ai-edge-torch-nightly 0.3.0.dev20241116__py3-none-any.whl → 0.3.0.dev20241119__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +3 -3
- ai_edge_torch/generative/examples/{gemma/convert_gemma2_multi_prefills.py → paligemma/convert_to_tflite.py} +23 -25
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +3 -1
- ai_edge_torch/generative/examples/paligemma/paligemma.py +25 -9
- ai_edge_torch/generative/test/test_model_conversion.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion_large.py +50 -1
- ai_edge_torch/generative/test/utils.py +6 -3
- ai_edge_torch/generative/utilities/converter.py +71 -87
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241116.dist-info → ai_edge_torch_nightly-0.3.0.dev20241119.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241116.dist-info → ai_edge_torch_nightly-0.3.0.dev20241119.dist-info}/RECORD +14 -14
- {ai_edge_torch_nightly-0.3.0.dev20241116.dist-info → ai_edge_torch_nightly-0.3.0.dev20241119.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241116.dist-info → ai_edge_torch_nightly-0.3.0.dev20241119.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241116.dist-info → ai_edge_torch_nightly-0.3.0.dev20241119.dist-info}/top_level.txt +0 -0
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
_PREFILL_SEQ_LEN = flags.
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_multi_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
1024,
|
39
|
-
'
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -13,20 +13,24 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""Example of converting a
|
16
|
+
"""Example of converting a PaliGemma model to multi-signature tflite model.
|
17
|
+
|
18
|
+
DISCLAIMER: It works only with ODML Torch conversion backend. Refer to
|
19
|
+
https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#use-odml-torch-conversion-backend-experimental.
|
20
|
+
"""
|
17
21
|
|
18
|
-
import logging
|
19
22
|
import os
|
20
23
|
import pathlib
|
21
24
|
|
22
25
|
from absl import app
|
23
26
|
from absl import flags
|
24
|
-
from ai_edge_torch.generative.examples.
|
27
|
+
from ai_edge_torch.generative.examples.paligemma import paligemma
|
25
28
|
from ai_edge_torch.generative.utilities import converter
|
29
|
+
import torch
|
26
30
|
|
27
31
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
32
|
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
33
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma-3b-224'),
|
30
34
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
35
|
)
|
32
36
|
_TFLITE_PATH = flags.DEFINE_string(
|
@@ -34,16 +38,21 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
34
38
|
'/tmp/',
|
35
39
|
'The tflite file path to export.',
|
36
40
|
)
|
37
|
-
|
41
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
38
42
|
'prefill_seq_len',
|
39
|
-
|
40
|
-
'
|
43
|
+
1024,
|
44
|
+
'The maximum size of prefill input tensor.',
|
41
45
|
)
|
42
46
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
43
47
|
'kv_cache_max_len',
|
44
48
|
1280,
|
45
49
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
46
50
|
)
|
51
|
+
_PIXEL_VALUES_SIZE = flags.DEFINE_multi_integer(
|
52
|
+
'pixel_values_size',
|
53
|
+
[3, 224, 224],
|
54
|
+
'The size of prefill pixel values except the batch dimension.',
|
55
|
+
)
|
47
56
|
_QUANTIZE = flags.DEFINE_bool(
|
48
57
|
'quantize',
|
49
58
|
True,
|
@@ -51,32 +60,21 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
51
60
|
)
|
52
61
|
|
53
62
|
|
54
|
-
|
55
|
-
|
56
|
-
# with multiple prefill signatures for different prefill lengths for faster
|
57
|
-
# inference.
|
58
|
-
def convert_to_tflite_multi_prefill_lens():
|
59
|
-
pytorch_model = gemma2.build_2b_model(
|
63
|
+
def main(_):
|
64
|
+
pytorch_model = paligemma.build_model(
|
60
65
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
61
66
|
)
|
62
67
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
63
|
-
output_filename = f'
|
64
|
-
converter.
|
68
|
+
output_filename = f'paligemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
69
|
+
converter.convert_to_tflite(
|
65
70
|
pytorch_model,
|
66
71
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
67
|
-
|
72
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
73
|
+
pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
|
68
74
|
quantize=_QUANTIZE.value,
|
75
|
+
config=pytorch_model.config.decoder_config,
|
69
76
|
)
|
70
77
|
|
71
78
|
|
72
|
-
def main(_):
|
73
|
-
if len(_PREFILL_SEQ_LENS.value) > 1:
|
74
|
-
# If multiple prefill lengths are provided, export a model with multiple
|
75
|
-
# prefill signatures each for a different prefill length.
|
76
|
-
convert_to_tflite_multi_prefill_lens()
|
77
|
-
else:
|
78
|
-
logging.warning('Need more than one prefill lengths to be specified.')
|
79
|
-
|
80
|
-
|
81
79
|
if __name__ == '__main__':
|
82
80
|
app.run(main)
|
@@ -59,7 +59,7 @@ class SiglipVisionEncoder(nn.Module):
|
|
59
59
|
out_channels=config.embedding_dim,
|
60
60
|
kernel_size=config.image_embedding.patch_size,
|
61
61
|
stride=config.image_embedding.patch_size,
|
62
|
-
padding=
|
62
|
+
padding=0,
|
63
63
|
)
|
64
64
|
num_patches = (
|
65
65
|
config.image_embedding.image_size // config.image_embedding.patch_size
|
@@ -144,6 +144,8 @@ def get_fake_image_encoder_config() -> cfg.ModelConfig:
|
|
144
144
|
config = get_image_encoder_config()
|
145
145
|
# PaliGemma image encoder has only one block config.
|
146
146
|
config.block_config(0).ff_config.intermediate_size = 128
|
147
|
+
config.image_embedding.image_size = 8
|
148
|
+
config.image_embedding.patch_size = 2
|
147
149
|
config.num_layers = 2
|
148
150
|
return config
|
149
151
|
|
@@ -54,6 +54,10 @@ class PaliGemma(nn.Module):
|
|
54
54
|
bias=config.image_projection_use_bias,
|
55
55
|
)
|
56
56
|
self.decoder = decoder.Decoder(config.decoder_config)
|
57
|
+
image_embedding_config = config.image_encoder_config.image_embedding
|
58
|
+
self.num_patches = (
|
59
|
+
image_embedding_config.image_size // image_embedding_config.patch_size
|
60
|
+
) ** 2
|
57
61
|
self.config = config
|
58
62
|
|
59
63
|
@torch.inference_mode
|
@@ -74,10 +78,22 @@ class PaliGemma(nn.Module):
|
|
74
78
|
if self.config.decoder_config.embedding_scale is not None:
|
75
79
|
image_embeds = image_embeds / self.config.decoder_config.embedding_scale
|
76
80
|
|
77
|
-
#
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
+
# Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
|
82
|
+
# can be done like:
|
83
|
+
#
|
84
|
+
# image_mask = tokens == self.config.image_token_id
|
85
|
+
# image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
|
86
|
+
# input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
|
87
|
+
#
|
88
|
+
# Unfortunately, torch.Tensor.masked_scatter can't be lowered on CPU.
|
89
|
+
# Since PaliGemma token embedder reserves the first [num_patches] tokens
|
90
|
+
# for image tokens, we can use this property to merge image_embeds into
|
91
|
+
# input_embeds by concatenating them.
|
92
|
+
assert image_embeds.shape[1] == self.num_patches
|
93
|
+
assert input_embeds.shape[1] >= self.num_patches
|
94
|
+
input_embeds = torch.cat(
|
95
|
+
(image_embeds, input_embeds[:, self.num_patches:, :]), dim=1
|
96
|
+
)
|
81
97
|
|
82
98
|
return self.decoder(
|
83
99
|
tokens=None,
|
@@ -87,7 +103,7 @@ class PaliGemma(nn.Module):
|
|
87
103
|
)
|
88
104
|
|
89
105
|
|
90
|
-
def get_model_config() -> PaliGemmaConfig:
|
106
|
+
def get_model_config(**kwargs) -> PaliGemmaConfig:
|
91
107
|
"""Returns the model config for a PaliGemma 3B-224 model.
|
92
108
|
|
93
109
|
Returns:
|
@@ -95,13 +111,13 @@ def get_model_config() -> PaliGemmaConfig:
|
|
95
111
|
"""
|
96
112
|
return PaliGemmaConfig(
|
97
113
|
image_encoder_config=image_encoder.get_image_encoder_config(),
|
98
|
-
decoder_config=decoder.get_decoder_config(),
|
114
|
+
decoder_config=decoder.get_decoder_config(**kwargs),
|
99
115
|
image_projection_use_bias=True,
|
100
116
|
image_token_id=257152,
|
101
117
|
)
|
102
118
|
|
103
119
|
|
104
|
-
def
|
120
|
+
def get_fake_model_config() -> PaliGemmaConfig:
|
105
121
|
return PaliGemmaConfig(
|
106
122
|
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
107
123
|
decoder_config=decoder.get_fake_decoder_config(),
|
@@ -110,8 +126,8 @@ def get_fake_image_encoder_config() -> PaliGemmaConfig:
|
|
110
126
|
)
|
111
127
|
|
112
128
|
|
113
|
-
def build_model(checkpoint_path: str) -> PaliGemma:
|
114
|
-
config = get_model_config()
|
129
|
+
def build_model(checkpoint_path: str, **kwargs) -> PaliGemma:
|
130
|
+
config = get_model_config(**kwargs)
|
115
131
|
model = PaliGemma(config)
|
116
132
|
# Load the parameters of image encoder.
|
117
133
|
loader = loading_utils.ModelLoader(
|
@@ -117,7 +117,7 @@ class TestModelConversion(googletest.TestCase):
|
|
117
117
|
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
|
118
118
|
# prefill
|
119
119
|
seq_len = 10
|
120
|
-
prefill_tokens = torch.
|
120
|
+
prefill_tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
|
121
121
|
prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
|
122
122
|
prefill_tokens[0, : len(prompt_token)] = prompt_token
|
123
123
|
prefill_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
@@ -22,6 +22,7 @@ from ai_edge_torch.generative.examples.gemma import gemma1
|
|
22
22
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
23
|
from ai_edge_torch.generative.examples.llama import llama
|
24
24
|
from ai_edge_torch.generative.examples.openelm import openelm
|
25
|
+
from ai_edge_torch.generative.examples.paligemma import paligemma
|
25
26
|
from ai_edge_torch.generative.examples.phi import phi2
|
26
27
|
from ai_edge_torch.generative.examples.phi import phi3
|
27
28
|
from ai_edge_torch.generative.examples.qwen import qwen
|
@@ -55,7 +56,7 @@ class TestModelConversion(googletest.TestCase):
|
|
55
56
|
|
56
57
|
def _test_model(self, config, model, signature_name, atol, rtol):
|
57
58
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
58
|
-
tokens = torch.
|
59
|
+
tokens = torch.zeros((1, 10), dtype=torch.int, device="cpu")
|
59
60
|
tokens[0, :4] = idx
|
60
61
|
input_pos = torch.arange(0, 10, dtype=torch.int)
|
61
62
|
kv = kv_cache.KVCache.from_model_config(config)
|
@@ -171,6 +172,54 @@ class TestModelConversion(googletest.TestCase):
|
|
171
172
|
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
172
173
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
173
174
|
|
175
|
+
@googletest.skipIf(
|
176
|
+
ai_edge_config.Config.use_torch_xla,
|
177
|
+
reason="tests with custom ops are not supported on oss",
|
178
|
+
)
|
179
|
+
def test_paligemma(self):
|
180
|
+
config = paligemma.get_fake_model_config()
|
181
|
+
pytorch_model = paligemma.PaliGemma(config).eval()
|
182
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
183
|
+
image_embedding_config = config.image_encoder_config.image_embedding
|
184
|
+
num_patches = (
|
185
|
+
image_embedding_config.image_size // image_embedding_config.patch_size
|
186
|
+
) ** 2
|
187
|
+
# Make sure the token size is longer than the number of image patches.
|
188
|
+
tokens_len = num_patches + 10
|
189
|
+
tokens = torch.zeros((1, tokens_len), dtype=torch.int, device="cpu")
|
190
|
+
tokens[0, :4] = idx
|
191
|
+
input_pos = torch.arange(0, tokens_len, dtype=torch.int)
|
192
|
+
kv = kv_cache.KVCache.from_model_config(config.decoder_config)
|
193
|
+
pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32, device="cpu")
|
194
|
+
|
195
|
+
edge_model = ai_edge_torch.signature(
|
196
|
+
"prefill_pixel",
|
197
|
+
pytorch_model,
|
198
|
+
sample_kwargs={
|
199
|
+
"tokens": tokens,
|
200
|
+
"input_pos": input_pos,
|
201
|
+
"kv_cache": kv,
|
202
|
+
"pixel_values": pixel_values,
|
203
|
+
},
|
204
|
+
).convert()
|
205
|
+
edge_model.set_interpreter_builder(
|
206
|
+
self._interpreter_builder(edge_model.tflite_model())
|
207
|
+
)
|
208
|
+
|
209
|
+
self.assertTrue(
|
210
|
+
test_utils.compare_tflite_torch(
|
211
|
+
edge_model,
|
212
|
+
pytorch_model,
|
213
|
+
tokens,
|
214
|
+
input_pos,
|
215
|
+
kv,
|
216
|
+
pixel_values=pixel_values,
|
217
|
+
signature_name="prefill_pixel",
|
218
|
+
atol=1e-3,
|
219
|
+
rtol=1e-5,
|
220
|
+
)
|
221
|
+
)
|
222
|
+
|
174
223
|
@googletest.skipIf(
|
175
224
|
ai_edge_config.Config.use_torch_xla,
|
176
225
|
reason="tests with custom ops are not supported on oss",
|
@@ -32,18 +32,21 @@ def compare_tflite_torch(
|
|
32
32
|
signature_name: str,
|
33
33
|
atol: float = 1e-5,
|
34
34
|
rtol: float = 1e-5,
|
35
|
+
**kwargs,
|
35
36
|
):
|
36
37
|
"""Compares torch models and TFLite models."""
|
37
38
|
values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
|
38
39
|
flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
|
39
|
-
torch_output = torch_model(tokens, input_pos, kv_cache)
|
40
|
+
torch_output = torch_model(tokens, input_pos, kv_cache, **kwargs)
|
40
41
|
|
41
|
-
|
42
|
+
if "pixel_values" in kwargs:
|
43
|
+
kwargs["pixel_values"] = kwargs["pixel_values"].numpy()
|
44
|
+
kwargs.update({k: v.numpy() for k, v in zip(flat_names, values)})
|
42
45
|
edge_output = edge_model(
|
43
46
|
signature_name=signature_name,
|
44
47
|
tokens=tokens.numpy(),
|
45
48
|
input_pos=input_pos.numpy(),
|
46
|
-
**
|
49
|
+
**kwargs,
|
47
50
|
)
|
48
51
|
|
49
52
|
return np.allclose(
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
"""Common utility functions for model conversion."""
|
17
17
|
|
18
|
-
import
|
18
|
+
from typing import Union
|
19
|
+
|
19
20
|
from ai_edge_torch._convert import converter as converter_utils
|
20
|
-
|
21
|
+
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
22
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
21
23
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
24
|
import torch
|
23
25
|
|
@@ -25,109 +27,74 @@ import torch
|
|
25
27
|
def convert_to_tflite(
|
26
28
|
pytorch_model: torch.nn.Module,
|
27
29
|
tflite_path: str,
|
28
|
-
prefill_seq_len: int
|
30
|
+
prefill_seq_len: Union[int, list[int]],
|
31
|
+
pixel_values_size: torch.Size = None,
|
29
32
|
quantize: bool = True,
|
33
|
+
config: cfg.ModelConfig = None,
|
30
34
|
):
|
31
35
|
"""Converts a nn.Module model to multi-signature tflite model.
|
32
36
|
|
33
|
-
A PyTorch model will be converted to a tflite model with
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
"
|
41
|
-
|
42
|
-
|
37
|
+
A PyTorch model will be converted to a tflite model with several signatures:
|
38
|
+
* "prefill_[prefill_seq_len]" (or "prefill" if only one prefill_seq_len is
|
39
|
+
passed),
|
40
|
+
* "prefill_[preill_seq_len]_pixel" (or "prefill_pixel" if only one
|
41
|
+
prefill_seq_len is passed) if num_pixel_values > 0, and
|
42
|
+
* "decode".
|
43
|
+
|
44
|
+
"prefill_[prefill_seq_len]" (or "prefill" if only one prefill_seq_len is
|
45
|
+
passed) signature takes as a sample input:
|
46
|
+
* a tensor of shape [1, prefill_seq_len] of token sequence,
|
47
|
+
* a tensor of shape [1, prefill_seq_len] of token positions, and
|
48
|
+
* an external KV cache.
|
49
|
+
|
50
|
+
If num_pixel_values > 0, "prefill_[prefill_seq_len]_pixel" (or "prefill_pixel"
|
51
|
+
if only one prefill_seq_len is passed) signature takes as a sample input:
|
52
|
+
* a tensor of shape [1, prefill_seq_len] of token sequence,
|
53
|
+
* a tensor of shape [1, prefill_seq_len] of token positions,
|
54
|
+
* an external KV cache, and
|
55
|
+
* a tensor of shape [1, num_pixel_values] of pixel values.
|
56
|
+
|
57
|
+
"decode" signature takes as a sample input:
|
58
|
+
* a tensor of shape [1, 1] of token sequence,
|
59
|
+
* a tensor of shape [1, 1] of the token position, and
|
60
|
+
* an external KV cache.
|
43
61
|
|
44
62
|
The final tflite model will be exported to tflite_path.
|
45
63
|
|
46
64
|
Args:
|
47
65
|
pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
|
48
66
|
tflite_path (str): The tflite file path to export.
|
49
|
-
prefill_seq_len (int,
|
50
|
-
|
67
|
+
prefill_seq_len (Union[int, list[int]]): A list of prefill lengths to
|
68
|
+
export.
|
69
|
+
pixel_values_size (torch.Size, optional): The size of pixel values to pass
|
70
|
+
to the model. If None, the model is not expected to take pixel values.
|
51
71
|
quantize (bool, optional): Whether the model should be quanized. Defaults
|
52
72
|
to True.
|
73
|
+
config (cfg.ModelConfig, optional): The model config used to configure KV
|
74
|
+
cache. If None, it uses the config of the pytorch_model.
|
53
75
|
"""
|
54
|
-
|
55
|
-
|
56
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
57
|
-
decode_token = torch.tensor([[0]], dtype=torch.int)
|
58
|
-
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
59
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
60
|
-
|
61
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
62
|
-
edge_model = (
|
63
|
-
ai_edge_torch.signature(
|
64
|
-
'prefill',
|
65
|
-
pytorch_model,
|
66
|
-
sample_kwargs={
|
67
|
-
'tokens': prefill_tokens,
|
68
|
-
'input_pos': prefill_input_pos,
|
69
|
-
'kv_cache': kv,
|
70
|
-
},
|
71
|
-
)
|
72
|
-
.signature(
|
73
|
-
'decode',
|
74
|
-
pytorch_model,
|
75
|
-
sample_kwargs={
|
76
|
-
'tokens': decode_token,
|
77
|
-
'input_pos': decode_input_pos,
|
78
|
-
'kv_cache': kv,
|
79
|
-
},
|
80
|
-
)
|
81
|
-
.convert(quant_config=quant_config)
|
76
|
+
prefill_seq_lens = (
|
77
|
+
[prefill_seq_len] if isinstance(prefill_seq_len, int) else prefill_seq_len
|
82
78
|
)
|
83
|
-
edge_model.export(tflite_path)
|
84
|
-
|
85
|
-
|
86
|
-
def convert_to_tflite_multi_prefill_lens(
|
87
|
-
pytorch_model: torch.nn.Module,
|
88
|
-
tflite_path: str,
|
89
|
-
prefill_seq_lens: list[int],
|
90
|
-
quantize: bool = True,
|
91
|
-
):
|
92
|
-
"""Converts a nn.Module model to multi-signature tflite model with different
|
93
|
-
|
94
|
-
prefill lengths.
|
95
|
-
|
96
|
-
A PyTorch model will be converted to a tflite model with several signatures:
|
97
|
-
"prefill_[prefill_seq_len]" and "decode".
|
98
|
-
|
99
|
-
"prefill_[prefill_seq_len]" signature takes a tensor of shape [1,
|
100
|
-
prefill_seq_len] of token
|
101
|
-
sequence, a tensor of shape [1, prefill_seq_len] of token positions, and an
|
102
|
-
external KV cache as a sample input.
|
103
|
-
|
104
|
-
"decode" signature takes a tensor of shape [1, 1] of token sequence, a tensor
|
105
|
-
of shape [1, 1] of the token position, and an external KV cache as a sample
|
106
|
-
input.
|
107
|
-
|
108
|
-
The final tflite model will be exported to tflite_path.
|
109
79
|
|
110
|
-
Args:
|
111
|
-
pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
|
112
|
-
tflite_path (str): The tflite file path to export.
|
113
|
-
prefill_seq_lens (list[int]): A list of prefill lengths to export.
|
114
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
115
|
-
to True.
|
116
|
-
"""
|
117
80
|
# Tensors used to trace the model graph during conversion.
|
118
81
|
prefill_tokens_list = []
|
119
82
|
prefill_input_pos_list = []
|
120
|
-
for
|
121
|
-
prefill_tokens_list.append(
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
83
|
+
for seq_len in prefill_seq_lens:
|
84
|
+
prefill_tokens_list.append(torch.full((1, seq_len), 0, dtype=torch.int))
|
85
|
+
prefill_input_pos_list.append(torch.arange(0, seq_len, dtype=torch.int))
|
86
|
+
|
87
|
+
prefill_pixel_values = (
|
88
|
+
torch.full((1,) + pixel_values_size, 0, dtype=torch.float32)
|
89
|
+
if pixel_values_size
|
90
|
+
else None
|
91
|
+
)
|
127
92
|
|
128
93
|
decode_token = torch.tensor([[0]], dtype=torch.int)
|
129
94
|
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
130
|
-
kv = kv_utils.KVCache.from_model_config(
|
95
|
+
kv = kv_utils.KVCache.from_model_config(
|
96
|
+
config if config else pytorch_model.config
|
97
|
+
)
|
131
98
|
|
132
99
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
133
100
|
converter = converter_utils.Converter()
|
@@ -135,8 +102,12 @@ def convert_to_tflite_multi_prefill_lens(
|
|
135
102
|
prefill_seq_len = prefill_seq_lens[i]
|
136
103
|
prefill_tokens = prefill_tokens_list[i]
|
137
104
|
prefill_input_pos = prefill_input_pos_list[i]
|
105
|
+
if i == 0 and len(prefill_seq_lens) == 1:
|
106
|
+
prefill_signature_name = 'prefill'
|
107
|
+
else:
|
108
|
+
prefill_signature_name = f'prefill_{prefill_seq_len}'
|
138
109
|
converter.add_signature(
|
139
|
-
|
110
|
+
prefill_signature_name,
|
140
111
|
pytorch_model,
|
141
112
|
sample_kwargs={
|
142
113
|
'tokens': prefill_tokens,
|
@@ -144,8 +115,19 @@ def convert_to_tflite_multi_prefill_lens(
|
|
144
115
|
'kv_cache': kv,
|
145
116
|
},
|
146
117
|
)
|
118
|
+
if prefill_pixel_values is not None:
|
119
|
+
converter.add_signature(
|
120
|
+
prefill_signature_name + '_pixel',
|
121
|
+
pytorch_model,
|
122
|
+
sample_kwargs={
|
123
|
+
'tokens': prefill_tokens,
|
124
|
+
'input_pos': prefill_input_pos,
|
125
|
+
'kv_cache': kv,
|
126
|
+
'pixel_values': prefill_pixel_values,
|
127
|
+
},
|
128
|
+
)
|
147
129
|
|
148
|
-
|
130
|
+
converter.add_signature(
|
149
131
|
'decode',
|
150
132
|
pytorch_model,
|
151
133
|
sample_kwargs={
|
@@ -153,5 +135,7 @@ def convert_to_tflite_multi_prefill_lens(
|
|
153
135
|
'input_pos': decode_input_pos,
|
154
136
|
'kv_cache': kv,
|
155
137
|
},
|
156
|
-
)
|
138
|
+
)
|
139
|
+
|
140
|
+
edge_model = converter.convert(quant_config=quant_config)
|
157
141
|
edge_model.export(tflite_path)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241119
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=W10ztMY91LZfiK-COm46eLLfufHNyUl2W0DtuM8zeC4,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -45,8 +45,7 @@ ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=-n7
|
|
45
45
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
|
46
46
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
47
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
|
48
|
-
ai_edge_torch/generative/examples/gemma/
|
49
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
|
48
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=Mss7j3yDhyOFCWA93iWh995CLeNBDTVG-gvpj6WBIp0,2226
|
50
49
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
|
51
50
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
|
52
51
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
@@ -61,9 +60,10 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKF
|
|
61
60
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
|
62
61
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
63
62
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
63
|
+
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=dT7dnx1dzGzFiH5gQJ4M6zcTLSRFvSDpi3IuZ9_vd78,2706
|
64
64
|
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=XMeznGBbjRJidv725L6_7XzkYskS2cDjf8NGB18FNhg,4944
|
65
|
-
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=
|
66
|
-
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=
|
65
|
+
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
|
66
|
+
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=pIjsS-IUFevRjFA9153YT1vtWXATGWHsgVQQX_nWaZQ,5280
|
67
67
|
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
|
68
68
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
69
69
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
|
@@ -135,12 +135,12 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
|
|
135
135
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
136
136
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
137
137
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
138
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
139
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
138
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=aZFaheg2sq7rEccch1TZM6W4BSfpJZjrM9Gyp4hVGYs,6351
|
139
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=xWV9O2wuRHc4VNBWuWipiuqXa3AJhiV1nmjewAZHHWM,11177
|
140
140
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
141
|
-
ai_edge_torch/generative/test/utils.py,sha256=
|
141
|
+
ai_edge_torch/generative/test/utils.py,sha256=eQ-hjd1eXuHJF3SJK6_CrjgOZVzmG_4VEdH7Z1gH_lA,1897
|
142
142
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
143
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
143
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=S14STbyxV6A9HKy1BdUo49f2jS6Ij0RL9mVAFUMWYV8,5291
|
144
144
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
145
145
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
|
146
146
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
@@ -193,8 +193,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
193
193
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
194
194
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
195
195
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
196
|
-
ai_edge_torch_nightly-0.3.0.
|
197
|
-
ai_edge_torch_nightly-0.3.0.
|
198
|
-
ai_edge_torch_nightly-0.3.0.
|
199
|
-
ai_edge_torch_nightly-0.3.0.
|
200
|
-
ai_edge_torch_nightly-0.3.0.
|
196
|
+
ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
197
|
+
ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/METADATA,sha256=ikosXUollu7saRd3GUi1IZw78jvgCShWAhtDF3NuUtE,1897
|
198
|
+
ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
|
199
|
+
ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
200
|
+
ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/RECORD,,
|
File without changes
|
File without changes
|