ai-edge-torch-nightly 0.3.0.dev20250204__py3-none-any.whl → 0.3.0.dev20250205__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -13,11 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of converting a PaliGemma model to multi-signature tflite model.
17
-
18
- DISCLAIMER: It works only with ODML Torch conversion backend. Refer to
19
- https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#use-odml-torch-conversion-backend-experimental.
20
- """
16
+ """Example of converting a PaliGemma model to multi-signature tflite model."""
21
17
 
22
18
  import os
23
19
  import pathlib
@@ -55,7 +55,6 @@ class Decoder(model_builder.DecoderOnlyModel):
55
55
  input_embeds: torch.Tensor = None,
56
56
  mask: Optional[torch.Tensor] = None,
57
57
  export_config: Optional[model_builder.ExportConfig] = None,
58
- called_by_generate: bool = True,
59
58
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
60
59
  if input_embeds is None:
61
60
  return super().forward(
@@ -64,11 +63,11 @@ class Decoder(model_builder.DecoderOnlyModel):
64
63
 
65
64
  assert input_embeds is not None
66
65
 
67
- repo_pos = input_pos + 1 # PaliGemma position is 1-based.
66
+ rope_pos = input_pos + 1 # PaliGemma position is 1-based.
68
67
  # ROPE parameters for all attn_configs are the same. Take the first one.
69
68
  attn_config = self.config.block_config(0).attn_config
70
69
  n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
71
- rope = self.config.build_rope(repo_pos, n_elem, attn_config.rotary_base)
70
+ rope = self.config.build_rope(rope_pos, n_elem, attn_config.rotary_base)
72
71
 
73
72
  # The first part of input_embeds are image embeddings. Diagonal causal mask
74
73
  # doesn't work here.
@@ -58,34 +58,23 @@ class Decoder2(gemma2.Gemma2):
58
58
  input_embeds: torch.Tensor = None,
59
59
  mask: Optional[torch.Tensor] = None,
60
60
  export_config: Optional[model_builder.ExportConfig] = None,
61
- called_by_generate: bool = True,
62
61
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
63
62
  if input_embeds is None:
64
63
  return super().forward(tokens, input_pos, kv_cache, mask, export_config)
65
64
 
66
65
  assert input_embeds is not None
67
66
 
68
- repo_pos = input_pos + 1 # PaliGemma2 position is 1-based.
67
+ rope_pos = input_pos + 1 # PaliGemma2 position is 1-based.
69
68
  # ROPE parameters for all attn_configs are the same. Take the first one.
70
69
  attn_config = self.config.block_config(0).attn_config
71
70
  n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
72
- rope = self.config.build_rope(repo_pos, n_elem, attn_config.rotary_base)
71
+ rope = self.config.build_rope(rope_pos, n_elem, attn_config.rotary_base)
73
72
 
74
73
  if mask is None:
75
- if called_by_generate:
76
- # PaliGemma2 generate() uses a diagonal causal mask even with image
77
- # embeds.
78
- mask = [
79
- self.get_attention_mask(
80
- self.config.block_config(i).attn_config.attn_type, input_pos
81
- )
82
- for i in range(self.config.num_layers)
83
- ]
84
- else:
85
- # By default, don't mask image embeds with a diagonal causal mask.
86
- embeds_len = input_embeds.shape[1]
87
- mask = torch.zeros(embeds_len, self.config.kv_cache_max)
88
- mask[:, embeds_len:] = float("-inf")
74
+ # By default, don't mask image embeds with a diagonal causal mask.
75
+ embeds_len = input_embeds.shape[1]
76
+ mask = torch.zeros(embeds_len, self.config.kv_cache_max)
77
+ mask[:, embeds_len:] = float("-inf")
89
78
 
90
79
  return self._forward_with_embeds(
91
80
  input_embeds, rope, mask, input_pos, kv_cache, export_config
@@ -15,7 +15,7 @@
15
15
 
16
16
  """Example of building a full-stack of PaliGemma model."""
17
17
 
18
- from dataclasses import dataclass
18
+ import dataclasses
19
19
  from typing import Optional
20
20
 
21
21
  from ai_edge_torch.generative.examples.paligemma import decoder
@@ -31,7 +31,7 @@ from torch import nn
31
31
  PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"
32
32
 
33
33
 
34
- @dataclass
34
+ @dataclasses.dataclass
35
35
  class PaliGemmaConfig:
36
36
  """PaliGemma model configurations."""
37
37
 
@@ -39,7 +39,6 @@ class PaliGemmaConfig:
39
39
  decoder_config: cfg.ModelConfig
40
40
 
41
41
  image_token_id: int
42
- image_projection_scale: float
43
42
  image_projection_use_bias: bool = False
44
43
 
45
44
 
@@ -73,7 +72,6 @@ class PaliGemma(nn.Module):
73
72
  mask: Optional[torch.Tensor] = None,
74
73
  pixel_values: torch.Tensor = None,
75
74
  export_config: Optional[model_builder.ExportConfig] = None,
76
- called_by_generate: bool = True,
77
75
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
78
76
  if pixel_values is None:
79
77
  return self.decoder(
@@ -83,14 +81,13 @@ class PaliGemma(nn.Module):
83
81
  mask=mask,
84
82
  input_embeds=None,
85
83
  export_config=export_config,
86
- called_by_generate=called_by_generate,
87
84
  )
88
85
 
89
86
  input_embeds = self.decoder.tok_embedding(tokens)
90
87
 
91
88
  image_encoded = self.image_encoder(pixel_values=pixel_values)
92
89
  image_embeds = self.image_projection(image_encoded)
93
- image_embeds = image_embeds / self.config.image_projection_scale
90
+ image_embeds = image_embeds / self.config.decoder_config.embedding_scale
94
91
 
95
92
  # Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
96
93
  # can be done like:
@@ -116,7 +113,6 @@ class PaliGemma(nn.Module):
116
113
  mask=mask,
117
114
  input_embeds=input_embeds,
118
115
  export_config=export_config,
119
- called_by_generate=called_by_generate,
120
116
  )
121
117
 
122
118
 
@@ -130,7 +126,6 @@ def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
130
126
  image_encoder_config=image_encoder.get_image_encoder_config(),
131
127
  decoder_config=get_decoder_config(**kwargs),
132
128
  image_token_id=257152,
133
- image_projection_scale=2048**0.5,
134
129
  image_projection_use_bias=True,
135
130
  )
136
131
 
@@ -140,7 +135,6 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
140
135
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
141
136
  decoder_config=get_decoder_config(**kwargs),
142
137
  image_token_id=127,
143
- image_projection_scale=128**0.5,
144
138
  image_projection_use_bias=True,
145
139
  )
146
140
 
@@ -41,7 +41,7 @@ _IMAGE_URL = flags.DEFINE_string(
41
41
  )
42
42
  _PROMPTS = flags.DEFINE_string(
43
43
  "prompts",
44
- "describe en",
44
+ "<image><bos>describe en",
45
45
  "The input prompts to generate answers.",
46
46
  )
47
47
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
@@ -59,16 +59,9 @@ _CHECKPOINT = {
59
59
  class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
60
60
  """Reauthored PaliGemma model wrapper."""
61
61
 
62
- def __init__(self, model: torch.nn.Module):
63
- super().__init__(model)
64
- self.forward_called_by_generate = False
65
-
66
62
  def _init_kv_cache(self):
67
63
  return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
68
64
 
69
- def _get_extra_args_for_forward(self):
70
- return {"called_by_generate": self.forward_called_by_generate}
71
-
72
65
 
73
66
  def main(_):
74
67
  if _VERSION.value == "1":
@@ -137,7 +130,6 @@ def main(_):
137
130
  logging.info("outputs_from_original_model: [[%s]]", response_original)
138
131
 
139
132
  logging.info("Generating answer with the reauthored model...")
140
- wrapped_reauthored_model.forward_called_by_generate = True
141
133
  outputs_reauthored = wrapped_reauthored_model.generate(
142
134
  prompts=inputs["input_ids"],
143
135
  pixel_values=inputs["pixel_values"],
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250204"
16
+ __version__ = "0.3.0.dev20250205"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250204
3
+ Version: 0.3.0.dev20250205
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=4XOGz1x6yfOnkOtBndF7qE1L3Ma12ZMJNwQ7wIWkyEs,706
5
+ ai_edge_torch/version.py,sha256=3qCqU6b85lrBJn0A7eFSW9dGx1TkEsCXhffIwwFwUv4,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -73,12 +73,12 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=pyxRGgMxrn
73
73
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
74
74
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
75
75
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
76
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=scLsguzzuHfKYDWUd2uZkKYVRzdAbQHLd-kPam8QwvM,3004
77
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=S_W-0ojRu2Vd5SLNPs1kC-70xHB8AdSWslm-yPxyezk,5478
78
- ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=W009ky-yobueTzdaybSCqBAvNyArLXW3jDyp5MarzZU,6376
76
+ ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=a6ISb96xhEJc1TtaFGCUiA4msKedPTAeMvkWrfIklx4,2792
77
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=z658dW_D0Iqvo6xnh4vG7_o17-Fufndyis8Rq5yafJY,5439
78
+ ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=GZa0Ou_DvOijB2nTL_jRvGbn0_dvJPosQAPf47yqicw,5988
79
79
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=7K1xl64UvoHaYmqWjIbahwXHfppwTQ8sN7JrpGKX1XQ,5771
80
- ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=CEMG9gh51ev1KXPew927a6nfampiXX9bL6m-25tNYN8,6340
81
- ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
80
+ ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=x1mgRtVLxkCTvlkPow3y7ADoGTjUh5uc5pF46mxatLw,6099
81
+ ai_edge_torch/generative/examples/paligemma/verify.py,sha256=HLcu1fWMtFFFONAqVW94rOBqq4XvFHtatX3JFGOsfZw,5345
82
82
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
83
83
  ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
84
84
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
@@ -227,8 +227,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
227
227
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
228
228
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
229
229
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
230
- ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
231
- ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/METADATA,sha256=Rf4w5EMQlNWOoFIuVlXUZPU9vmXlOJW7oB4yPrtgK0c,1966
232
- ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
233
- ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
234
- ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/RECORD,,
230
+ ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
231
+ ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/METADATA,sha256=F9YG6dtQw7Vh9T4m0C2z4JAiddvpobcdY-Rxjmh4WX4,1966
232
+ ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
233
+ ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
234
+ ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/RECORD,,