ai-edge-torch-nightly 0.3.0.dev20240916__py3-none-any.whl → 0.3.0.dev20240919__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +36 -56
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +36 -56
  3. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +36 -56
  4. ai_edge_torch/generative/examples/openelm/openelm.py +0 -29
  5. ai_edge_torch/generative/examples/openelm/verify.py +61 -0
  6. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +36 -56
  7. ai_edge_torch/generative/examples/phi/phi2.py +4 -31
  8. ai_edge_torch/generative/examples/phi/verify.py +53 -0
  9. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +36 -56
  10. ai_edge_torch/generative/examples/smollm/smollm.py +0 -30
  11. ai_edge_torch/generative/examples/smollm/verify.py +59 -0
  12. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -0
  13. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +36 -56
  14. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +0 -29
  15. ai_edge_torch/generative/examples/tiny_llama/verify.py +61 -0
  16. ai_edge_torch/generative/layers/attention.py +8 -4
  17. ai_edge_torch/generative/layers/builder.py +3 -1
  18. ai_edge_torch/generative/layers/model_config.py +3 -0
  19. ai_edge_torch/generative/layers/normalization.py +31 -20
  20. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +19 -9
  21. ai_edge_torch/generative/layers/unet/blocks_2d.py +11 -4
  22. ai_edge_torch/generative/layers/unet/model_config.py +3 -0
  23. ai_edge_torch/generative/test/test_model_conversion_large.py +1 -1
  24. ai_edge_torch/generative/utilities/converter.py +82 -0
  25. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +10 -0
  26. ai_edge_torch/generative/utilities/verifier.py +200 -0
  27. ai_edge_torch/odml_torch/lowerings/_basic.py +14 -4
  28. ai_edge_torch/odml_torch/lowerings/registry.py +1 -0
  29. ai_edge_torch/version.py +1 -1
  30. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/METADATA +1 -1
  31. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/RECORD +34 -28
  32. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/LICENSE +0 -0
  33. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/WHEEL +0 -0
  34. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/top_level.txt +0 -0
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.gemma import gemma2
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/gemma2_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_gemma2_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts a Gemma2 2B model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = gemma2.build_2b_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
86
- convert_gemma2_to_tflite(path)
66
+ app.run(main)
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.gemma import gemma
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/gemma_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_gemma_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts a Gemma 2B model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = gemma.build_2b_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
86
- convert_gemma_to_tflite(path)
66
+ app.run(main)
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.openelm import openelm
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/openelm_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_openelm_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts OpenELM model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = openelm.build_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/openelm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm')
86
- convert_openelm_to_tflite(path)
66
+ app.run(main)
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building an OpenELM model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -210,28 +206,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
210
206
  loader.load(model, strict=False)
211
207
  model.eval()
212
208
  return model
213
-
214
-
215
- def define_and_run(checkpoint_path: str) -> None:
216
- """Instantiates and runs an OpenELM model."""
217
-
218
- current_dir = pathlib.Path(__file__).parent.resolve()
219
- openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt")
220
- kv_cache_max_len = 1024
221
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
222
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
223
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
224
- tokens[0, :4] = idx
225
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
226
- kv = kv_utils.KVCache.from_model_config(model.config)
227
- output = model.forward(tokens, input_pos, kv)
228
- assert torch.allclose(
229
- openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
230
- )
231
-
232
-
233
- if __name__ == "__main__":
234
- input_checkpoint_path = os.path.join(
235
- pathlib.Path.home(), "Downloads/llm_data/openelm"
236
- )
237
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,61 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored OpenELM-3B model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "What is the meaning of life?",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "apple/OpenELM-3B"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
37
+ checkpoint, trust_remote_code=True
38
+ )
39
+
40
+ # Locate the cached dir.
41
+ cached_config_file = transformers.utils.cached_file(
42
+ checkpoint, transformers.utils.CONFIG_NAME
43
+ )
44
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
45
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
46
+ reauthored_model = openelm.build_model(reauthored_checkpoint)
47
+
48
+ tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
49
+ verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint)
50
+ tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
51
+
52
+ verifier.verify_reauthored_model(
53
+ original_model=original_model,
54
+ reauthored_model=reauthored_model,
55
+ tokenizer=tokenizer,
56
+ prompts=_PROMPTS.value,
57
+ )
58
+
59
+
60
+ if __name__ == "__main__":
61
+ app.run(main)
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.phi import phi2
23
- from ai_edge_torch.generative.layers import kv_cache
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/phi2_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_phi2_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts a Phi-2 model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = phi2.build_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
86
- convert_phi2_to_tflite(path)
66
+ app.run(main)
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building a Phi-2 model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -143,7 +139,10 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
143
139
  intermediate_size=10240,
144
140
  use_bias=True,
145
141
  )
146
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
142
+ norm_config = cfg.NormalizationConfig(
143
+ type=cfg.NormalizationType.LAYER_NORM,
144
+ use_input_shape=False, # Phi-2 does layer-norm with the weight shape.
145
+ )
147
146
  block_config = cfg.TransformerBlockConfig(
148
147
  attn_config=attn_config,
149
148
  ff_config=ff_config,
@@ -182,29 +181,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
182
181
  loader.load(model)
183
182
  model.eval()
184
183
  return model
185
-
186
-
187
- def define_and_run(checkpoint_path: str) -> None:
188
- """Instantiates and runs a Phi-2 model."""
189
-
190
- current_dir = pathlib.Path(__file__).parent.resolve()
191
- phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
192
- kv_cache_max_len = 1024
193
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
194
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
195
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
196
- tokens[0, :4] = idx
197
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
198
- kv = kv_utils.KVCache.from_model_config(model.config)
199
- output = model.forward(tokens, input_pos, kv)
200
- print("comparing with goldens..")
201
- assert torch.allclose(
202
- phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
203
- )
204
-
205
-
206
- if __name__ == "__main__":
207
- input_checkpoint_path = os.path.join(
208
- pathlib.Path.home(), "Downloads/llm_data/phi2"
209
- )
210
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,53 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Phi-2 model."""
17
+
18
+ from absl import app
19
+ from absl import flags
20
+ from ai_edge_torch.generative.examples.phi import phi2
21
+ from ai_edge_torch.generative.utilities import verifier
22
+ import kagglehub
23
+ import transformers
24
+
25
+ _PROMPTS = flags.DEFINE_multi_string(
26
+ "prompts",
27
+ "What is the meaning of life?",
28
+ "The input prompts to generate answers.",
29
+ )
30
+
31
+
32
+ def main(_):
33
+ checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
34
+ verifier.log_msg("Loading the original model from", checkpoint)
35
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
36
+
37
+ verifier.log_msg("Building the reauthored model from", checkpoint)
38
+ reauthored_model = phi2.build_model(checkpoint)
39
+
40
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
41
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
42
+
43
+ verifier.verify_reauthored_model(
44
+ original_model=original_model,
45
+ reauthored_model=reauthored_model,
46
+ tokenizer=tokenizer,
47
+ prompts=_PROMPTS.value,
48
+ atol=1e-03,
49
+ )
50
+
51
+
52
+ if __name__ == "__main__":
53
+ app.run(main)