ai-edge-torch-nightly 0.3.0.dev20241002__py3-none-any.whl → 0.3.0.dev20241003__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/gemma1.py +10 -93
- ai_edge_torch/generative/examples/gemma/gemma2.py +0 -1
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +13 -2
- ai_edge_torch/generative/examples/llama/llama.py +19 -24
- ai_edge_torch/generative/examples/llama/verify.py +18 -3
- ai_edge_torch/generative/examples/openelm/openelm.py +9 -90
- ai_edge_torch/generative/examples/phi/phi2.py +10 -86
- ai_edge_torch/generative/examples/phi/phi3.py +9 -69
- ai_edge_torch/generative/examples/qwen/qwen.py +26 -36
- ai_edge_torch/generative/examples/smollm/smollm.py +10 -30
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +11 -101
- ai_edge_torch/generative/layers/model_config.py +6 -0
- ai_edge_torch/generative/test/test_loader.py +2 -1
- ai_edge_torch/generative/test/test_model_conversion.py +2 -1
- ai_edge_torch/generative/test/test_model_conversion_large.py +6 -5
- ai_edge_torch/generative/utilities/model_builder.py +141 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/RECORD +22 -23
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +0 -68
- ai_edge_torch/generative/examples/llama/verify_3b.py +0 -73
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/top_level.txt +0 -0
@@ -1,68 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""Example of converting Llama 3.2 3B model to multi-signature tflite model."""
|
17
|
-
|
18
|
-
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
|
-
from absl import app
|
22
|
-
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.llama import llama
|
24
|
-
from ai_edge_torch.generative.utilities import converter
|
25
|
-
|
26
|
-
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
-
'checkpoint_path',
|
28
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
|
29
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
-
)
|
31
|
-
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
-
'tflite_path',
|
33
|
-
'/tmp/',
|
34
|
-
'The tflite file path to export.',
|
35
|
-
)
|
36
|
-
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
-
'prefill_seq_len',
|
38
|
-
1024,
|
39
|
-
'The maximum size of prefill input tensor.',
|
40
|
-
)
|
41
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
-
'kv_cache_max_len',
|
43
|
-
1280,
|
44
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
-
)
|
46
|
-
_QUANTIZE = flags.DEFINE_bool(
|
47
|
-
'quantize',
|
48
|
-
True,
|
49
|
-
'Whether the model should be quantized.',
|
50
|
-
)
|
51
|
-
|
52
|
-
|
53
|
-
def main(_):
|
54
|
-
pytorch_model = llama.build_3b_model(
|
55
|
-
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
|
-
)
|
57
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename = f'llama_3b_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
|
-
converter.convert_to_tflite(
|
60
|
-
pytorch_model,
|
61
|
-
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
63
|
-
quantize=_QUANTIZE.value,
|
64
|
-
)
|
65
|
-
|
66
|
-
|
67
|
-
if __name__ == '__main__':
|
68
|
-
app.run(main)
|
@@ -1,73 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""Verifies the reauthored Llama 3.2-3B model."""
|
17
|
-
|
18
|
-
import logging
|
19
|
-
import pathlib
|
20
|
-
|
21
|
-
from absl import app
|
22
|
-
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.llama import llama
|
24
|
-
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
-
from ai_edge_torch.generative.utilities import verifier
|
26
|
-
import transformers
|
27
|
-
|
28
|
-
|
29
|
-
_PROMPTS = flags.DEFINE_multi_string(
|
30
|
-
"prompts",
|
31
|
-
"What is the meaning of life?",
|
32
|
-
"The input prompts to generate answers.",
|
33
|
-
)
|
34
|
-
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
-
"max_new_tokens",
|
36
|
-
30,
|
37
|
-
"The maximum size of the generated tokens.",
|
38
|
-
)
|
39
|
-
|
40
|
-
|
41
|
-
def main(_):
|
42
|
-
checkpoint = "meta-llama/Llama-3.2-3B-Instruct"
|
43
|
-
logging.info("Loading the original model from: %s", checkpoint)
|
44
|
-
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
|
-
|
46
|
-
# Locate the cached dir.
|
47
|
-
cached_config_file = transformers.utils.cached_file(
|
48
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
49
|
-
)
|
50
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
51
|
-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
52
|
-
reauthored_model = llama.build_3b_model(reauthored_checkpoint)
|
53
|
-
|
54
|
-
logging.info("Loading the tokenizer from: %s", checkpoint)
|
55
|
-
# Llama tokenizer_config.json sets a fast tokenizer class explicitly,
|
56
|
-
# "PreTrainedTokenizerFast". It works only when the fast tokenizer is
|
57
|
-
# available.
|
58
|
-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
59
|
-
|
60
|
-
verifier.verify_reauthored_model(
|
61
|
-
original_model=transformers_verifier.TransformersModelWrapper(
|
62
|
-
original_model
|
63
|
-
),
|
64
|
-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
65
|
-
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
66
|
-
generate_prompts=_PROMPTS.value,
|
67
|
-
max_new_tokens=_MAX_NEW_TOKENS.value,
|
68
|
-
atol=1e-04,
|
69
|
-
)
|
70
|
-
|
71
|
-
|
72
|
-
if __name__ == "__main__":
|
73
|
-
app.run(main)
|
File without changes
|
File without changes
|