ai-edge-torch-nightly 0.3.0.dev20241023__py3-none-any.whl → 0.3.0.dev20241027__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/amd_llama_135m/__init__.py +14 -0
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +82 -0
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/amd_llama_135m/verify.py +72 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_multi_prefills.py +82 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +11 -1
- ai_edge_torch/generative/utilities/converter.py +75 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241023.dist-info → ai_edge_torch_nightly-0.3.0.dev20241027.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241023.dist-info → ai_edge_torch_nightly-0.3.0.dev20241027.dist-info}/RECORD +13 -8
- {ai_edge_torch_nightly-0.3.0.dev20241023.dist-info → ai_edge_torch_nightly-0.3.0.dev20241027.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241023.dist-info → ai_edge_torch_nightly-0.3.0.dev20241027.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241023.dist-info → ai_edge_torch_nightly-0.3.0.dev20241027.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building AMD-Llama-135m."""
|
17
|
+
|
18
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
20
|
+
|
21
|
+
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
|
22
|
+
|
23
|
+
|
24
|
+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
25
|
+
"""Returns the model config for an AMD-Llama-135m model.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
29
|
+
is 1024.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
The model config for an AMD-Llama-135m model.
|
33
|
+
"""
|
34
|
+
attn_config = cfg.AttentionConfig(
|
35
|
+
num_heads=12,
|
36
|
+
head_dim=64,
|
37
|
+
num_query_groups=12,
|
38
|
+
rotary_base=10000,
|
39
|
+
rotary_percentage=1.0,
|
40
|
+
)
|
41
|
+
ff_config = cfg.FeedForwardConfig(
|
42
|
+
type=cfg.FeedForwardType.GATED,
|
43
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
44
|
+
intermediate_size=2048,
|
45
|
+
)
|
46
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
47
|
+
block_config = cfg.TransformerBlockConfig(
|
48
|
+
attn_config=attn_config,
|
49
|
+
ff_config=ff_config,
|
50
|
+
pre_attention_norm_config=norm_config,
|
51
|
+
post_attention_norm_config=norm_config,
|
52
|
+
)
|
53
|
+
config = cfg.ModelConfig(
|
54
|
+
vocab_size=32000,
|
55
|
+
num_layers=12,
|
56
|
+
max_seq_len=2048,
|
57
|
+
embedding_dim=768,
|
58
|
+
kv_cache_max_len=kv_cache_max_len,
|
59
|
+
block_configs=block_config,
|
60
|
+
final_norm_config=norm_config,
|
61
|
+
lm_head_share_weight_with_embedding=False,
|
62
|
+
enable_hlfb=True,
|
63
|
+
)
|
64
|
+
return config
|
65
|
+
|
66
|
+
|
67
|
+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
68
|
+
config = get_model_config(**kwargs)
|
69
|
+
config.vocab_size = 128
|
70
|
+
config.num_layers = 2
|
71
|
+
config.block_config(0).ff_config.intermediate_size = 64
|
72
|
+
return config
|
73
|
+
|
74
|
+
|
75
|
+
def build_model(
|
76
|
+
checkpoint_path: str, **kwargs
|
77
|
+
) -> model_builder.DecoderOnlyModel:
|
78
|
+
return model_builder.build_decoder_only_model(
|
79
|
+
checkpoint_path=checkpoint_path,
|
80
|
+
config=get_model_config(**kwargs),
|
81
|
+
tensor_names=TENSOR_NAMES,
|
82
|
+
)
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting AMD-Llama-135m model to multi-signature tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
25
|
+
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/amd-llama-135m'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
1024,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1280,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
51
|
+
|
52
|
+
|
53
|
+
def main(_):
|
54
|
+
pytorch_model = amd_llama_135m.build_model(
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
|
+
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'amd-llama-135m_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
|
+
converter.convert_to_tflite(
|
60
|
+
pytorch_model,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
63
|
+
quantize=_QUANTIZE.value,
|
64
|
+
)
|
65
|
+
|
66
|
+
|
67
|
+
if __name__ == '__main__':
|
68
|
+
app.run(main)
|
@@ -0,0 +1,72 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored AMD-Llama-135M model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
+
from ai_edge_torch.generative.utilities import verifier
|
26
|
+
import transformers
|
27
|
+
|
28
|
+
|
29
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
30
|
+
"prompts",
|
31
|
+
"Tell me a story?\nOnce upon a time",
|
32
|
+
"The input prompts to generate answers.",
|
33
|
+
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
def main(_):
|
42
|
+
checkpoint = "amd/AMD-Llama-135m"
|
43
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
45
|
+
checkpoint, trust_remote_code=True
|
46
|
+
)
|
47
|
+
|
48
|
+
# Locate the cached dir.
|
49
|
+
cached_config_file = transformers.utils.cached_file(
|
50
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
51
|
+
)
|
52
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
53
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
54
|
+
reauthored_model = amd_llama_135m.build_model(reauthored_checkpoint)
|
55
|
+
|
56
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
57
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
58
|
+
|
59
|
+
verifier.verify_reauthored_model(
|
60
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
+
original_model
|
62
|
+
),
|
63
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
65
|
+
generate_prompts=_PROMPTS.value,
|
66
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
67
|
+
atol=1e-04,
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
if __name__ == "__main__":
|
72
|
+
app.run(main)
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting a Gemma2 model to multi-signature tflite model, with multiple prefill lengths."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import os
|
20
|
+
import pathlib
|
21
|
+
|
22
|
+
from absl import app
|
23
|
+
from absl import flags
|
24
|
+
from ai_edge_torch.generative.examples.gemma import gemma2
|
25
|
+
from ai_edge_torch.generative.utilities import converter
|
26
|
+
|
27
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
|
+
'checkpoint_path',
|
29
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
|
30
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
+
)
|
32
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
33
|
+
'tflite_path',
|
34
|
+
'/tmp/',
|
35
|
+
'The tflite file path to export.',
|
36
|
+
)
|
37
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
|
+
'prefill_seq_len',
|
39
|
+
(8, 64, 128, 256, 512, 1024),
|
40
|
+
'A list of prefill lengths to export.',
|
41
|
+
)
|
42
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
43
|
+
'kv_cache_max_len',
|
44
|
+
1280,
|
45
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
46
|
+
)
|
47
|
+
_QUANTIZE = flags.DEFINE_bool(
|
48
|
+
'quantize',
|
49
|
+
True,
|
50
|
+
'Whether the model should be quantized.',
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
# Note that the converted model is not compatible with LLM Inference engine for
|
55
|
+
# now. The main purpose for this function is to allow you export a tflite model
|
56
|
+
# with multiple prefill signatures for different prefill lengths for faster
|
57
|
+
# inference.
|
58
|
+
def convert_to_tflite_multi_prefill_lens():
|
59
|
+
pytorch_model = gemma2.build_2b_model(
|
60
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
61
|
+
)
|
62
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
63
|
+
output_filename = f'gemma2_{quant_suffix}_multi-prefill-seq_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
64
|
+
converter.convert_to_tflite_multi_prefill_lens(
|
65
|
+
pytorch_model,
|
66
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
67
|
+
prefill_seq_lens=_PREFILL_SEQ_LENS.value,
|
68
|
+
quantize=_QUANTIZE.value,
|
69
|
+
)
|
70
|
+
|
71
|
+
|
72
|
+
def main(_):
|
73
|
+
if len(_PREFILL_SEQ_LENS.value) > 1:
|
74
|
+
# If multiple prefill lengths are provided, export a model with multiple
|
75
|
+
# prefill signatures each for a different prefill length.
|
76
|
+
convert_to_tflite_multi_prefill_lens()
|
77
|
+
else:
|
78
|
+
logging.warning('Need more than one prefill lengths to be specified.')
|
79
|
+
|
80
|
+
|
81
|
+
if __name__ == '__main__':
|
82
|
+
app.run(main)
|
@@ -17,6 +17,7 @@
|
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
|
+
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
20
21
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
21
22
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
23
|
from ai_edge_torch.generative.examples.llama import llama
|
@@ -29,8 +30,8 @@ from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
|
|
29
30
|
from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
|
30
31
|
from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
|
31
32
|
from ai_edge_torch.generative.layers import kv_cache
|
32
|
-
from ai_edge_torch.generative.utilities import model_builder
|
33
33
|
from ai_edge_torch.generative.test import utils as test_utils
|
34
|
+
from ai_edge_torch.generative.utilities import model_builder
|
34
35
|
import numpy as np
|
35
36
|
import torch
|
36
37
|
|
@@ -161,6 +162,15 @@ class TestModelConversion(googletest.TestCase):
|
|
161
162
|
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
162
163
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
163
164
|
|
165
|
+
@googletest.skipIf(
|
166
|
+
ai_edge_config.Config.use_torch_xla,
|
167
|
+
reason="tests with custom ops are not supported on oss",
|
168
|
+
)
|
169
|
+
def test_amd_llama_135m(self):
|
170
|
+
config = amd_llama_135m.get_fake_model_config()
|
171
|
+
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
172
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
173
|
+
|
164
174
|
@googletest.skipIf(
|
165
175
|
ai_edge_config.Config.use_torch_xla,
|
166
176
|
reason="tests with custom ops are not supported on oss",
|
@@ -16,6 +16,7 @@
|
|
16
16
|
"""Common utility functions for model conversion."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
|
+
from ai_edge_torch._convert import converter as converter_utils
|
19
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
21
|
from ai_edge_torch.generative.quantize import quant_recipes
|
21
22
|
import torch
|
@@ -80,3 +81,77 @@ def convert_to_tflite(
|
|
80
81
|
.convert(quant_config=quant_config)
|
81
82
|
)
|
82
83
|
edge_model.export(tflite_path)
|
84
|
+
|
85
|
+
|
86
|
+
def convert_to_tflite_multi_prefill_lens(
|
87
|
+
pytorch_model: torch.nn.Module,
|
88
|
+
tflite_path: str,
|
89
|
+
prefill_seq_lens: list[int],
|
90
|
+
quantize: bool = True,
|
91
|
+
):
|
92
|
+
"""Converts a nn.Module model to multi-signature tflite model with different
|
93
|
+
|
94
|
+
prefill lengths.
|
95
|
+
|
96
|
+
A PyTorch model will be converted to a tflite model with several signatures:
|
97
|
+
"prefill_[prefill_seq_len]" and "decode".
|
98
|
+
|
99
|
+
"prefill_[prefill_seq_len]" signature takes a tensor of shape [1,
|
100
|
+
prefill_seq_len] of token
|
101
|
+
sequence, a tensor of shape [1, prefill_seq_len] of token positions, and an
|
102
|
+
external KV cache as a sample input.
|
103
|
+
|
104
|
+
"decode" signature takes a tensor of shape [1, 1] of token sequence, a tensor
|
105
|
+
of shape [1, 1] of the token position, and an external KV cache as a sample
|
106
|
+
input.
|
107
|
+
|
108
|
+
The final tflite model will be exported to tflite_path.
|
109
|
+
|
110
|
+
Args:
|
111
|
+
pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
|
112
|
+
tflite_path (str): The tflite file path to export.
|
113
|
+
prefill_seq_lens (list[int]): A list of prefill lengths to export.
|
114
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
115
|
+
to True.
|
116
|
+
"""
|
117
|
+
# Tensors used to trace the model graph during conversion.
|
118
|
+
prefill_tokens_list = []
|
119
|
+
prefill_input_pos_list = []
|
120
|
+
for prefill_seq_len in prefill_seq_lens:
|
121
|
+
prefill_tokens_list.append(
|
122
|
+
torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
123
|
+
)
|
124
|
+
prefill_input_pos_list.append(
|
125
|
+
torch.arange(0, prefill_seq_len, dtype=torch.int)
|
126
|
+
)
|
127
|
+
|
128
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
129
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
130
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
131
|
+
|
132
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
133
|
+
converter = converter_utils.Converter()
|
134
|
+
for i in range(len(prefill_seq_lens)):
|
135
|
+
prefill_seq_len = prefill_seq_lens[i]
|
136
|
+
prefill_tokens = prefill_tokens_list[i]
|
137
|
+
prefill_input_pos = prefill_input_pos_list[i]
|
138
|
+
converter.add_signature(
|
139
|
+
f'prefill_{prefill_seq_len}',
|
140
|
+
pytorch_model,
|
141
|
+
sample_kwargs={
|
142
|
+
'tokens': prefill_tokens,
|
143
|
+
'input_pos': prefill_input_pos,
|
144
|
+
'kv_cache': kv,
|
145
|
+
},
|
146
|
+
)
|
147
|
+
|
148
|
+
edge_model = converter.add_signature(
|
149
|
+
'decode',
|
150
|
+
pytorch_model,
|
151
|
+
sample_kwargs={
|
152
|
+
'tokens': decode_token,
|
153
|
+
'input_pos': decode_input_pos,
|
154
|
+
'kv_cache': kv,
|
155
|
+
},
|
156
|
+
).convert(quant_config=quant_config)
|
157
|
+
edge_model.export(tflite_path)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241027
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=SrYveglaiA_DXPoRBqSXClWM1q7853I5ujRorq_MV0M,4251
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=VekzumwXByceYkTQ97jSNSKfX2vYBmx4ZSsHs9cyT-0,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -39,8 +39,13 @@ ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOa
|
|
39
39
|
ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
40
40
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
41
41
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
42
|
+
ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
43
|
+
ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=bkq2ZknJfuY7WC8wLVg92Z6eA_aMDbkgwaMxvmDW4_0,2618
|
44
|
+
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=-n79r6yFnCACpms5eMkXNpyQsCn2PYVRdB-jOoIqn14,2227
|
45
|
+
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
|
42
46
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
43
47
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
|
48
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_multi_prefills.py,sha256=bZKOiAJBWPzIVHdASEgKRUFdyZSPVGFfe3uXUYrRh1c,2868
|
44
49
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
|
45
50
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
|
46
51
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
|
@@ -124,11 +129,11 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
|
|
124
129
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
125
130
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
126
131
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=a4TzSw8KMxEafirxqkykZi-WgTs5Z7wHp-J1AfjRDzA,6353
|
127
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
132
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=TzBEbWOoB7bIHePuP6ySL9eYfmKHpONgTQCU-f05m8c,9497
|
128
133
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
129
134
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
130
135
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
131
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
136
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=17O83wVifH1vQJCI4WC3DaNiCIOyK2gys1GzohbLrRs,5554
|
132
137
|
ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
|
133
138
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=89jt80UUfDzYBi-x077HBavWeuNJuYPXym9fiKCY1Tk,5278
|
134
139
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
@@ -181,8 +186,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
181
186
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
182
187
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
183
188
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
184
|
-
ai_edge_torch_nightly-0.3.0.
|
185
|
-
ai_edge_torch_nightly-0.3.0.
|
186
|
-
ai_edge_torch_nightly-0.3.0.
|
187
|
-
ai_edge_torch_nightly-0.3.0.
|
188
|
-
ai_edge_torch_nightly-0.3.0.
|
189
|
+
ai_edge_torch_nightly-0.3.0.dev20241027.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
190
|
+
ai_edge_torch_nightly-0.3.0.dev20241027.dist-info/METADATA,sha256=WYTOBwCoMZ3Z8G223xG54Lj8PTR9HUW2Yr5dUVtF0nA,1897
|
191
|
+
ai_edge_torch_nightly-0.3.0.dev20241027.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
192
|
+
ai_edge_torch_nightly-0.3.0.dev20241027.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
193
|
+
ai_edge_torch_nightly-0.3.0.dev20241027.dist-info/RECORD,,
|
File without changes
|
File without changes
|