ai-edge-torch-nightly 0.3.0.dev20241116__py3-none-any.whl → 0.3.0.dev20241119__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +3 -3
- ai_edge_torch/generative/examples/{gemma/convert_gemma2_multi_prefills.py → paligemma/convert_to_tflite.py} +23 -25
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +3 -1
- ai_edge_torch/generative/examples/paligemma/paligemma.py +25 -9
- ai_edge_torch/generative/test/test_model_conversion.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion_large.py +50 -1
- ai_edge_torch/generative/test/utils.py +6 -3
- ai_edge_torch/generative/utilities/converter.py +71 -87
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241116.dist-info → ai_edge_torch_nightly-0.3.0.dev20241119.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241116.dist-info → ai_edge_torch_nightly-0.3.0.dev20241119.dist-info}/RECORD +14 -14
- {ai_edge_torch_nightly-0.3.0.dev20241116.dist-info → ai_edge_torch_nightly-0.3.0.dev20241119.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241116.dist-info → ai_edge_torch_nightly-0.3.0.dev20241119.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241116.dist-info → ai_edge_torch_nightly-0.3.0.dev20241119.dist-info}/top_level.txt +0 -0
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
_PREFILL_SEQ_LEN = flags.
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_multi_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
1024,
|
39
|
-
'
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -13,20 +13,24 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""Example of converting a
|
16
|
+
"""Example of converting a PaliGemma model to multi-signature tflite model.
|
17
|
+
|
18
|
+
DISCLAIMER: It works only with ODML Torch conversion backend. Refer to
|
19
|
+
https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#use-odml-torch-conversion-backend-experimental.
|
20
|
+
"""
|
17
21
|
|
18
|
-
import logging
|
19
22
|
import os
|
20
23
|
import pathlib
|
21
24
|
|
22
25
|
from absl import app
|
23
26
|
from absl import flags
|
24
|
-
from ai_edge_torch.generative.examples.
|
27
|
+
from ai_edge_torch.generative.examples.paligemma import paligemma
|
25
28
|
from ai_edge_torch.generative.utilities import converter
|
29
|
+
import torch
|
26
30
|
|
27
31
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
32
|
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
33
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma-3b-224'),
|
30
34
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
35
|
)
|
32
36
|
_TFLITE_PATH = flags.DEFINE_string(
|
@@ -34,16 +38,21 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
34
38
|
'/tmp/',
|
35
39
|
'The tflite file path to export.',
|
36
40
|
)
|
37
|
-
|
41
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
38
42
|
'prefill_seq_len',
|
39
|
-
|
40
|
-
'
|
43
|
+
1024,
|
44
|
+
'The maximum size of prefill input tensor.',
|
41
45
|
)
|
42
46
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
43
47
|
'kv_cache_max_len',
|
44
48
|
1280,
|
45
49
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
46
50
|
)
|
51
|
+
_PIXEL_VALUES_SIZE = flags.DEFINE_multi_integer(
|
52
|
+
'pixel_values_size',
|
53
|
+
[3, 224, 224],
|
54
|
+
'The size of prefill pixel values except the batch dimension.',
|
55
|
+
)
|
47
56
|
_QUANTIZE = flags.DEFINE_bool(
|
48
57
|
'quantize',
|
49
58
|
True,
|
@@ -51,32 +60,21 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
51
60
|
)
|
52
61
|
|
53
62
|
|
54
|
-
|
55
|
-
|
56
|
-
# with multiple prefill signatures for different prefill lengths for faster
|
57
|
-
# inference.
|
58
|
-
def convert_to_tflite_multi_prefill_lens():
|
59
|
-
pytorch_model = gemma2.build_2b_model(
|
63
|
+
def main(_):
|
64
|
+
pytorch_model = paligemma.build_model(
|
60
65
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
61
66
|
)
|
62
67
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
63
|
-
output_filename = f'
|
64
|
-
converter.
|
68
|
+
output_filename = f'paligemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
69
|
+
converter.convert_to_tflite(
|
65
70
|
pytorch_model,
|
66
71
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
67
|
-
|
72
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
73
|
+
pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
|
68
74
|
quantize=_QUANTIZE.value,
|
75
|
+
config=pytorch_model.config.decoder_config,
|
69
76
|
)
|
70
77
|
|
71
78
|
|
72
|
-
def main(_):
|
73
|
-
if len(_PREFILL_SEQ_LENS.value) > 1:
|
74
|
-
# If multiple prefill lengths are provided, export a model with multiple
|
75
|
-
# prefill signatures each for a different prefill length.
|
76
|
-
convert_to_tflite_multi_prefill_lens()
|
77
|
-
else:
|
78
|
-
logging.warning('Need more than one prefill lengths to be specified.')
|
79
|
-
|
80
|
-
|
81
79
|
if __name__ == '__main__':
|
82
80
|
app.run(main)
|
@@ -59,7 +59,7 @@ class SiglipVisionEncoder(nn.Module):
|
|
59
59
|
out_channels=config.embedding_dim,
|
60
60
|
kernel_size=config.image_embedding.patch_size,
|
61
61
|
stride=config.image_embedding.patch_size,
|
62
|
-
padding=
|
62
|
+
padding=0,
|
63
63
|
)
|
64
64
|
num_patches = (
|
65
65
|
config.image_embedding.image_size // config.image_embedding.patch_size
|
@@ -144,6 +144,8 @@ def get_fake_image_encoder_config() -> cfg.ModelConfig:
|
|
144
144
|
config = get_image_encoder_config()
|
145
145
|
# PaliGemma image encoder has only one block config.
|
146
146
|
config.block_config(0).ff_config.intermediate_size = 128
|
147
|
+
config.image_embedding.image_size = 8
|
148
|
+
config.image_embedding.patch_size = 2
|
147
149
|
config.num_layers = 2
|
148
150
|
return config
|
149
151
|
|
@@ -54,6 +54,10 @@ class PaliGemma(nn.Module):
|
|
54
54
|
bias=config.image_projection_use_bias,
|
55
55
|
)
|
56
56
|
self.decoder = decoder.Decoder(config.decoder_config)
|
57
|
+
image_embedding_config = config.image_encoder_config.image_embedding
|
58
|
+
self.num_patches = (
|
59
|
+
image_embedding_config.image_size // image_embedding_config.patch_size
|
60
|
+
) ** 2
|
57
61
|
self.config = config
|
58
62
|
|
59
63
|
@torch.inference_mode
|
@@ -74,10 +78,22 @@ class PaliGemma(nn.Module):
|
|
74
78
|
if self.config.decoder_config.embedding_scale is not None:
|
75
79
|
image_embeds = image_embeds / self.config.decoder_config.embedding_scale
|
76
80
|
|
77
|
-
#
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
+
# Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
|
82
|
+
# can be done like:
|
83
|
+
#
|
84
|
+
# image_mask = tokens == self.config.image_token_id
|
85
|
+
# image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
|
86
|
+
# input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
|
87
|
+
#
|
88
|
+
# Unfortunately, torch.Tensor.masked_scatter can't be lowered on CPU.
|
89
|
+
# Since PaliGemma token embedder reserves the first [num_patches] tokens
|
90
|
+
# for image tokens, we can use this property to merge image_embeds into
|
91
|
+
# input_embeds by concatenating them.
|
92
|
+
assert image_embeds.shape[1] == self.num_patches
|
93
|
+
assert input_embeds.shape[1] >= self.num_patches
|
94
|
+
input_embeds = torch.cat(
|
95
|
+
(image_embeds, input_embeds[:, self.num_patches:, :]), dim=1
|
96
|
+
)
|
81
97
|
|
82
98
|
return self.decoder(
|
83
99
|
tokens=None,
|
@@ -87,7 +103,7 @@ class PaliGemma(nn.Module):
|
|
87
103
|
)
|
88
104
|
|
89
105
|
|
90
|
-
def get_model_config() -> PaliGemmaConfig:
|
106
|
+
def get_model_config(**kwargs) -> PaliGemmaConfig:
|
91
107
|
"""Returns the model config for a PaliGemma 3B-224 model.
|
92
108
|
|
93
109
|
Returns:
|
@@ -95,13 +111,13 @@ def get_model_config() -> PaliGemmaConfig:
|
|
95
111
|
"""
|
96
112
|
return PaliGemmaConfig(
|
97
113
|
image_encoder_config=image_encoder.get_image_encoder_config(),
|
98
|
-
decoder_config=decoder.get_decoder_config(),
|
114
|
+
decoder_config=decoder.get_decoder_config(**kwargs),
|
99
115
|
image_projection_use_bias=True,
|
100
116
|
image_token_id=257152,
|
101
117
|
)
|
102
118
|
|
103
119
|
|
104
|
-
def
|
120
|
+
def get_fake_model_config() -> PaliGemmaConfig:
|
105
121
|
return PaliGemmaConfig(
|
106
122
|
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
107
123
|
decoder_config=decoder.get_fake_decoder_config(),
|
@@ -110,8 +126,8 @@ def get_fake_image_encoder_config() -> PaliGemmaConfig:
|
|
110
126
|
)
|
111
127
|
|
112
128
|
|
113
|
-
def build_model(checkpoint_path: str) -> PaliGemma:
|
114
|
-
config = get_model_config()
|
129
|
+
def build_model(checkpoint_path: str, **kwargs) -> PaliGemma:
|
130
|
+
config = get_model_config(**kwargs)
|
115
131
|
model = PaliGemma(config)
|
116
132
|
# Load the parameters of image encoder.
|
117
133
|
loader = loading_utils.ModelLoader(
|
@@ -117,7 +117,7 @@ class TestModelConversion(googletest.TestCase):
|
|
117
117
|
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
|
118
118
|
# prefill
|
119
119
|
seq_len = 10
|
120
|
-
prefill_tokens = torch.
|
120
|
+
prefill_tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
|
121
121
|
prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
|
122
122
|
prefill_tokens[0, : len(prompt_token)] = prompt_token
|
123
123
|
prefill_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
@@ -22,6 +22,7 @@ from ai_edge_torch.generative.examples.gemma import gemma1
|
|
22
22
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
23
|
from ai_edge_torch.generative.examples.llama import llama
|
24
24
|
from ai_edge_torch.generative.examples.openelm import openelm
|
25
|
+
from ai_edge_torch.generative.examples.paligemma import paligemma
|
25
26
|
from ai_edge_torch.generative.examples.phi import phi2
|
26
27
|
from ai_edge_torch.generative.examples.phi import phi3
|
27
28
|
from ai_edge_torch.generative.examples.qwen import qwen
|
@@ -55,7 +56,7 @@ class TestModelConversion(googletest.TestCase):
|
|
55
56
|
|
56
57
|
def _test_model(self, config, model, signature_name, atol, rtol):
|
57
58
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
58
|
-
tokens = torch.
|
59
|
+
tokens = torch.zeros((1, 10), dtype=torch.int, device="cpu")
|
59
60
|
tokens[0, :4] = idx
|
60
61
|
input_pos = torch.arange(0, 10, dtype=torch.int)
|
61
62
|
kv = kv_cache.KVCache.from_model_config(config)
|
@@ -171,6 +172,54 @@ class TestModelConversion(googletest.TestCase):
|
|
171
172
|
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
172
173
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
173
174
|
|
175
|
+
@googletest.skipIf(
|
176
|
+
ai_edge_config.Config.use_torch_xla,
|
177
|
+
reason="tests with custom ops are not supported on oss",
|
178
|
+
)
|
179
|
+
def test_paligemma(self):
|
180
|
+
config = paligemma.get_fake_model_config()
|
181
|
+
pytorch_model = paligemma.PaliGemma(config).eval()
|
182
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
183
|
+
image_embedding_config = config.image_encoder_config.image_embedding
|
184
|
+
num_patches = (
|
185
|
+
image_embedding_config.image_size // image_embedding_config.patch_size
|
186
|
+
) ** 2
|
187
|
+
# Make sure the token size is longer than the number of image patches.
|
188
|
+
tokens_len = num_patches + 10
|
189
|
+
tokens = torch.zeros((1, tokens_len), dtype=torch.int, device="cpu")
|
190
|
+
tokens[0, :4] = idx
|
191
|
+
input_pos = torch.arange(0, tokens_len, dtype=torch.int)
|
192
|
+
kv = kv_cache.KVCache.from_model_config(config.decoder_config)
|
193
|
+
pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32, device="cpu")
|
194
|
+
|
195
|
+
edge_model = ai_edge_torch.signature(
|
196
|
+
"prefill_pixel",
|
197
|
+
pytorch_model,
|
198
|
+
sample_kwargs={
|
199
|
+
"tokens": tokens,
|
200
|
+
"input_pos": input_pos,
|
201
|
+
"kv_cache": kv,
|
202
|
+
"pixel_values": pixel_values,
|
203
|
+
},
|
204
|
+
).convert()
|
205
|
+
edge_model.set_interpreter_builder(
|
206
|
+
self._interpreter_builder(edge_model.tflite_model())
|
207
|
+
)
|
208
|
+
|
209
|
+
self.assertTrue(
|
210
|
+
test_utils.compare_tflite_torch(
|
211
|
+
edge_model,
|
212
|
+
pytorch_model,
|
213
|
+
tokens,
|
214
|
+
input_pos,
|
215
|
+
kv,
|
216
|
+
pixel_values=pixel_values,
|
217
|
+
signature_name="prefill_pixel",
|
218
|
+
atol=1e-3,
|
219
|
+
rtol=1e-5,
|
220
|
+
)
|
221
|
+
)
|
222
|
+
|
174
223
|
@googletest.skipIf(
|
175
224
|
ai_edge_config.Config.use_torch_xla,
|
176
225
|
reason="tests with custom ops are not supported on oss",
|
@@ -32,18 +32,21 @@ def compare_tflite_torch(
|
|
32
32
|
signature_name: str,
|
33
33
|
atol: float = 1e-5,
|
34
34
|
rtol: float = 1e-5,
|
35
|
+
**kwargs,
|
35
36
|
):
|
36
37
|
"""Compares torch models and TFLite models."""
|
37
38
|
values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
|
38
39
|
flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
|
39
|
-
torch_output = torch_model(tokens, input_pos, kv_cache)
|
40
|
+
torch_output = torch_model(tokens, input_pos, kv_cache, **kwargs)
|
40
41
|
|
41
|
-
|
42
|
+
if "pixel_values" in kwargs:
|
43
|
+
kwargs["pixel_values"] = kwargs["pixel_values"].numpy()
|
44
|
+
kwargs.update({k: v.numpy() for k, v in zip(flat_names, values)})
|
42
45
|
edge_output = edge_model(
|
43
46
|
signature_name=signature_name,
|
44
47
|
tokens=tokens.numpy(),
|
45
48
|
input_pos=input_pos.numpy(),
|
46
|
-
**
|
49
|
+
**kwargs,
|
47
50
|
)
|
48
51
|
|
49
52
|
return np.allclose(
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
"""Common utility functions for model conversion."""
|
17
17
|
|
18
|
-
import
|
18
|
+
from typing import Union
|
19
|
+
|
19
20
|
from ai_edge_torch._convert import converter as converter_utils
|
20
|
-
|
21
|
+
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
22
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
21
23
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
24
|
import torch
|
23
25
|
|
@@ -25,109 +27,74 @@ import torch
|
|
25
27
|
def convert_to_tflite(
|
26
28
|
pytorch_model: torch.nn.Module,
|
27
29
|
tflite_path: str,
|
28
|
-
prefill_seq_len: int
|
30
|
+
prefill_seq_len: Union[int, list[int]],
|
31
|
+
pixel_values_size: torch.Size = None,
|
29
32
|
quantize: bool = True,
|
33
|
+
config: cfg.ModelConfig = None,
|
30
34
|
):
|
31
35
|
"""Converts a nn.Module model to multi-signature tflite model.
|
32
36
|
|
33
|
-
A PyTorch model will be converted to a tflite model with
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
"
|
41
|
-
|
42
|
-
|
37
|
+
A PyTorch model will be converted to a tflite model with several signatures:
|
38
|
+
* "prefill_[prefill_seq_len]" (or "prefill" if only one prefill_seq_len is
|
39
|
+
passed),
|
40
|
+
* "prefill_[preill_seq_len]_pixel" (or "prefill_pixel" if only one
|
41
|
+
prefill_seq_len is passed) if num_pixel_values > 0, and
|
42
|
+
* "decode".
|
43
|
+
|
44
|
+
"prefill_[prefill_seq_len]" (or "prefill" if only one prefill_seq_len is
|
45
|
+
passed) signature takes as a sample input:
|
46
|
+
* a tensor of shape [1, prefill_seq_len] of token sequence,
|
47
|
+
* a tensor of shape [1, prefill_seq_len] of token positions, and
|
48
|
+
* an external KV cache.
|
49
|
+
|
50
|
+
If num_pixel_values > 0, "prefill_[prefill_seq_len]_pixel" (or "prefill_pixel"
|
51
|
+
if only one prefill_seq_len is passed) signature takes as a sample input:
|
52
|
+
* a tensor of shape [1, prefill_seq_len] of token sequence,
|
53
|
+
* a tensor of shape [1, prefill_seq_len] of token positions,
|
54
|
+
* an external KV cache, and
|
55
|
+
* a tensor of shape [1, num_pixel_values] of pixel values.
|
56
|
+
|
57
|
+
"decode" signature takes as a sample input:
|
58
|
+
* a tensor of shape [1, 1] of token sequence,
|
59
|
+
* a tensor of shape [1, 1] of the token position, and
|
60
|
+
* an external KV cache.
|
43
61
|
|
44
62
|
The final tflite model will be exported to tflite_path.
|
45
63
|
|
46
64
|
Args:
|
47
65
|
pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
|
48
66
|
tflite_path (str): The tflite file path to export.
|
49
|
-
prefill_seq_len (int,
|
50
|
-
|
67
|
+
prefill_seq_len (Union[int, list[int]]): A list of prefill lengths to
|
68
|
+
export.
|
69
|
+
pixel_values_size (torch.Size, optional): The size of pixel values to pass
|
70
|
+
to the model. If None, the model is not expected to take pixel values.
|
51
71
|
quantize (bool, optional): Whether the model should be quanized. Defaults
|
52
72
|
to True.
|
73
|
+
config (cfg.ModelConfig, optional): The model config used to configure KV
|
74
|
+
cache. If None, it uses the config of the pytorch_model.
|
53
75
|
"""
|
54
|
-
|
55
|
-
|
56
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
57
|
-
decode_token = torch.tensor([[0]], dtype=torch.int)
|
58
|
-
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
59
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
60
|
-
|
61
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
62
|
-
edge_model = (
|
63
|
-
ai_edge_torch.signature(
|
64
|
-
'prefill',
|
65
|
-
pytorch_model,
|
66
|
-
sample_kwargs={
|
67
|
-
'tokens': prefill_tokens,
|
68
|
-
'input_pos': prefill_input_pos,
|
69
|
-
'kv_cache': kv,
|
70
|
-
},
|
71
|
-
)
|
72
|
-
.signature(
|
73
|
-
'decode',
|
74
|
-
pytorch_model,
|
75
|
-
sample_kwargs={
|
76
|
-
'tokens': decode_token,
|
77
|
-
'input_pos': decode_input_pos,
|
78
|
-
'kv_cache': kv,
|
79
|
-
},
|
80
|
-
)
|
81
|
-
.convert(quant_config=quant_config)
|
76
|
+
prefill_seq_lens = (
|
77
|
+
[prefill_seq_len] if isinstance(prefill_seq_len, int) else prefill_seq_len
|
82
78
|
)
|
83
|
-
edge_model.export(tflite_path)
|
84
|
-
|
85
|
-
|
86
|
-
def convert_to_tflite_multi_prefill_lens(
|
87
|
-
pytorch_model: torch.nn.Module,
|
88
|
-
tflite_path: str,
|
89
|
-
prefill_seq_lens: list[int],
|
90
|
-
quantize: bool = True,
|
91
|
-
):
|
92
|
-
"""Converts a nn.Module model to multi-signature tflite model with different
|
93
|
-
|
94
|
-
prefill lengths.
|
95
|
-
|
96
|
-
A PyTorch model will be converted to a tflite model with several signatures:
|
97
|
-
"prefill_[prefill_seq_len]" and "decode".
|
98
|
-
|
99
|
-
"prefill_[prefill_seq_len]" signature takes a tensor of shape [1,
|
100
|
-
prefill_seq_len] of token
|
101
|
-
sequence, a tensor of shape [1, prefill_seq_len] of token positions, and an
|
102
|
-
external KV cache as a sample input.
|
103
|
-
|
104
|
-
"decode" signature takes a tensor of shape [1, 1] of token sequence, a tensor
|
105
|
-
of shape [1, 1] of the token position, and an external KV cache as a sample
|
106
|
-
input.
|
107
|
-
|
108
|
-
The final tflite model will be exported to tflite_path.
|
109
79
|
|
110
|
-
Args:
|
111
|
-
pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
|
112
|
-
tflite_path (str): The tflite file path to export.
|
113
|
-
prefill_seq_lens (list[int]): A list of prefill lengths to export.
|
114
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
115
|
-
to True.
|
116
|
-
"""
|
117
80
|
# Tensors used to trace the model graph during conversion.
|
118
81
|
prefill_tokens_list = []
|
119
82
|
prefill_input_pos_list = []
|
120
|
-
for
|
121
|
-
prefill_tokens_list.append(
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
83
|
+
for seq_len in prefill_seq_lens:
|
84
|
+
prefill_tokens_list.append(torch.full((1, seq_len), 0, dtype=torch.int))
|
85
|
+
prefill_input_pos_list.append(torch.arange(0, seq_len, dtype=torch.int))
|
86
|
+
|
87
|
+
prefill_pixel_values = (
|
88
|
+
torch.full((1,) + pixel_values_size, 0, dtype=torch.float32)
|
89
|
+
if pixel_values_size
|
90
|
+
else None
|
91
|
+
)
|
127
92
|
|
128
93
|
decode_token = torch.tensor([[0]], dtype=torch.int)
|
129
94
|
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
130
|
-
kv = kv_utils.KVCache.from_model_config(
|
95
|
+
kv = kv_utils.KVCache.from_model_config(
|
96
|
+
config if config else pytorch_model.config
|
97
|
+
)
|
131
98
|
|
132
99
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
133
100
|
converter = converter_utils.Converter()
|
@@ -135,8 +102,12 @@ def convert_to_tflite_multi_prefill_lens(
|
|
135
102
|
prefill_seq_len = prefill_seq_lens[i]
|
136
103
|
prefill_tokens = prefill_tokens_list[i]
|
137
104
|
prefill_input_pos = prefill_input_pos_list[i]
|
105
|
+
if i == 0 and len(prefill_seq_lens) == 1:
|
106
|
+
prefill_signature_name = 'prefill'
|
107
|
+
else:
|
108
|
+
prefill_signature_name = f'prefill_{prefill_seq_len}'
|
138
109
|
converter.add_signature(
|
139
|
-
|
110
|
+
prefill_signature_name,
|
140
111
|
pytorch_model,
|
141
112
|
sample_kwargs={
|
142
113
|
'tokens': prefill_tokens,
|
@@ -144,8 +115,19 @@ def convert_to_tflite_multi_prefill_lens(
|
|
144
115
|
'kv_cache': kv,
|
145
116
|
},
|
146
117
|
)
|
118
|
+
if prefill_pixel_values is not None:
|
119
|
+
converter.add_signature(
|
120
|
+
prefill_signature_name + '_pixel',
|
121
|
+
pytorch_model,
|
122
|
+
sample_kwargs={
|
123
|
+
'tokens': prefill_tokens,
|
124
|
+
'input_pos': prefill_input_pos,
|
125
|
+
'kv_cache': kv,
|
126
|
+
'pixel_values': prefill_pixel_values,
|
127
|
+
},
|
128
|
+
)
|
147
129
|
|
148
|
-
|
130
|
+
converter.add_signature(
|
149
131
|
'decode',
|
150
132
|
pytorch_model,
|
151
133
|
sample_kwargs={
|
@@ -153,5 +135,7 @@ def convert_to_tflite_multi_prefill_lens(
|
|
153
135
|
'input_pos': decode_input_pos,
|
154
136
|
'kv_cache': kv,
|
155
137
|
},
|
156
|
-
)
|
138
|
+
)
|
139
|
+
|
140
|
+
edge_model = converter.convert(quant_config=quant_config)
|
157
141
|
edge_model.export(tflite_path)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241119
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=W10ztMY91LZfiK-COm46eLLfufHNyUl2W0DtuM8zeC4,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -45,8 +45,7 @@ ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=-n7
|
|
45
45
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
|
46
46
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
47
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
|
48
|
-
ai_edge_torch/generative/examples/gemma/
|
49
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
|
48
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=Mss7j3yDhyOFCWA93iWh995CLeNBDTVG-gvpj6WBIp0,2226
|
50
49
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
|
51
50
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
|
52
51
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
@@ -61,9 +60,10 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKF
|
|
61
60
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
|
62
61
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
63
62
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
63
|
+
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=dT7dnx1dzGzFiH5gQJ4M6zcTLSRFvSDpi3IuZ9_vd78,2706
|
64
64
|
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=XMeznGBbjRJidv725L6_7XzkYskS2cDjf8NGB18FNhg,4944
|
65
|
-
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=
|
66
|
-
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=
|
65
|
+
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
|
66
|
+
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=pIjsS-IUFevRjFA9153YT1vtWXATGWHsgVQQX_nWaZQ,5280
|
67
67
|
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
|
68
68
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
69
69
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
|
@@ -135,12 +135,12 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
|
|
135
135
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
136
136
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
137
137
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
138
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
139
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
138
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=aZFaheg2sq7rEccch1TZM6W4BSfpJZjrM9Gyp4hVGYs,6351
|
139
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=xWV9O2wuRHc4VNBWuWipiuqXa3AJhiV1nmjewAZHHWM,11177
|
140
140
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
141
|
-
ai_edge_torch/generative/test/utils.py,sha256=
|
141
|
+
ai_edge_torch/generative/test/utils.py,sha256=eQ-hjd1eXuHJF3SJK6_CrjgOZVzmG_4VEdH7Z1gH_lA,1897
|
142
142
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
143
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
143
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=S14STbyxV6A9HKy1BdUo49f2jS6Ij0RL9mVAFUMWYV8,5291
|
144
144
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
145
145
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
|
146
146
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
@@ -193,8 +193,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
193
193
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
194
194
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
195
195
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
196
|
-
ai_edge_torch_nightly-0.3.0.
|
197
|
-
ai_edge_torch_nightly-0.3.0.
|
198
|
-
ai_edge_torch_nightly-0.3.0.
|
199
|
-
ai_edge_torch_nightly-0.3.0.
|
200
|
-
ai_edge_torch_nightly-0.3.0.
|
196
|
+
ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
197
|
+
ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/METADATA,sha256=ikosXUollu7saRd3GUi1IZw78jvgCShWAhtDF3NuUtE,1897
|
198
|
+
ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
|
199
|
+
ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
200
|
+
ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/RECORD,,
|
File without changes
|
File without changes
|