ai-edge-torch-nightly 0.3.0.dev20241002__py3-none-any.whl → 0.3.0.dev20241003__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (24) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +10 -93
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +0 -1
  3. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +13 -2
  4. ai_edge_torch/generative/examples/llama/llama.py +19 -24
  5. ai_edge_torch/generative/examples/llama/verify.py +18 -3
  6. ai_edge_torch/generative/examples/openelm/openelm.py +9 -90
  7. ai_edge_torch/generative/examples/phi/phi2.py +10 -86
  8. ai_edge_torch/generative/examples/phi/phi3.py +9 -69
  9. ai_edge_torch/generative/examples/qwen/qwen.py +26 -36
  10. ai_edge_torch/generative/examples/smollm/smollm.py +10 -30
  11. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +11 -101
  12. ai_edge_torch/generative/layers/model_config.py +6 -0
  13. ai_edge_torch/generative/test/test_loader.py +2 -1
  14. ai_edge_torch/generative/test/test_model_conversion.py +2 -1
  15. ai_edge_torch/generative/test/test_model_conversion_large.py +6 -5
  16. ai_edge_torch/generative/utilities/model_builder.py +141 -0
  17. ai_edge_torch/version.py +1 -1
  18. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/METADATA +1 -1
  19. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/RECORD +22 -23
  20. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +0 -68
  21. ai_edge_torch/generative/examples/llama/verify_3b.py +0 -73
  22. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/LICENSE +0 -0
  23. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/WHEEL +0 -0
  24. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/top_level.txt +0 -0
@@ -1,68 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Example of converting Llama 3.2 3B model to multi-signature tflite model."""
17
-
18
- import os
19
- import pathlib
20
-
21
- from absl import app
22
- from absl import flags
23
- from ai_edge_torch.generative.examples.llama import llama
24
- from ai_edge_torch.generative.utilities import converter
25
-
26
- _CHECKPOINT_PATH = flags.DEFINE_string(
27
- 'checkpoint_path',
28
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
29
- 'The path to the model checkpoint, or directory holding the checkpoint.',
30
- )
31
- _TFLITE_PATH = flags.DEFINE_string(
32
- 'tflite_path',
33
- '/tmp/',
34
- 'The tflite file path to export.',
35
- )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
40
- )
41
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
- 'kv_cache_max_len',
43
- 1280,
44
- 'The maximum size of KV cache buffer, including both prefill and decode.',
45
- )
46
- _QUANTIZE = flags.DEFINE_bool(
47
- 'quantize',
48
- True,
49
- 'Whether the model should be quantized.',
50
- )
51
-
52
-
53
- def main(_):
54
- pytorch_model = llama.build_3b_model(
55
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
- )
57
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'llama_3b_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
- converter.convert_to_tflite(
60
- pytorch_model,
61
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
- quantize=_QUANTIZE.value,
64
- )
65
-
66
-
67
- if __name__ == '__main__':
68
- app.run(main)
@@ -1,73 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Verifies the reauthored Llama 3.2-3B model."""
17
-
18
- import logging
19
- import pathlib
20
-
21
- from absl import app
22
- from absl import flags
23
- from ai_edge_torch.generative.examples.llama import llama
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
27
-
28
-
29
- _PROMPTS = flags.DEFINE_multi_string(
30
- "prompts",
31
- "What is the meaning of life?",
32
- "The input prompts to generate answers.",
33
- )
34
- _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
- "max_new_tokens",
36
- 30,
37
- "The maximum size of the generated tokens.",
38
- )
39
-
40
-
41
- def main(_):
42
- checkpoint = "meta-llama/Llama-3.2-3B-Instruct"
43
- logging.info("Loading the original model from: %s", checkpoint)
44
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
-
46
- # Locate the cached dir.
47
- cached_config_file = transformers.utils.cached_file(
48
- checkpoint, transformers.utils.CONFIG_NAME
49
- )
50
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
- reauthored_model = llama.build_3b_model(reauthored_checkpoint)
53
-
54
- logging.info("Loading the tokenizer from: %s", checkpoint)
55
- # Llama tokenizer_config.json sets a fast tokenizer class explicitly,
56
- # "PreTrainedTokenizerFast". It works only when the fast tokenizer is
57
- # available.
58
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
59
-
60
- verifier.verify_reauthored_model(
61
- original_model=transformers_verifier.TransformersModelWrapper(
62
- original_model
63
- ),
64
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
65
- tokenizer=verifier.TokenizerWrapper(tokenizer),
66
- generate_prompts=_PROMPTS.value,
67
- max_new_tokens=_MAX_NEW_TOKENS.value,
68
- atol=1e-04,
69
- )
70
-
71
-
72
- if __name__ == "__main__":
73
- app.run(main)