ai-edge-torch-nightly 0.3.0.dev20241117__py3-none-any.whl → 0.3.0.dev20241119__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_multi_integer(
37
37
  'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -13,20 +13,24 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of converting a Gemma2 model to multi-signature tflite model, with multiple prefill lengths."""
16
+ """Example of converting a PaliGemma model to multi-signature tflite model.
17
+
18
+ DISCLAIMER: It works only with ODML Torch conversion backend. Refer to
19
+ https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#use-odml-torch-conversion-backend-experimental.
20
+ """
17
21
 
18
- import logging
19
22
  import os
20
23
  import pathlib
21
24
 
22
25
  from absl import app
23
26
  from absl import flags
24
- from ai_edge_torch.generative.examples.gemma import gemma2
27
+ from ai_edge_torch.generative.examples.paligemma import paligemma
25
28
  from ai_edge_torch.generative.utilities import converter
29
+ import torch
26
30
 
27
31
  _CHECKPOINT_PATH = flags.DEFINE_string(
28
32
  'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
33
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma-3b-224'),
30
34
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
35
  )
32
36
  _TFLITE_PATH = flags.DEFINE_string(
@@ -34,16 +38,21 @@ _TFLITE_PATH = flags.DEFINE_string(
34
38
  '/tmp/',
35
39
  'The tflite file path to export.',
36
40
  )
37
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
41
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
38
42
  'prefill_seq_len',
39
- (8, 64, 128, 256, 512, 1024),
40
- 'A list of prefill lengths to export.',
43
+ 1024,
44
+ 'The maximum size of prefill input tensor.',
41
45
  )
42
46
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
43
47
  'kv_cache_max_len',
44
48
  1280,
45
49
  'The maximum size of KV cache buffer, including both prefill and decode.',
46
50
  )
51
+ _PIXEL_VALUES_SIZE = flags.DEFINE_multi_integer(
52
+ 'pixel_values_size',
53
+ [3, 224, 224],
54
+ 'The size of prefill pixel values except the batch dimension.',
55
+ )
47
56
  _QUANTIZE = flags.DEFINE_bool(
48
57
  'quantize',
49
58
  True,
@@ -51,32 +60,21 @@ _QUANTIZE = flags.DEFINE_bool(
51
60
  )
52
61
 
53
62
 
54
- # Note that the converted model is not compatible with LLM Inference engine for
55
- # now. The main purpose for this function is to allow you export a tflite model
56
- # with multiple prefill signatures for different prefill lengths for faster
57
- # inference.
58
- def convert_to_tflite_multi_prefill_lens():
59
- pytorch_model = gemma2.build_2b_model(
63
+ def main(_):
64
+ pytorch_model = paligemma.build_model(
60
65
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
61
66
  )
62
67
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
63
- output_filename = f'gemma2_{quant_suffix}_multi-prefill-seq_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
64
- converter.convert_to_tflite_multi_prefill_lens(
68
+ output_filename = f'paligemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
69
+ converter.convert_to_tflite(
65
70
  pytorch_model,
66
71
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
67
- prefill_seq_lens=_PREFILL_SEQ_LENS.value,
72
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
73
+ pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
68
74
  quantize=_QUANTIZE.value,
75
+ config=pytorch_model.config.decoder_config,
69
76
  )
70
77
 
71
78
 
72
- def main(_):
73
- if len(_PREFILL_SEQ_LENS.value) > 1:
74
- # If multiple prefill lengths are provided, export a model with multiple
75
- # prefill signatures each for a different prefill length.
76
- convert_to_tflite_multi_prefill_lens()
77
- else:
78
- logging.warning('Need more than one prefill lengths to be specified.')
79
-
80
-
81
79
  if __name__ == '__main__':
82
80
  app.run(main)
@@ -59,7 +59,7 @@ class SiglipVisionEncoder(nn.Module):
59
59
  out_channels=config.embedding_dim,
60
60
  kernel_size=config.image_embedding.patch_size,
61
61
  stride=config.image_embedding.patch_size,
62
- padding="valid",
62
+ padding=0,
63
63
  )
64
64
  num_patches = (
65
65
  config.image_embedding.image_size // config.image_embedding.patch_size
@@ -144,6 +144,8 @@ def get_fake_image_encoder_config() -> cfg.ModelConfig:
144
144
  config = get_image_encoder_config()
145
145
  # PaliGemma image encoder has only one block config.
146
146
  config.block_config(0).ff_config.intermediate_size = 128
147
+ config.image_embedding.image_size = 8
148
+ config.image_embedding.patch_size = 2
147
149
  config.num_layers = 2
148
150
  return config
149
151
 
@@ -54,6 +54,10 @@ class PaliGemma(nn.Module):
54
54
  bias=config.image_projection_use_bias,
55
55
  )
56
56
  self.decoder = decoder.Decoder(config.decoder_config)
57
+ image_embedding_config = config.image_encoder_config.image_embedding
58
+ self.num_patches = (
59
+ image_embedding_config.image_size // image_embedding_config.patch_size
60
+ ) ** 2
57
61
  self.config = config
58
62
 
59
63
  @torch.inference_mode
@@ -74,10 +78,22 @@ class PaliGemma(nn.Module):
74
78
  if self.config.decoder_config.embedding_scale is not None:
75
79
  image_embeds = image_embeds / self.config.decoder_config.embedding_scale
76
80
 
77
- # Merge image_embeds into text_embeds as PaliGemmaForConditionalGeneration.
78
- image_mask = tokens == self.config.image_token_id
79
- image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
80
- input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
81
+ # Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
82
+ # can be done like:
83
+ #
84
+ # image_mask = tokens == self.config.image_token_id
85
+ # image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
86
+ # input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
87
+ #
88
+ # Unfortunately, torch.Tensor.masked_scatter can't be lowered on CPU.
89
+ # Since PaliGemma token embedder reserves the first [num_patches] tokens
90
+ # for image tokens, we can use this property to merge image_embeds into
91
+ # input_embeds by concatenating them.
92
+ assert image_embeds.shape[1] == self.num_patches
93
+ assert input_embeds.shape[1] >= self.num_patches
94
+ input_embeds = torch.cat(
95
+ (image_embeds, input_embeds[:, self.num_patches:, :]), dim=1
96
+ )
81
97
 
82
98
  return self.decoder(
83
99
  tokens=None,
@@ -87,7 +103,7 @@ class PaliGemma(nn.Module):
87
103
  )
88
104
 
89
105
 
90
- def get_model_config() -> PaliGemmaConfig:
106
+ def get_model_config(**kwargs) -> PaliGemmaConfig:
91
107
  """Returns the model config for a PaliGemma 3B-224 model.
92
108
 
93
109
  Returns:
@@ -95,13 +111,13 @@ def get_model_config() -> PaliGemmaConfig:
95
111
  """
96
112
  return PaliGemmaConfig(
97
113
  image_encoder_config=image_encoder.get_image_encoder_config(),
98
- decoder_config=decoder.get_decoder_config(),
114
+ decoder_config=decoder.get_decoder_config(**kwargs),
99
115
  image_projection_use_bias=True,
100
116
  image_token_id=257152,
101
117
  )
102
118
 
103
119
 
104
- def get_fake_image_encoder_config() -> PaliGemmaConfig:
120
+ def get_fake_model_config() -> PaliGemmaConfig:
105
121
  return PaliGemmaConfig(
106
122
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
107
123
  decoder_config=decoder.get_fake_decoder_config(),
@@ -110,8 +126,8 @@ def get_fake_image_encoder_config() -> PaliGemmaConfig:
110
126
  )
111
127
 
112
128
 
113
- def build_model(checkpoint_path: str) -> PaliGemma:
114
- config = get_model_config()
129
+ def build_model(checkpoint_path: str, **kwargs) -> PaliGemma:
130
+ config = get_model_config(**kwargs)
115
131
  model = PaliGemma(config)
116
132
  # Load the parameters of image encoder.
117
133
  loader = loading_utils.ModelLoader(
@@ -117,7 +117,7 @@ class TestModelConversion(googletest.TestCase):
117
117
  def _test_multisig_model(self, config, pytorch_model, atol, rtol):
118
118
  # prefill
119
119
  seq_len = 10
120
- prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.int, device="cpu")
120
+ prefill_tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
121
121
  prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
122
122
  prefill_tokens[0, : len(prompt_token)] = prompt_token
123
123
  prefill_input_pos = torch.arange(0, seq_len, dtype=torch.int)
@@ -22,6 +22,7 @@ from ai_edge_torch.generative.examples.gemma import gemma1
22
22
  from ai_edge_torch.generative.examples.gemma import gemma2
23
23
  from ai_edge_torch.generative.examples.llama import llama
24
24
  from ai_edge_torch.generative.examples.openelm import openelm
25
+ from ai_edge_torch.generative.examples.paligemma import paligemma
25
26
  from ai_edge_torch.generative.examples.phi import phi2
26
27
  from ai_edge_torch.generative.examples.phi import phi3
27
28
  from ai_edge_torch.generative.examples.qwen import qwen
@@ -55,7 +56,7 @@ class TestModelConversion(googletest.TestCase):
55
56
 
56
57
  def _test_model(self, config, model, signature_name, atol, rtol):
57
58
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
58
- tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
59
+ tokens = torch.zeros((1, 10), dtype=torch.int, device="cpu")
59
60
  tokens[0, :4] = idx
60
61
  input_pos = torch.arange(0, 10, dtype=torch.int)
61
62
  kv = kv_cache.KVCache.from_model_config(config)
@@ -171,6 +172,54 @@ class TestModelConversion(googletest.TestCase):
171
172
  pytorch_model = model_builder.DecoderOnlyModel(config).eval()
172
173
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
173
174
 
175
+ @googletest.skipIf(
176
+ ai_edge_config.Config.use_torch_xla,
177
+ reason="tests with custom ops are not supported on oss",
178
+ )
179
+ def test_paligemma(self):
180
+ config = paligemma.get_fake_model_config()
181
+ pytorch_model = paligemma.PaliGemma(config).eval()
182
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
183
+ image_embedding_config = config.image_encoder_config.image_embedding
184
+ num_patches = (
185
+ image_embedding_config.image_size // image_embedding_config.patch_size
186
+ ) ** 2
187
+ # Make sure the token size is longer than the number of image patches.
188
+ tokens_len = num_patches + 10
189
+ tokens = torch.zeros((1, tokens_len), dtype=torch.int, device="cpu")
190
+ tokens[0, :4] = idx
191
+ input_pos = torch.arange(0, tokens_len, dtype=torch.int)
192
+ kv = kv_cache.KVCache.from_model_config(config.decoder_config)
193
+ pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32, device="cpu")
194
+
195
+ edge_model = ai_edge_torch.signature(
196
+ "prefill_pixel",
197
+ pytorch_model,
198
+ sample_kwargs={
199
+ "tokens": tokens,
200
+ "input_pos": input_pos,
201
+ "kv_cache": kv,
202
+ "pixel_values": pixel_values,
203
+ },
204
+ ).convert()
205
+ edge_model.set_interpreter_builder(
206
+ self._interpreter_builder(edge_model.tflite_model())
207
+ )
208
+
209
+ self.assertTrue(
210
+ test_utils.compare_tflite_torch(
211
+ edge_model,
212
+ pytorch_model,
213
+ tokens,
214
+ input_pos,
215
+ kv,
216
+ pixel_values=pixel_values,
217
+ signature_name="prefill_pixel",
218
+ atol=1e-3,
219
+ rtol=1e-5,
220
+ )
221
+ )
222
+
174
223
  @googletest.skipIf(
175
224
  ai_edge_config.Config.use_torch_xla,
176
225
  reason="tests with custom ops are not supported on oss",
@@ -32,18 +32,21 @@ def compare_tflite_torch(
32
32
  signature_name: str,
33
33
  atol: float = 1e-5,
34
34
  rtol: float = 1e-5,
35
+ **kwargs,
35
36
  ):
36
37
  """Compares torch models and TFLite models."""
37
38
  values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
38
39
  flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
39
- torch_output = torch_model(tokens, input_pos, kv_cache)
40
+ torch_output = torch_model(tokens, input_pos, kv_cache, **kwargs)
40
41
 
41
- input_kv_flatten = {k: v.numpy() for k, v in zip(flat_names, values)}
42
+ if "pixel_values" in kwargs:
43
+ kwargs["pixel_values"] = kwargs["pixel_values"].numpy()
44
+ kwargs.update({k: v.numpy() for k, v in zip(flat_names, values)})
42
45
  edge_output = edge_model(
43
46
  signature_name=signature_name,
44
47
  tokens=tokens.numpy(),
45
48
  input_pos=input_pos.numpy(),
46
- **input_kv_flatten,
49
+ **kwargs,
47
50
  )
48
51
 
49
52
  return np.allclose(
@@ -15,9 +15,11 @@
15
15
 
16
16
  """Common utility functions for model conversion."""
17
17
 
18
- import ai_edge_torch
18
+ from typing import Union
19
+
19
20
  from ai_edge_torch._convert import converter as converter_utils
20
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
+ import ai_edge_torch.generative.layers.kv_cache as kv_utils
22
+ import ai_edge_torch.generative.layers.model_config as cfg
21
23
  from ai_edge_torch.generative.quantize import quant_recipes
22
24
  import torch
23
25
 
@@ -25,109 +27,74 @@ import torch
25
27
  def convert_to_tflite(
26
28
  pytorch_model: torch.nn.Module,
27
29
  tflite_path: str,
28
- prefill_seq_len: int = 512,
30
+ prefill_seq_len: Union[int, list[int]],
31
+ pixel_values_size: torch.Size = None,
29
32
  quantize: bool = True,
33
+ config: cfg.ModelConfig = None,
30
34
  ):
31
35
  """Converts a nn.Module model to multi-signature tflite model.
32
36
 
33
- A PyTorch model will be converted to a tflite model with two signatures:
34
- "prefill" and "decode".
35
-
36
- "prefill" signature takes a tensor of shape [1, prefill_seq_len] of token
37
- sequence, a tensor of shape [1, prefill_seq_len] of token positions, and an
38
- external KV cache as a sample input.
39
-
40
- "decode" signature takes a tensor of shape [1, 1] of token sequence, a tensor
41
- of shape [1, 1] of the token position, and an external KV cache as a sample
42
- input.
37
+ A PyTorch model will be converted to a tflite model with several signatures:
38
+ * "prefill_[prefill_seq_len]" (or "prefill" if only one prefill_seq_len is
39
+ passed),
40
+ * "prefill_[preill_seq_len]_pixel" (or "prefill_pixel" if only one
41
+ prefill_seq_len is passed) if num_pixel_values > 0, and
42
+ * "decode".
43
+
44
+ "prefill_[prefill_seq_len]" (or "prefill" if only one prefill_seq_len is
45
+ passed) signature takes as a sample input:
46
+ * a tensor of shape [1, prefill_seq_len] of token sequence,
47
+ * a tensor of shape [1, prefill_seq_len] of token positions, and
48
+ * an external KV cache.
49
+
50
+ If num_pixel_values > 0, "prefill_[prefill_seq_len]_pixel" (or "prefill_pixel"
51
+ if only one prefill_seq_len is passed) signature takes as a sample input:
52
+ * a tensor of shape [1, prefill_seq_len] of token sequence,
53
+ * a tensor of shape [1, prefill_seq_len] of token positions,
54
+ * an external KV cache, and
55
+ * a tensor of shape [1, num_pixel_values] of pixel values.
56
+
57
+ "decode" signature takes as a sample input:
58
+ * a tensor of shape [1, 1] of token sequence,
59
+ * a tensor of shape [1, 1] of the token position, and
60
+ * an external KV cache.
43
61
 
44
62
  The final tflite model will be exported to tflite_path.
45
63
 
46
64
  Args:
47
65
  pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
48
66
  tflite_path (str): The tflite file path to export.
49
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
50
- Defaults to 512.
67
+ prefill_seq_len (Union[int, list[int]]): A list of prefill lengths to
68
+ export.
69
+ pixel_values_size (torch.Size, optional): The size of pixel values to pass
70
+ to the model. If None, the model is not expected to take pixel values.
51
71
  quantize (bool, optional): Whether the model should be quanized. Defaults
52
72
  to True.
73
+ config (cfg.ModelConfig, optional): The model config used to configure KV
74
+ cache. If None, it uses the config of the pytorch_model.
53
75
  """
54
- # Tensors used to trace the model graph during conversion.
55
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
56
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
57
- decode_token = torch.tensor([[0]], dtype=torch.int)
58
- decode_input_pos = torch.tensor([0], dtype=torch.int)
59
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
60
-
61
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
62
- edge_model = (
63
- ai_edge_torch.signature(
64
- 'prefill',
65
- pytorch_model,
66
- sample_kwargs={
67
- 'tokens': prefill_tokens,
68
- 'input_pos': prefill_input_pos,
69
- 'kv_cache': kv,
70
- },
71
- )
72
- .signature(
73
- 'decode',
74
- pytorch_model,
75
- sample_kwargs={
76
- 'tokens': decode_token,
77
- 'input_pos': decode_input_pos,
78
- 'kv_cache': kv,
79
- },
80
- )
81
- .convert(quant_config=quant_config)
76
+ prefill_seq_lens = (
77
+ [prefill_seq_len] if isinstance(prefill_seq_len, int) else prefill_seq_len
82
78
  )
83
- edge_model.export(tflite_path)
84
-
85
-
86
- def convert_to_tflite_multi_prefill_lens(
87
- pytorch_model: torch.nn.Module,
88
- tflite_path: str,
89
- prefill_seq_lens: list[int],
90
- quantize: bool = True,
91
- ):
92
- """Converts a nn.Module model to multi-signature tflite model with different
93
-
94
- prefill lengths.
95
-
96
- A PyTorch model will be converted to a tflite model with several signatures:
97
- "prefill_[prefill_seq_len]" and "decode".
98
-
99
- "prefill_[prefill_seq_len]" signature takes a tensor of shape [1,
100
- prefill_seq_len] of token
101
- sequence, a tensor of shape [1, prefill_seq_len] of token positions, and an
102
- external KV cache as a sample input.
103
-
104
- "decode" signature takes a tensor of shape [1, 1] of token sequence, a tensor
105
- of shape [1, 1] of the token position, and an external KV cache as a sample
106
- input.
107
-
108
- The final tflite model will be exported to tflite_path.
109
79
 
110
- Args:
111
- pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
112
- tflite_path (str): The tflite file path to export.
113
- prefill_seq_lens (list[int]): A list of prefill lengths to export.
114
- quantize (bool, optional): Whether the model should be quanized. Defaults
115
- to True.
116
- """
117
80
  # Tensors used to trace the model graph during conversion.
118
81
  prefill_tokens_list = []
119
82
  prefill_input_pos_list = []
120
- for prefill_seq_len in prefill_seq_lens:
121
- prefill_tokens_list.append(
122
- torch.full((1, prefill_seq_len), 0, dtype=torch.int)
123
- )
124
- prefill_input_pos_list.append(
125
- torch.arange(0, prefill_seq_len, dtype=torch.int)
126
- )
83
+ for seq_len in prefill_seq_lens:
84
+ prefill_tokens_list.append(torch.full((1, seq_len), 0, dtype=torch.int))
85
+ prefill_input_pos_list.append(torch.arange(0, seq_len, dtype=torch.int))
86
+
87
+ prefill_pixel_values = (
88
+ torch.full((1,) + pixel_values_size, 0, dtype=torch.float32)
89
+ if pixel_values_size
90
+ else None
91
+ )
127
92
 
128
93
  decode_token = torch.tensor([[0]], dtype=torch.int)
129
94
  decode_input_pos = torch.tensor([0], dtype=torch.int)
130
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
95
+ kv = kv_utils.KVCache.from_model_config(
96
+ config if config else pytorch_model.config
97
+ )
131
98
 
132
99
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
133
100
  converter = converter_utils.Converter()
@@ -135,8 +102,12 @@ def convert_to_tflite_multi_prefill_lens(
135
102
  prefill_seq_len = prefill_seq_lens[i]
136
103
  prefill_tokens = prefill_tokens_list[i]
137
104
  prefill_input_pos = prefill_input_pos_list[i]
105
+ if i == 0 and len(prefill_seq_lens) == 1:
106
+ prefill_signature_name = 'prefill'
107
+ else:
108
+ prefill_signature_name = f'prefill_{prefill_seq_len}'
138
109
  converter.add_signature(
139
- f'prefill_{prefill_seq_len}',
110
+ prefill_signature_name,
140
111
  pytorch_model,
141
112
  sample_kwargs={
142
113
  'tokens': prefill_tokens,
@@ -144,8 +115,19 @@ def convert_to_tflite_multi_prefill_lens(
144
115
  'kv_cache': kv,
145
116
  },
146
117
  )
118
+ if prefill_pixel_values is not None:
119
+ converter.add_signature(
120
+ prefill_signature_name + '_pixel',
121
+ pytorch_model,
122
+ sample_kwargs={
123
+ 'tokens': prefill_tokens,
124
+ 'input_pos': prefill_input_pos,
125
+ 'kv_cache': kv,
126
+ 'pixel_values': prefill_pixel_values,
127
+ },
128
+ )
147
129
 
148
- edge_model = converter.add_signature(
130
+ converter.add_signature(
149
131
  'decode',
150
132
  pytorch_model,
151
133
  sample_kwargs={
@@ -153,5 +135,7 @@ def convert_to_tflite_multi_prefill_lens(
153
135
  'input_pos': decode_input_pos,
154
136
  'kv_cache': kv,
155
137
  },
156
- ).convert(quant_config=quant_config)
138
+ )
139
+
140
+ edge_model = converter.convert(quant_config=quant_config)
157
141
  edge_model.export(tflite_path)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241117"
16
+ __version__ = "0.3.0.dev20241119"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241117
3
+ Version: 0.3.0.dev20241119
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=7_Q9eL2px6cC84NORS1VimuJ-a5-0vrWBZ38lYckMtA,706
6
+ ai_edge_torch/version.py,sha256=W10ztMY91LZfiK-COm46eLLfufHNyUl2W0DtuM8zeC4,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -45,8 +45,7 @@ ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=-n7
45
45
  ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
46
46
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
47
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
48
- ai_edge_torch/generative/examples/gemma/convert_gemma2_multi_prefills.py,sha256=bZKOiAJBWPzIVHdASEgKRUFdyZSPVGFfe3uXUYrRh1c,2868
49
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
48
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=Mss7j3yDhyOFCWA93iWh995CLeNBDTVG-gvpj6WBIp0,2226
50
49
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
51
50
  ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
52
51
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
@@ -61,9 +60,10 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKF
61
60
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
62
61
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
63
62
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
63
+ ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=dT7dnx1dzGzFiH5gQJ4M6zcTLSRFvSDpi3IuZ9_vd78,2706
64
64
  ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=XMeznGBbjRJidv725L6_7XzkYskS2cDjf8NGB18FNhg,4944
65
- ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=v19_EKALhAP9FjkINKqpv8JsVaQ6iH_7X5FpnhE6abw,5500
66
- ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=mbq9CBp2znXPIQdzIQTiQGRh4Ql3bn9kyX-k_LXKTms,4537
65
+ ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
66
+ ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=pIjsS-IUFevRjFA9153YT1vtWXATGWHsgVQQX_nWaZQ,5280
67
67
  ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
68
68
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
69
69
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
@@ -135,12 +135,12 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
135
135
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
136
136
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
137
137
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
138
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=a4TzSw8KMxEafirxqkykZi-WgTs5Z7wHp-J1AfjRDzA,6353
139
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=TzBEbWOoB7bIHePuP6ySL9eYfmKHpONgTQCU-f05m8c,9497
138
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=aZFaheg2sq7rEccch1TZM6W4BSfpJZjrM9Gyp4hVGYs,6351
139
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=xWV9O2wuRHc4VNBWuWipiuqXa3AJhiV1nmjewAZHHWM,11177
140
140
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
141
- ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
141
+ ai_edge_torch/generative/test/utils.py,sha256=eQ-hjd1eXuHJF3SJK6_CrjgOZVzmG_4VEdH7Z1gH_lA,1897
142
142
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
143
- ai_edge_torch/generative/utilities/converter.py,sha256=17O83wVifH1vQJCI4WC3DaNiCIOyK2gys1GzohbLrRs,5554
143
+ ai_edge_torch/generative/utilities/converter.py,sha256=S14STbyxV6A9HKy1BdUo49f2jS6Ij0RL9mVAFUMWYV8,5291
144
144
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
145
145
  ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
146
146
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
@@ -193,8 +193,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
193
193
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
194
194
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
195
195
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
196
- ai_edge_torch_nightly-0.3.0.dev20241117.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
197
- ai_edge_torch_nightly-0.3.0.dev20241117.dist-info/METADATA,sha256=66qws6AuCk-8cDSyPhBpsjrASeUYB3G2PMwJG5BeflI,1897
198
- ai_edge_torch_nightly-0.3.0.dev20241117.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
199
- ai_edge_torch_nightly-0.3.0.dev20241117.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
200
- ai_edge_torch_nightly-0.3.0.dev20241117.dist-info/RECORD,,
196
+ ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
197
+ ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/METADATA,sha256=ikosXUollu7saRd3GUi1IZw78jvgCShWAhtDF3NuUtE,1897
198
+ ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
199
+ ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
200
+ ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/RECORD,,