ai-edge-torch-nightly 0.3.0.dev20250204__py3-none-any.whl → 0.3.0.dev20250205__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,11 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of converting a PaliGemma model to multi-signature tflite model.
17
-
18
- DISCLAIMER: It works only with ODML Torch conversion backend. Refer to
19
- https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#use-odml-torch-conversion-backend-experimental.
20
- """
16
+ """Example of converting a PaliGemma model to multi-signature tflite model."""
21
17
 
22
18
  import os
23
19
  import pathlib
@@ -55,7 +55,6 @@ class Decoder(model_builder.DecoderOnlyModel):
55
55
  input_embeds: torch.Tensor = None,
56
56
  mask: Optional[torch.Tensor] = None,
57
57
  export_config: Optional[model_builder.ExportConfig] = None,
58
- called_by_generate: bool = True,
59
58
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
60
59
  if input_embeds is None:
61
60
  return super().forward(
@@ -64,11 +63,11 @@ class Decoder(model_builder.DecoderOnlyModel):
64
63
 
65
64
  assert input_embeds is not None
66
65
 
67
- repo_pos = input_pos + 1 # PaliGemma position is 1-based.
66
+ rope_pos = input_pos + 1 # PaliGemma position is 1-based.
68
67
  # ROPE parameters for all attn_configs are the same. Take the first one.
69
68
  attn_config = self.config.block_config(0).attn_config
70
69
  n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
71
- rope = self.config.build_rope(repo_pos, n_elem, attn_config.rotary_base)
70
+ rope = self.config.build_rope(rope_pos, n_elem, attn_config.rotary_base)
72
71
 
73
72
  # The first part of input_embeds are image embeddings. Diagonal causal mask
74
73
  # doesn't work here.
@@ -58,34 +58,23 @@ class Decoder2(gemma2.Gemma2):
58
58
  input_embeds: torch.Tensor = None,
59
59
  mask: Optional[torch.Tensor] = None,
60
60
  export_config: Optional[model_builder.ExportConfig] = None,
61
- called_by_generate: bool = True,
62
61
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
63
62
  if input_embeds is None:
64
63
  return super().forward(tokens, input_pos, kv_cache, mask, export_config)
65
64
 
66
65
  assert input_embeds is not None
67
66
 
68
- repo_pos = input_pos + 1 # PaliGemma2 position is 1-based.
67
+ rope_pos = input_pos + 1 # PaliGemma2 position is 1-based.
69
68
  # ROPE parameters for all attn_configs are the same. Take the first one.
70
69
  attn_config = self.config.block_config(0).attn_config
71
70
  n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
72
- rope = self.config.build_rope(repo_pos, n_elem, attn_config.rotary_base)
71
+ rope = self.config.build_rope(rope_pos, n_elem, attn_config.rotary_base)
73
72
 
74
73
  if mask is None:
75
- if called_by_generate:
76
- # PaliGemma2 generate() uses a diagonal causal mask even with image
77
- # embeds.
78
- mask = [
79
- self.get_attention_mask(
80
- self.config.block_config(i).attn_config.attn_type, input_pos
81
- )
82
- for i in range(self.config.num_layers)
83
- ]
84
- else:
85
- # By default, don't mask image embeds with a diagonal causal mask.
86
- embeds_len = input_embeds.shape[1]
87
- mask = torch.zeros(embeds_len, self.config.kv_cache_max)
88
- mask[:, embeds_len:] = float("-inf")
74
+ # By default, don't mask image embeds with a diagonal causal mask.
75
+ embeds_len = input_embeds.shape[1]
76
+ mask = torch.zeros(embeds_len, self.config.kv_cache_max)
77
+ mask[:, embeds_len:] = float("-inf")
89
78
 
90
79
  return self._forward_with_embeds(
91
80
  input_embeds, rope, mask, input_pos, kv_cache, export_config
@@ -15,7 +15,7 @@
15
15
 
16
16
  """Example of building a full-stack of PaliGemma model."""
17
17
 
18
- from dataclasses import dataclass
18
+ import dataclasses
19
19
  from typing import Optional
20
20
 
21
21
  from ai_edge_torch.generative.examples.paligemma import decoder
@@ -31,7 +31,7 @@ from torch import nn
31
31
  PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"
32
32
 
33
33
 
34
- @dataclass
34
+ @dataclasses.dataclass
35
35
  class PaliGemmaConfig:
36
36
  """PaliGemma model configurations."""
37
37
 
@@ -39,7 +39,6 @@ class PaliGemmaConfig:
39
39
  decoder_config: cfg.ModelConfig
40
40
 
41
41
  image_token_id: int
42
- image_projection_scale: float
43
42
  image_projection_use_bias: bool = False
44
43
 
45
44
 
@@ -73,7 +72,6 @@ class PaliGemma(nn.Module):
73
72
  mask: Optional[torch.Tensor] = None,
74
73
  pixel_values: torch.Tensor = None,
75
74
  export_config: Optional[model_builder.ExportConfig] = None,
76
- called_by_generate: bool = True,
77
75
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
78
76
  if pixel_values is None:
79
77
  return self.decoder(
@@ -83,14 +81,13 @@ class PaliGemma(nn.Module):
83
81
  mask=mask,
84
82
  input_embeds=None,
85
83
  export_config=export_config,
86
- called_by_generate=called_by_generate,
87
84
  )
88
85
 
89
86
  input_embeds = self.decoder.tok_embedding(tokens)
90
87
 
91
88
  image_encoded = self.image_encoder(pixel_values=pixel_values)
92
89
  image_embeds = self.image_projection(image_encoded)
93
- image_embeds = image_embeds / self.config.image_projection_scale
90
+ image_embeds = image_embeds / self.config.decoder_config.embedding_scale
94
91
 
95
92
  # Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
96
93
  # can be done like:
@@ -116,7 +113,6 @@ class PaliGemma(nn.Module):
116
113
  mask=mask,
117
114
  input_embeds=input_embeds,
118
115
  export_config=export_config,
119
- called_by_generate=called_by_generate,
120
116
  )
121
117
 
122
118
 
@@ -130,7 +126,6 @@ def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
130
126
  image_encoder_config=image_encoder.get_image_encoder_config(),
131
127
  decoder_config=get_decoder_config(**kwargs),
132
128
  image_token_id=257152,
133
- image_projection_scale=2048**0.5,
134
129
  image_projection_use_bias=True,
135
130
  )
136
131
 
@@ -140,7 +135,6 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
140
135
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
141
136
  decoder_config=get_decoder_config(**kwargs),
142
137
  image_token_id=127,
143
- image_projection_scale=128**0.5,
144
138
  image_projection_use_bias=True,
145
139
  )
146
140
 
@@ -41,7 +41,7 @@ _IMAGE_URL = flags.DEFINE_string(
41
41
  )
42
42
  _PROMPTS = flags.DEFINE_string(
43
43
  "prompts",
44
- "describe en",
44
+ "<image><bos>describe en",
45
45
  "The input prompts to generate answers.",
46
46
  )
47
47
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
@@ -59,16 +59,9 @@ _CHECKPOINT = {
59
59
  class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
60
60
  """Reauthored PaliGemma model wrapper."""
61
61
 
62
- def __init__(self, model: torch.nn.Module):
63
- super().__init__(model)
64
- self.forward_called_by_generate = False
65
-
66
62
  def _init_kv_cache(self):
67
63
  return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
68
64
 
69
- def _get_extra_args_for_forward(self):
70
- return {"called_by_generate": self.forward_called_by_generate}
71
-
72
65
 
73
66
  def main(_):
74
67
  if _VERSION.value == "1":
@@ -137,7 +130,6 @@ def main(_):
137
130
  logging.info("outputs_from_original_model: [[%s]]", response_original)
138
131
 
139
132
  logging.info("Generating answer with the reauthored model...")
140
- wrapped_reauthored_model.forward_called_by_generate = True
141
133
  outputs_reauthored = wrapped_reauthored_model.generate(
142
134
  prompts=inputs["input_ids"],
143
135
  pixel_values=inputs["pixel_values"],
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250204"
16
+ __version__ = "0.3.0.dev20250205"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250204
3
+ Version: 0.3.0.dev20250205
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=4XOGz1x6yfOnkOtBndF7qE1L3Ma12ZMJNwQ7wIWkyEs,706
5
+ ai_edge_torch/version.py,sha256=3qCqU6b85lrBJn0A7eFSW9dGx1TkEsCXhffIwwFwUv4,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -73,12 +73,12 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=pyxRGgMxrn
73
73
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
74
74
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
75
75
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
76
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=scLsguzzuHfKYDWUd2uZkKYVRzdAbQHLd-kPam8QwvM,3004
77
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=S_W-0ojRu2Vd5SLNPs1kC-70xHB8AdSWslm-yPxyezk,5478
78
- ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=W009ky-yobueTzdaybSCqBAvNyArLXW3jDyp5MarzZU,6376
76
+ ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=a6ISb96xhEJc1TtaFGCUiA4msKedPTAeMvkWrfIklx4,2792
77
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=z658dW_D0Iqvo6xnh4vG7_o17-Fufndyis8Rq5yafJY,5439
78
+ ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=GZa0Ou_DvOijB2nTL_jRvGbn0_dvJPosQAPf47yqicw,5988
79
79
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=7K1xl64UvoHaYmqWjIbahwXHfppwTQ8sN7JrpGKX1XQ,5771
80
- ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=CEMG9gh51ev1KXPew927a6nfampiXX9bL6m-25tNYN8,6340
81
- ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
80
+ ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=x1mgRtVLxkCTvlkPow3y7ADoGTjUh5uc5pF46mxatLw,6099
81
+ ai_edge_torch/generative/examples/paligemma/verify.py,sha256=HLcu1fWMtFFFONAqVW94rOBqq4XvFHtatX3JFGOsfZw,5345
82
82
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
83
83
  ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
84
84
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
@@ -227,8 +227,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
227
227
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
228
228
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
229
229
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
230
- ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
231
- ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/METADATA,sha256=Rf4w5EMQlNWOoFIuVlXUZPU9vmXlOJW7oB4yPrtgK0c,1966
232
- ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
233
- ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
234
- ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/RECORD,,
230
+ ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
231
+ ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/METADATA,sha256=F9YG6dtQw7Vh9T4m0C2z4JAiddvpobcdY-Rxjmh4WX4,1966
232
+ ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
233
+ ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
234
+ ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/RECORD,,