ai-edge-torch-nightly 0.3.0.dev20241116__py3-none-any.whl → 0.3.0.dev20241119__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_multi_integer(
37
37
  'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -13,20 +13,24 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of converting a Gemma2 model to multi-signature tflite model, with multiple prefill lengths."""
16
+ """Example of converting a PaliGemma model to multi-signature tflite model.
17
+
18
+ DISCLAIMER: It works only with ODML Torch conversion backend. Refer to
19
+ https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#use-odml-torch-conversion-backend-experimental.
20
+ """
17
21
 
18
- import logging
19
22
  import os
20
23
  import pathlib
21
24
 
22
25
  from absl import app
23
26
  from absl import flags
24
- from ai_edge_torch.generative.examples.gemma import gemma2
27
+ from ai_edge_torch.generative.examples.paligemma import paligemma
25
28
  from ai_edge_torch.generative.utilities import converter
29
+ import torch
26
30
 
27
31
  _CHECKPOINT_PATH = flags.DEFINE_string(
28
32
  'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
33
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma-3b-224'),
30
34
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
35
  )
32
36
  _TFLITE_PATH = flags.DEFINE_string(
@@ -34,16 +38,21 @@ _TFLITE_PATH = flags.DEFINE_string(
34
38
  '/tmp/',
35
39
  'The tflite file path to export.',
36
40
  )
37
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
41
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
38
42
  'prefill_seq_len',
39
- (8, 64, 128, 256, 512, 1024),
40
- 'A list of prefill lengths to export.',
43
+ 1024,
44
+ 'The maximum size of prefill input tensor.',
41
45
  )
42
46
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
43
47
  'kv_cache_max_len',
44
48
  1280,
45
49
  'The maximum size of KV cache buffer, including both prefill and decode.',
46
50
  )
51
+ _PIXEL_VALUES_SIZE = flags.DEFINE_multi_integer(
52
+ 'pixel_values_size',
53
+ [3, 224, 224],
54
+ 'The size of prefill pixel values except the batch dimension.',
55
+ )
47
56
  _QUANTIZE = flags.DEFINE_bool(
48
57
  'quantize',
49
58
  True,
@@ -51,32 +60,21 @@ _QUANTIZE = flags.DEFINE_bool(
51
60
  )
52
61
 
53
62
 
54
- # Note that the converted model is not compatible with LLM Inference engine for
55
- # now. The main purpose for this function is to allow you export a tflite model
56
- # with multiple prefill signatures for different prefill lengths for faster
57
- # inference.
58
- def convert_to_tflite_multi_prefill_lens():
59
- pytorch_model = gemma2.build_2b_model(
63
+ def main(_):
64
+ pytorch_model = paligemma.build_model(
60
65
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
61
66
  )
62
67
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
63
- output_filename = f'gemma2_{quant_suffix}_multi-prefill-seq_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
64
- converter.convert_to_tflite_multi_prefill_lens(
68
+ output_filename = f'paligemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
69
+ converter.convert_to_tflite(
65
70
  pytorch_model,
66
71
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
67
- prefill_seq_lens=_PREFILL_SEQ_LENS.value,
72
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
73
+ pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
68
74
  quantize=_QUANTIZE.value,
75
+ config=pytorch_model.config.decoder_config,
69
76
  )
70
77
 
71
78
 
72
- def main(_):
73
- if len(_PREFILL_SEQ_LENS.value) > 1:
74
- # If multiple prefill lengths are provided, export a model with multiple
75
- # prefill signatures each for a different prefill length.
76
- convert_to_tflite_multi_prefill_lens()
77
- else:
78
- logging.warning('Need more than one prefill lengths to be specified.')
79
-
80
-
81
79
  if __name__ == '__main__':
82
80
  app.run(main)
@@ -59,7 +59,7 @@ class SiglipVisionEncoder(nn.Module):
59
59
  out_channels=config.embedding_dim,
60
60
  kernel_size=config.image_embedding.patch_size,
61
61
  stride=config.image_embedding.patch_size,
62
- padding="valid",
62
+ padding=0,
63
63
  )
64
64
  num_patches = (
65
65
  config.image_embedding.image_size // config.image_embedding.patch_size
@@ -144,6 +144,8 @@ def get_fake_image_encoder_config() -> cfg.ModelConfig:
144
144
  config = get_image_encoder_config()
145
145
  # PaliGemma image encoder has only one block config.
146
146
  config.block_config(0).ff_config.intermediate_size = 128
147
+ config.image_embedding.image_size = 8
148
+ config.image_embedding.patch_size = 2
147
149
  config.num_layers = 2
148
150
  return config
149
151
 
@@ -54,6 +54,10 @@ class PaliGemma(nn.Module):
54
54
  bias=config.image_projection_use_bias,
55
55
  )
56
56
  self.decoder = decoder.Decoder(config.decoder_config)
57
+ image_embedding_config = config.image_encoder_config.image_embedding
58
+ self.num_patches = (
59
+ image_embedding_config.image_size // image_embedding_config.patch_size
60
+ ) ** 2
57
61
  self.config = config
58
62
 
59
63
  @torch.inference_mode
@@ -74,10 +78,22 @@ class PaliGemma(nn.Module):
74
78
  if self.config.decoder_config.embedding_scale is not None:
75
79
  image_embeds = image_embeds / self.config.decoder_config.embedding_scale
76
80
 
77
- # Merge image_embeds into text_embeds as PaliGemmaForConditionalGeneration.
78
- image_mask = tokens == self.config.image_token_id
79
- image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
80
- input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
81
+ # Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
82
+ # can be done like:
83
+ #
84
+ # image_mask = tokens == self.config.image_token_id
85
+ # image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
86
+ # input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
87
+ #
88
+ # Unfortunately, torch.Tensor.masked_scatter can't be lowered on CPU.
89
+ # Since PaliGemma token embedder reserves the first [num_patches] tokens
90
+ # for image tokens, we can use this property to merge image_embeds into
91
+ # input_embeds by concatenating them.
92
+ assert image_embeds.shape[1] == self.num_patches
93
+ assert input_embeds.shape[1] >= self.num_patches
94
+ input_embeds = torch.cat(
95
+ (image_embeds, input_embeds[:, self.num_patches:, :]), dim=1
96
+ )
81
97
 
82
98
  return self.decoder(
83
99
  tokens=None,
@@ -87,7 +103,7 @@ class PaliGemma(nn.Module):
87
103
  )
88
104
 
89
105
 
90
- def get_model_config() -> PaliGemmaConfig:
106
+ def get_model_config(**kwargs) -> PaliGemmaConfig:
91
107
  """Returns the model config for a PaliGemma 3B-224 model.
92
108
 
93
109
  Returns:
@@ -95,13 +111,13 @@ def get_model_config() -> PaliGemmaConfig:
95
111
  """
96
112
  return PaliGemmaConfig(
97
113
  image_encoder_config=image_encoder.get_image_encoder_config(),
98
- decoder_config=decoder.get_decoder_config(),
114
+ decoder_config=decoder.get_decoder_config(**kwargs),
99
115
  image_projection_use_bias=True,
100
116
  image_token_id=257152,
101
117
  )
102
118
 
103
119
 
104
- def get_fake_image_encoder_config() -> PaliGemmaConfig:
120
+ def get_fake_model_config() -> PaliGemmaConfig:
105
121
  return PaliGemmaConfig(
106
122
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
107
123
  decoder_config=decoder.get_fake_decoder_config(),
@@ -110,8 +126,8 @@ def get_fake_image_encoder_config() -> PaliGemmaConfig:
110
126
  )
111
127
 
112
128
 
113
- def build_model(checkpoint_path: str) -> PaliGemma:
114
- config = get_model_config()
129
+ def build_model(checkpoint_path: str, **kwargs) -> PaliGemma:
130
+ config = get_model_config(**kwargs)
115
131
  model = PaliGemma(config)
116
132
  # Load the parameters of image encoder.
117
133
  loader = loading_utils.ModelLoader(
@@ -117,7 +117,7 @@ class TestModelConversion(googletest.TestCase):
117
117
  def _test_multisig_model(self, config, pytorch_model, atol, rtol):
118
118
  # prefill
119
119
  seq_len = 10
120
- prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.int, device="cpu")
120
+ prefill_tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
121
121
  prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
122
122
  prefill_tokens[0, : len(prompt_token)] = prompt_token
123
123
  prefill_input_pos = torch.arange(0, seq_len, dtype=torch.int)
@@ -22,6 +22,7 @@ from ai_edge_torch.generative.examples.gemma import gemma1
22
22
  from ai_edge_torch.generative.examples.gemma import gemma2
23
23
  from ai_edge_torch.generative.examples.llama import llama
24
24
  from ai_edge_torch.generative.examples.openelm import openelm
25
+ from ai_edge_torch.generative.examples.paligemma import paligemma
25
26
  from ai_edge_torch.generative.examples.phi import phi2
26
27
  from ai_edge_torch.generative.examples.phi import phi3
27
28
  from ai_edge_torch.generative.examples.qwen import qwen
@@ -55,7 +56,7 @@ class TestModelConversion(googletest.TestCase):
55
56
 
56
57
  def _test_model(self, config, model, signature_name, atol, rtol):
57
58
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
58
- tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
59
+ tokens = torch.zeros((1, 10), dtype=torch.int, device="cpu")
59
60
  tokens[0, :4] = idx
60
61
  input_pos = torch.arange(0, 10, dtype=torch.int)
61
62
  kv = kv_cache.KVCache.from_model_config(config)
@@ -171,6 +172,54 @@ class TestModelConversion(googletest.TestCase):
171
172
  pytorch_model = model_builder.DecoderOnlyModel(config).eval()
172
173
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
173
174
 
175
+ @googletest.skipIf(
176
+ ai_edge_config.Config.use_torch_xla,
177
+ reason="tests with custom ops are not supported on oss",
178
+ )
179
+ def test_paligemma(self):
180
+ config = paligemma.get_fake_model_config()
181
+ pytorch_model = paligemma.PaliGemma(config).eval()
182
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
183
+ image_embedding_config = config.image_encoder_config.image_embedding
184
+ num_patches = (
185
+ image_embedding_config.image_size // image_embedding_config.patch_size
186
+ ) ** 2
187
+ # Make sure the token size is longer than the number of image patches.
188
+ tokens_len = num_patches + 10
189
+ tokens = torch.zeros((1, tokens_len), dtype=torch.int, device="cpu")
190
+ tokens[0, :4] = idx
191
+ input_pos = torch.arange(0, tokens_len, dtype=torch.int)
192
+ kv = kv_cache.KVCache.from_model_config(config.decoder_config)
193
+ pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32, device="cpu")
194
+
195
+ edge_model = ai_edge_torch.signature(
196
+ "prefill_pixel",
197
+ pytorch_model,
198
+ sample_kwargs={
199
+ "tokens": tokens,
200
+ "input_pos": input_pos,
201
+ "kv_cache": kv,
202
+ "pixel_values": pixel_values,
203
+ },
204
+ ).convert()
205
+ edge_model.set_interpreter_builder(
206
+ self._interpreter_builder(edge_model.tflite_model())
207
+ )
208
+
209
+ self.assertTrue(
210
+ test_utils.compare_tflite_torch(
211
+ edge_model,
212
+ pytorch_model,
213
+ tokens,
214
+ input_pos,
215
+ kv,
216
+ pixel_values=pixel_values,
217
+ signature_name="prefill_pixel",
218
+ atol=1e-3,
219
+ rtol=1e-5,
220
+ )
221
+ )
222
+
174
223
  @googletest.skipIf(
175
224
  ai_edge_config.Config.use_torch_xla,
176
225
  reason="tests with custom ops are not supported on oss",
@@ -32,18 +32,21 @@ def compare_tflite_torch(
32
32
  signature_name: str,
33
33
  atol: float = 1e-5,
34
34
  rtol: float = 1e-5,
35
+ **kwargs,
35
36
  ):
36
37
  """Compares torch models and TFLite models."""
37
38
  values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
38
39
  flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
39
- torch_output = torch_model(tokens, input_pos, kv_cache)
40
+ torch_output = torch_model(tokens, input_pos, kv_cache, **kwargs)
40
41
 
41
- input_kv_flatten = {k: v.numpy() for k, v in zip(flat_names, values)}
42
+ if "pixel_values" in kwargs:
43
+ kwargs["pixel_values"] = kwargs["pixel_values"].numpy()
44
+ kwargs.update({k: v.numpy() for k, v in zip(flat_names, values)})
42
45
  edge_output = edge_model(
43
46
  signature_name=signature_name,
44
47
  tokens=tokens.numpy(),
45
48
  input_pos=input_pos.numpy(),
46
- **input_kv_flatten,
49
+ **kwargs,
47
50
  )
48
51
 
49
52
  return np.allclose(
@@ -15,9 +15,11 @@
15
15
 
16
16
  """Common utility functions for model conversion."""
17
17
 
18
- import ai_edge_torch
18
+ from typing import Union
19
+
19
20
  from ai_edge_torch._convert import converter as converter_utils
20
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
+ import ai_edge_torch.generative.layers.kv_cache as kv_utils
22
+ import ai_edge_torch.generative.layers.model_config as cfg
21
23
  from ai_edge_torch.generative.quantize import quant_recipes
22
24
  import torch
23
25
 
@@ -25,109 +27,74 @@ import torch
25
27
  def convert_to_tflite(
26
28
  pytorch_model: torch.nn.Module,
27
29
  tflite_path: str,
28
- prefill_seq_len: int = 512,
30
+ prefill_seq_len: Union[int, list[int]],
31
+ pixel_values_size: torch.Size = None,
29
32
  quantize: bool = True,
33
+ config: cfg.ModelConfig = None,
30
34
  ):
31
35
  """Converts a nn.Module model to multi-signature tflite model.
32
36
 
33
- A PyTorch model will be converted to a tflite model with two signatures:
34
- "prefill" and "decode".
35
-
36
- "prefill" signature takes a tensor of shape [1, prefill_seq_len] of token
37
- sequence, a tensor of shape [1, prefill_seq_len] of token positions, and an
38
- external KV cache as a sample input.
39
-
40
- "decode" signature takes a tensor of shape [1, 1] of token sequence, a tensor
41
- of shape [1, 1] of the token position, and an external KV cache as a sample
42
- input.
37
+ A PyTorch model will be converted to a tflite model with several signatures:
38
+ * "prefill_[prefill_seq_len]" (or "prefill" if only one prefill_seq_len is
39
+ passed),
40
+ * "prefill_[preill_seq_len]_pixel" (or "prefill_pixel" if only one
41
+ prefill_seq_len is passed) if num_pixel_values > 0, and
42
+ * "decode".
43
+
44
+ "prefill_[prefill_seq_len]" (or "prefill" if only one prefill_seq_len is
45
+ passed) signature takes as a sample input:
46
+ * a tensor of shape [1, prefill_seq_len] of token sequence,
47
+ * a tensor of shape [1, prefill_seq_len] of token positions, and
48
+ * an external KV cache.
49
+
50
+ If num_pixel_values > 0, "prefill_[prefill_seq_len]_pixel" (or "prefill_pixel"
51
+ if only one prefill_seq_len is passed) signature takes as a sample input:
52
+ * a tensor of shape [1, prefill_seq_len] of token sequence,
53
+ * a tensor of shape [1, prefill_seq_len] of token positions,
54
+ * an external KV cache, and
55
+ * a tensor of shape [1, num_pixel_values] of pixel values.
56
+
57
+ "decode" signature takes as a sample input:
58
+ * a tensor of shape [1, 1] of token sequence,
59
+ * a tensor of shape [1, 1] of the token position, and
60
+ * an external KV cache.
43
61
 
44
62
  The final tflite model will be exported to tflite_path.
45
63
 
46
64
  Args:
47
65
  pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
48
66
  tflite_path (str): The tflite file path to export.
49
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
50
- Defaults to 512.
67
+ prefill_seq_len (Union[int, list[int]]): A list of prefill lengths to
68
+ export.
69
+ pixel_values_size (torch.Size, optional): The size of pixel values to pass
70
+ to the model. If None, the model is not expected to take pixel values.
51
71
  quantize (bool, optional): Whether the model should be quanized. Defaults
52
72
  to True.
73
+ config (cfg.ModelConfig, optional): The model config used to configure KV
74
+ cache. If None, it uses the config of the pytorch_model.
53
75
  """
54
- # Tensors used to trace the model graph during conversion.
55
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
56
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
57
- decode_token = torch.tensor([[0]], dtype=torch.int)
58
- decode_input_pos = torch.tensor([0], dtype=torch.int)
59
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
60
-
61
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
62
- edge_model = (
63
- ai_edge_torch.signature(
64
- 'prefill',
65
- pytorch_model,
66
- sample_kwargs={
67
- 'tokens': prefill_tokens,
68
- 'input_pos': prefill_input_pos,
69
- 'kv_cache': kv,
70
- },
71
- )
72
- .signature(
73
- 'decode',
74
- pytorch_model,
75
- sample_kwargs={
76
- 'tokens': decode_token,
77
- 'input_pos': decode_input_pos,
78
- 'kv_cache': kv,
79
- },
80
- )
81
- .convert(quant_config=quant_config)
76
+ prefill_seq_lens = (
77
+ [prefill_seq_len] if isinstance(prefill_seq_len, int) else prefill_seq_len
82
78
  )
83
- edge_model.export(tflite_path)
84
-
85
-
86
- def convert_to_tflite_multi_prefill_lens(
87
- pytorch_model: torch.nn.Module,
88
- tflite_path: str,
89
- prefill_seq_lens: list[int],
90
- quantize: bool = True,
91
- ):
92
- """Converts a nn.Module model to multi-signature tflite model with different
93
-
94
- prefill lengths.
95
-
96
- A PyTorch model will be converted to a tflite model with several signatures:
97
- "prefill_[prefill_seq_len]" and "decode".
98
-
99
- "prefill_[prefill_seq_len]" signature takes a tensor of shape [1,
100
- prefill_seq_len] of token
101
- sequence, a tensor of shape [1, prefill_seq_len] of token positions, and an
102
- external KV cache as a sample input.
103
-
104
- "decode" signature takes a tensor of shape [1, 1] of token sequence, a tensor
105
- of shape [1, 1] of the token position, and an external KV cache as a sample
106
- input.
107
-
108
- The final tflite model will be exported to tflite_path.
109
79
 
110
- Args:
111
- pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
112
- tflite_path (str): The tflite file path to export.
113
- prefill_seq_lens (list[int]): A list of prefill lengths to export.
114
- quantize (bool, optional): Whether the model should be quanized. Defaults
115
- to True.
116
- """
117
80
  # Tensors used to trace the model graph during conversion.
118
81
  prefill_tokens_list = []
119
82
  prefill_input_pos_list = []
120
- for prefill_seq_len in prefill_seq_lens:
121
- prefill_tokens_list.append(
122
- torch.full((1, prefill_seq_len), 0, dtype=torch.int)
123
- )
124
- prefill_input_pos_list.append(
125
- torch.arange(0, prefill_seq_len, dtype=torch.int)
126
- )
83
+ for seq_len in prefill_seq_lens:
84
+ prefill_tokens_list.append(torch.full((1, seq_len), 0, dtype=torch.int))
85
+ prefill_input_pos_list.append(torch.arange(0, seq_len, dtype=torch.int))
86
+
87
+ prefill_pixel_values = (
88
+ torch.full((1,) + pixel_values_size, 0, dtype=torch.float32)
89
+ if pixel_values_size
90
+ else None
91
+ )
127
92
 
128
93
  decode_token = torch.tensor([[0]], dtype=torch.int)
129
94
  decode_input_pos = torch.tensor([0], dtype=torch.int)
130
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
95
+ kv = kv_utils.KVCache.from_model_config(
96
+ config if config else pytorch_model.config
97
+ )
131
98
 
132
99
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
133
100
  converter = converter_utils.Converter()
@@ -135,8 +102,12 @@ def convert_to_tflite_multi_prefill_lens(
135
102
  prefill_seq_len = prefill_seq_lens[i]
136
103
  prefill_tokens = prefill_tokens_list[i]
137
104
  prefill_input_pos = prefill_input_pos_list[i]
105
+ if i == 0 and len(prefill_seq_lens) == 1:
106
+ prefill_signature_name = 'prefill'
107
+ else:
108
+ prefill_signature_name = f'prefill_{prefill_seq_len}'
138
109
  converter.add_signature(
139
- f'prefill_{prefill_seq_len}',
110
+ prefill_signature_name,
140
111
  pytorch_model,
141
112
  sample_kwargs={
142
113
  'tokens': prefill_tokens,
@@ -144,8 +115,19 @@ def convert_to_tflite_multi_prefill_lens(
144
115
  'kv_cache': kv,
145
116
  },
146
117
  )
118
+ if prefill_pixel_values is not None:
119
+ converter.add_signature(
120
+ prefill_signature_name + '_pixel',
121
+ pytorch_model,
122
+ sample_kwargs={
123
+ 'tokens': prefill_tokens,
124
+ 'input_pos': prefill_input_pos,
125
+ 'kv_cache': kv,
126
+ 'pixel_values': prefill_pixel_values,
127
+ },
128
+ )
147
129
 
148
- edge_model = converter.add_signature(
130
+ converter.add_signature(
149
131
  'decode',
150
132
  pytorch_model,
151
133
  sample_kwargs={
@@ -153,5 +135,7 @@ def convert_to_tflite_multi_prefill_lens(
153
135
  'input_pos': decode_input_pos,
154
136
  'kv_cache': kv,
155
137
  },
156
- ).convert(quant_config=quant_config)
138
+ )
139
+
140
+ edge_model = converter.convert(quant_config=quant_config)
157
141
  edge_model.export(tflite_path)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241116"
16
+ __version__ = "0.3.0.dev20241119"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241116
3
+ Version: 0.3.0.dev20241119
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=VA2R7z515pfD79tg2AjlwXASYb6LSz0-kch5NJzdj3k,706
6
+ ai_edge_torch/version.py,sha256=W10ztMY91LZfiK-COm46eLLfufHNyUl2W0DtuM8zeC4,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -45,8 +45,7 @@ ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=-n7
45
45
  ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
46
46
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
47
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
48
- ai_edge_torch/generative/examples/gemma/convert_gemma2_multi_prefills.py,sha256=bZKOiAJBWPzIVHdASEgKRUFdyZSPVGFfe3uXUYrRh1c,2868
49
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
48
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=Mss7j3yDhyOFCWA93iWh995CLeNBDTVG-gvpj6WBIp0,2226
50
49
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
51
50
  ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
52
51
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
@@ -61,9 +60,10 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKF
61
60
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
62
61
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
63
62
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
63
+ ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=dT7dnx1dzGzFiH5gQJ4M6zcTLSRFvSDpi3IuZ9_vd78,2706
64
64
  ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=XMeznGBbjRJidv725L6_7XzkYskS2cDjf8NGB18FNhg,4944
65
- ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=v19_EKALhAP9FjkINKqpv8JsVaQ6iH_7X5FpnhE6abw,5500
66
- ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=mbq9CBp2znXPIQdzIQTiQGRh4Ql3bn9kyX-k_LXKTms,4537
65
+ ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
66
+ ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=pIjsS-IUFevRjFA9153YT1vtWXATGWHsgVQQX_nWaZQ,5280
67
67
  ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
68
68
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
69
69
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
@@ -135,12 +135,12 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
135
135
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
136
136
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
137
137
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
138
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=a4TzSw8KMxEafirxqkykZi-WgTs5Z7wHp-J1AfjRDzA,6353
139
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=TzBEbWOoB7bIHePuP6ySL9eYfmKHpONgTQCU-f05m8c,9497
138
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=aZFaheg2sq7rEccch1TZM6W4BSfpJZjrM9Gyp4hVGYs,6351
139
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=xWV9O2wuRHc4VNBWuWipiuqXa3AJhiV1nmjewAZHHWM,11177
140
140
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
141
- ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
141
+ ai_edge_torch/generative/test/utils.py,sha256=eQ-hjd1eXuHJF3SJK6_CrjgOZVzmG_4VEdH7Z1gH_lA,1897
142
142
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
143
- ai_edge_torch/generative/utilities/converter.py,sha256=17O83wVifH1vQJCI4WC3DaNiCIOyK2gys1GzohbLrRs,5554
143
+ ai_edge_torch/generative/utilities/converter.py,sha256=S14STbyxV6A9HKy1BdUo49f2jS6Ij0RL9mVAFUMWYV8,5291
144
144
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
145
145
  ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
146
146
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
@@ -193,8 +193,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
193
193
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
194
194
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
195
195
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
196
- ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
197
- ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/METADATA,sha256=OyMmJ6EACAhEKbHNgLaGAogbjR8DwCLHYfIDdKW7iMI,1897
198
- ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
199
- ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
200
- ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/RECORD,,
196
+ ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
197
+ ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/METADATA,sha256=ikosXUollu7saRd3GUi1IZw78jvgCShWAhtDF3NuUtE,1897
198
+ ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
199
+ ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
200
+ ai_edge_torch_nightly-0.3.0.dev20241119.dist-info/RECORD,,