ai-edge-torch-nightly 0.3.0.dev20250204__py3-none-any.whl → 0.3.0.dev20250205__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +1 -5
- ai_edge_torch/generative/examples/paligemma/decoder.py +2 -3
- ai_edge_torch/generative/examples/paligemma/decoder2.py +6 -17
- ai_edge_torch/generative/examples/paligemma/paligemma.py +3 -9
- ai_edge_torch/generative/examples/paligemma/verify.py +1 -9
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250204.dist-info → ai_edge_torch_nightly-0.3.0.dev20250205.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250204.dist-info → ai_edge_torch_nightly-0.3.0.dev20250205.dist-info}/RECORD +11 -11
- {ai_edge_torch_nightly-0.3.0.dev20250204.dist-info → ai_edge_torch_nightly-0.3.0.dev20250205.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250204.dist-info → ai_edge_torch_nightly-0.3.0.dev20250205.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250204.dist-info → ai_edge_torch_nightly-0.3.0.dev20250205.dist-info}/top_level.txt +0 -0
@@ -13,11 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""Example of converting a PaliGemma model to multi-signature tflite model.
|
17
|
-
|
18
|
-
DISCLAIMER: It works only with ODML Torch conversion backend. Refer to
|
19
|
-
https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#use-odml-torch-conversion-backend-experimental.
|
20
|
-
"""
|
16
|
+
"""Example of converting a PaliGemma model to multi-signature tflite model."""
|
21
17
|
|
22
18
|
import os
|
23
19
|
import pathlib
|
@@ -55,7 +55,6 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
55
55
|
input_embeds: torch.Tensor = None,
|
56
56
|
mask: Optional[torch.Tensor] = None,
|
57
57
|
export_config: Optional[model_builder.ExportConfig] = None,
|
58
|
-
called_by_generate: bool = True,
|
59
58
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
60
59
|
if input_embeds is None:
|
61
60
|
return super().forward(
|
@@ -64,11 +63,11 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
64
63
|
|
65
64
|
assert input_embeds is not None
|
66
65
|
|
67
|
-
|
66
|
+
rope_pos = input_pos + 1 # PaliGemma position is 1-based.
|
68
67
|
# ROPE parameters for all attn_configs are the same. Take the first one.
|
69
68
|
attn_config = self.config.block_config(0).attn_config
|
70
69
|
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
71
|
-
rope = self.config.build_rope(
|
70
|
+
rope = self.config.build_rope(rope_pos, n_elem, attn_config.rotary_base)
|
72
71
|
|
73
72
|
# The first part of input_embeds are image embeddings. Diagonal causal mask
|
74
73
|
# doesn't work here.
|
@@ -58,34 +58,23 @@ class Decoder2(gemma2.Gemma2):
|
|
58
58
|
input_embeds: torch.Tensor = None,
|
59
59
|
mask: Optional[torch.Tensor] = None,
|
60
60
|
export_config: Optional[model_builder.ExportConfig] = None,
|
61
|
-
called_by_generate: bool = True,
|
62
61
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
63
62
|
if input_embeds is None:
|
64
63
|
return super().forward(tokens, input_pos, kv_cache, mask, export_config)
|
65
64
|
|
66
65
|
assert input_embeds is not None
|
67
66
|
|
68
|
-
|
67
|
+
rope_pos = input_pos + 1 # PaliGemma2 position is 1-based.
|
69
68
|
# ROPE parameters for all attn_configs are the same. Take the first one.
|
70
69
|
attn_config = self.config.block_config(0).attn_config
|
71
70
|
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
72
|
-
rope = self.config.build_rope(
|
71
|
+
rope = self.config.build_rope(rope_pos, n_elem, attn_config.rotary_base)
|
73
72
|
|
74
73
|
if mask is None:
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
self.get_attention_mask(
|
80
|
-
self.config.block_config(i).attn_config.attn_type, input_pos
|
81
|
-
)
|
82
|
-
for i in range(self.config.num_layers)
|
83
|
-
]
|
84
|
-
else:
|
85
|
-
# By default, don't mask image embeds with a diagonal causal mask.
|
86
|
-
embeds_len = input_embeds.shape[1]
|
87
|
-
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
88
|
-
mask[:, embeds_len:] = float("-inf")
|
74
|
+
# By default, don't mask image embeds with a diagonal causal mask.
|
75
|
+
embeds_len = input_embeds.shape[1]
|
76
|
+
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
77
|
+
mask[:, embeds_len:] = float("-inf")
|
89
78
|
|
90
79
|
return self._forward_with_embeds(
|
91
80
|
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Example of building a full-stack of PaliGemma model."""
|
17
17
|
|
18
|
-
|
18
|
+
import dataclasses
|
19
19
|
from typing import Optional
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.examples.paligemma import decoder
|
@@ -31,7 +31,7 @@ from torch import nn
|
|
31
31
|
PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"
|
32
32
|
|
33
33
|
|
34
|
-
@dataclass
|
34
|
+
@dataclasses.dataclass
|
35
35
|
class PaliGemmaConfig:
|
36
36
|
"""PaliGemma model configurations."""
|
37
37
|
|
@@ -39,7 +39,6 @@ class PaliGemmaConfig:
|
|
39
39
|
decoder_config: cfg.ModelConfig
|
40
40
|
|
41
41
|
image_token_id: int
|
42
|
-
image_projection_scale: float
|
43
42
|
image_projection_use_bias: bool = False
|
44
43
|
|
45
44
|
|
@@ -73,7 +72,6 @@ class PaliGemma(nn.Module):
|
|
73
72
|
mask: Optional[torch.Tensor] = None,
|
74
73
|
pixel_values: torch.Tensor = None,
|
75
74
|
export_config: Optional[model_builder.ExportConfig] = None,
|
76
|
-
called_by_generate: bool = True,
|
77
75
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
78
76
|
if pixel_values is None:
|
79
77
|
return self.decoder(
|
@@ -83,14 +81,13 @@ class PaliGemma(nn.Module):
|
|
83
81
|
mask=mask,
|
84
82
|
input_embeds=None,
|
85
83
|
export_config=export_config,
|
86
|
-
called_by_generate=called_by_generate,
|
87
84
|
)
|
88
85
|
|
89
86
|
input_embeds = self.decoder.tok_embedding(tokens)
|
90
87
|
|
91
88
|
image_encoded = self.image_encoder(pixel_values=pixel_values)
|
92
89
|
image_embeds = self.image_projection(image_encoded)
|
93
|
-
image_embeds = image_embeds / self.config.
|
90
|
+
image_embeds = image_embeds / self.config.decoder_config.embedding_scale
|
94
91
|
|
95
92
|
# Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
|
96
93
|
# can be done like:
|
@@ -116,7 +113,6 @@ class PaliGemma(nn.Module):
|
|
116
113
|
mask=mask,
|
117
114
|
input_embeds=input_embeds,
|
118
115
|
export_config=export_config,
|
119
|
-
called_by_generate=called_by_generate,
|
120
116
|
)
|
121
117
|
|
122
118
|
|
@@ -130,7 +126,6 @@ def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
|
|
130
126
|
image_encoder_config=image_encoder.get_image_encoder_config(),
|
131
127
|
decoder_config=get_decoder_config(**kwargs),
|
132
128
|
image_token_id=257152,
|
133
|
-
image_projection_scale=2048**0.5,
|
134
129
|
image_projection_use_bias=True,
|
135
130
|
)
|
136
131
|
|
@@ -140,7 +135,6 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
|
|
140
135
|
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
141
136
|
decoder_config=get_decoder_config(**kwargs),
|
142
137
|
image_token_id=127,
|
143
|
-
image_projection_scale=128**0.5,
|
144
138
|
image_projection_use_bias=True,
|
145
139
|
)
|
146
140
|
|
@@ -41,7 +41,7 @@ _IMAGE_URL = flags.DEFINE_string(
|
|
41
41
|
)
|
42
42
|
_PROMPTS = flags.DEFINE_string(
|
43
43
|
"prompts",
|
44
|
-
"describe en",
|
44
|
+
"<image><bos>describe en",
|
45
45
|
"The input prompts to generate answers.",
|
46
46
|
)
|
47
47
|
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
@@ -59,16 +59,9 @@ _CHECKPOINT = {
|
|
59
59
|
class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
|
60
60
|
"""Reauthored PaliGemma model wrapper."""
|
61
61
|
|
62
|
-
def __init__(self, model: torch.nn.Module):
|
63
|
-
super().__init__(model)
|
64
|
-
self.forward_called_by_generate = False
|
65
|
-
|
66
62
|
def _init_kv_cache(self):
|
67
63
|
return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
|
68
64
|
|
69
|
-
def _get_extra_args_for_forward(self):
|
70
|
-
return {"called_by_generate": self.forward_called_by_generate}
|
71
|
-
|
72
65
|
|
73
66
|
def main(_):
|
74
67
|
if _VERSION.value == "1":
|
@@ -137,7 +130,6 @@ def main(_):
|
|
137
130
|
logging.info("outputs_from_original_model: [[%s]]", response_original)
|
138
131
|
|
139
132
|
logging.info("Generating answer with the reauthored model...")
|
140
|
-
wrapped_reauthored_model.forward_called_by_generate = True
|
141
133
|
outputs_reauthored = wrapped_reauthored_model.generate(
|
142
134
|
prompts=inputs["input_ids"],
|
143
135
|
pixel_values=inputs["pixel_values"],
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250205
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=3qCqU6b85lrBJn0A7eFSW9dGx1TkEsCXhffIwwFwUv4,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -73,12 +73,12 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=pyxRGgMxrn
|
|
73
73
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
|
74
74
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
75
75
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
76
|
-
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=
|
77
|
-
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=
|
78
|
-
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=
|
76
|
+
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=a6ISb96xhEJc1TtaFGCUiA4msKedPTAeMvkWrfIklx4,2792
|
77
|
+
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=z658dW_D0Iqvo6xnh4vG7_o17-Fufndyis8Rq5yafJY,5439
|
78
|
+
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=GZa0Ou_DvOijB2nTL_jRvGbn0_dvJPosQAPf47yqicw,5988
|
79
79
|
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=7K1xl64UvoHaYmqWjIbahwXHfppwTQ8sN7JrpGKX1XQ,5771
|
80
|
-
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=
|
81
|
-
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=
|
80
|
+
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=x1mgRtVLxkCTvlkPow3y7ADoGTjUh5uc5pF46mxatLw,6099
|
81
|
+
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=HLcu1fWMtFFFONAqVW94rOBqq4XvFHtatX3JFGOsfZw,5345
|
82
82
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
83
83
|
ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
|
84
84
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
|
@@ -227,8 +227,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
227
227
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
228
228
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
229
229
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
230
|
-
ai_edge_torch_nightly-0.3.0.
|
231
|
-
ai_edge_torch_nightly-0.3.0.
|
232
|
-
ai_edge_torch_nightly-0.3.0.
|
233
|
-
ai_edge_torch_nightly-0.3.0.
|
234
|
-
ai_edge_torch_nightly-0.3.0.
|
230
|
+
ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
231
|
+
ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/METADATA,sha256=F9YG6dtQw7Vh9T4m0C2z4JAiddvpobcdY-Rxjmh4WX4,1966
|
232
|
+
ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
233
|
+
ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
234
|
+
ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/RECORD,,
|
File without changes
|
File without changes
|