ai-edge-torch-nightly 0.3.0.dev20250204__py3-none-any.whl → 0.3.0.dev20250205__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +1 -5
- ai_edge_torch/generative/examples/paligemma/decoder.py +2 -3
- ai_edge_torch/generative/examples/paligemma/decoder2.py +6 -17
- ai_edge_torch/generative/examples/paligemma/paligemma.py +3 -9
- ai_edge_torch/generative/examples/paligemma/verify.py +1 -9
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250204.dist-info → ai_edge_torch_nightly-0.3.0.dev20250205.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250204.dist-info → ai_edge_torch_nightly-0.3.0.dev20250205.dist-info}/RECORD +11 -11
- {ai_edge_torch_nightly-0.3.0.dev20250204.dist-info → ai_edge_torch_nightly-0.3.0.dev20250205.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250204.dist-info → ai_edge_torch_nightly-0.3.0.dev20250205.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250204.dist-info → ai_edge_torch_nightly-0.3.0.dev20250205.dist-info}/top_level.txt +0 -0
@@ -13,11 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""Example of converting a PaliGemma model to multi-signature tflite model.
|
17
|
-
|
18
|
-
DISCLAIMER: It works only with ODML Torch conversion backend. Refer to
|
19
|
-
https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#use-odml-torch-conversion-backend-experimental.
|
20
|
-
"""
|
16
|
+
"""Example of converting a PaliGemma model to multi-signature tflite model."""
|
21
17
|
|
22
18
|
import os
|
23
19
|
import pathlib
|
@@ -55,7 +55,6 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
55
55
|
input_embeds: torch.Tensor = None,
|
56
56
|
mask: Optional[torch.Tensor] = None,
|
57
57
|
export_config: Optional[model_builder.ExportConfig] = None,
|
58
|
-
called_by_generate: bool = True,
|
59
58
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
60
59
|
if input_embeds is None:
|
61
60
|
return super().forward(
|
@@ -64,11 +63,11 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
64
63
|
|
65
64
|
assert input_embeds is not None
|
66
65
|
|
67
|
-
|
66
|
+
rope_pos = input_pos + 1 # PaliGemma position is 1-based.
|
68
67
|
# ROPE parameters for all attn_configs are the same. Take the first one.
|
69
68
|
attn_config = self.config.block_config(0).attn_config
|
70
69
|
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
71
|
-
rope = self.config.build_rope(
|
70
|
+
rope = self.config.build_rope(rope_pos, n_elem, attn_config.rotary_base)
|
72
71
|
|
73
72
|
# The first part of input_embeds are image embeddings. Diagonal causal mask
|
74
73
|
# doesn't work here.
|
@@ -58,34 +58,23 @@ class Decoder2(gemma2.Gemma2):
|
|
58
58
|
input_embeds: torch.Tensor = None,
|
59
59
|
mask: Optional[torch.Tensor] = None,
|
60
60
|
export_config: Optional[model_builder.ExportConfig] = None,
|
61
|
-
called_by_generate: bool = True,
|
62
61
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
63
62
|
if input_embeds is None:
|
64
63
|
return super().forward(tokens, input_pos, kv_cache, mask, export_config)
|
65
64
|
|
66
65
|
assert input_embeds is not None
|
67
66
|
|
68
|
-
|
67
|
+
rope_pos = input_pos + 1 # PaliGemma2 position is 1-based.
|
69
68
|
# ROPE parameters for all attn_configs are the same. Take the first one.
|
70
69
|
attn_config = self.config.block_config(0).attn_config
|
71
70
|
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
72
|
-
rope = self.config.build_rope(
|
71
|
+
rope = self.config.build_rope(rope_pos, n_elem, attn_config.rotary_base)
|
73
72
|
|
74
73
|
if mask is None:
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
self.get_attention_mask(
|
80
|
-
self.config.block_config(i).attn_config.attn_type, input_pos
|
81
|
-
)
|
82
|
-
for i in range(self.config.num_layers)
|
83
|
-
]
|
84
|
-
else:
|
85
|
-
# By default, don't mask image embeds with a diagonal causal mask.
|
86
|
-
embeds_len = input_embeds.shape[1]
|
87
|
-
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
88
|
-
mask[:, embeds_len:] = float("-inf")
|
74
|
+
# By default, don't mask image embeds with a diagonal causal mask.
|
75
|
+
embeds_len = input_embeds.shape[1]
|
76
|
+
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
77
|
+
mask[:, embeds_len:] = float("-inf")
|
89
78
|
|
90
79
|
return self._forward_with_embeds(
|
91
80
|
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Example of building a full-stack of PaliGemma model."""
|
17
17
|
|
18
|
-
|
18
|
+
import dataclasses
|
19
19
|
from typing import Optional
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.examples.paligemma import decoder
|
@@ -31,7 +31,7 @@ from torch import nn
|
|
31
31
|
PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"
|
32
32
|
|
33
33
|
|
34
|
-
@dataclass
|
34
|
+
@dataclasses.dataclass
|
35
35
|
class PaliGemmaConfig:
|
36
36
|
"""PaliGemma model configurations."""
|
37
37
|
|
@@ -39,7 +39,6 @@ class PaliGemmaConfig:
|
|
39
39
|
decoder_config: cfg.ModelConfig
|
40
40
|
|
41
41
|
image_token_id: int
|
42
|
-
image_projection_scale: float
|
43
42
|
image_projection_use_bias: bool = False
|
44
43
|
|
45
44
|
|
@@ -73,7 +72,6 @@ class PaliGemma(nn.Module):
|
|
73
72
|
mask: Optional[torch.Tensor] = None,
|
74
73
|
pixel_values: torch.Tensor = None,
|
75
74
|
export_config: Optional[model_builder.ExportConfig] = None,
|
76
|
-
called_by_generate: bool = True,
|
77
75
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
78
76
|
if pixel_values is None:
|
79
77
|
return self.decoder(
|
@@ -83,14 +81,13 @@ class PaliGemma(nn.Module):
|
|
83
81
|
mask=mask,
|
84
82
|
input_embeds=None,
|
85
83
|
export_config=export_config,
|
86
|
-
called_by_generate=called_by_generate,
|
87
84
|
)
|
88
85
|
|
89
86
|
input_embeds = self.decoder.tok_embedding(tokens)
|
90
87
|
|
91
88
|
image_encoded = self.image_encoder(pixel_values=pixel_values)
|
92
89
|
image_embeds = self.image_projection(image_encoded)
|
93
|
-
image_embeds = image_embeds / self.config.
|
90
|
+
image_embeds = image_embeds / self.config.decoder_config.embedding_scale
|
94
91
|
|
95
92
|
# Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
|
96
93
|
# can be done like:
|
@@ -116,7 +113,6 @@ class PaliGemma(nn.Module):
|
|
116
113
|
mask=mask,
|
117
114
|
input_embeds=input_embeds,
|
118
115
|
export_config=export_config,
|
119
|
-
called_by_generate=called_by_generate,
|
120
116
|
)
|
121
117
|
|
122
118
|
|
@@ -130,7 +126,6 @@ def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
|
|
130
126
|
image_encoder_config=image_encoder.get_image_encoder_config(),
|
131
127
|
decoder_config=get_decoder_config(**kwargs),
|
132
128
|
image_token_id=257152,
|
133
|
-
image_projection_scale=2048**0.5,
|
134
129
|
image_projection_use_bias=True,
|
135
130
|
)
|
136
131
|
|
@@ -140,7 +135,6 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
|
|
140
135
|
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
141
136
|
decoder_config=get_decoder_config(**kwargs),
|
142
137
|
image_token_id=127,
|
143
|
-
image_projection_scale=128**0.5,
|
144
138
|
image_projection_use_bias=True,
|
145
139
|
)
|
146
140
|
|
@@ -41,7 +41,7 @@ _IMAGE_URL = flags.DEFINE_string(
|
|
41
41
|
)
|
42
42
|
_PROMPTS = flags.DEFINE_string(
|
43
43
|
"prompts",
|
44
|
-
"describe en",
|
44
|
+
"<image><bos>describe en",
|
45
45
|
"The input prompts to generate answers.",
|
46
46
|
)
|
47
47
|
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
@@ -59,16 +59,9 @@ _CHECKPOINT = {
|
|
59
59
|
class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
|
60
60
|
"""Reauthored PaliGemma model wrapper."""
|
61
61
|
|
62
|
-
def __init__(self, model: torch.nn.Module):
|
63
|
-
super().__init__(model)
|
64
|
-
self.forward_called_by_generate = False
|
65
|
-
|
66
62
|
def _init_kv_cache(self):
|
67
63
|
return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
|
68
64
|
|
69
|
-
def _get_extra_args_for_forward(self):
|
70
|
-
return {"called_by_generate": self.forward_called_by_generate}
|
71
|
-
|
72
65
|
|
73
66
|
def main(_):
|
74
67
|
if _VERSION.value == "1":
|
@@ -137,7 +130,6 @@ def main(_):
|
|
137
130
|
logging.info("outputs_from_original_model: [[%s]]", response_original)
|
138
131
|
|
139
132
|
logging.info("Generating answer with the reauthored model...")
|
140
|
-
wrapped_reauthored_model.forward_called_by_generate = True
|
141
133
|
outputs_reauthored = wrapped_reauthored_model.generate(
|
142
134
|
prompts=inputs["input_ids"],
|
143
135
|
pixel_values=inputs["pixel_values"],
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250205
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=3qCqU6b85lrBJn0A7eFSW9dGx1TkEsCXhffIwwFwUv4,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -73,12 +73,12 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=pyxRGgMxrn
|
|
73
73
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
|
74
74
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
75
75
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
76
|
-
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=
|
77
|
-
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=
|
78
|
-
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=
|
76
|
+
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=a6ISb96xhEJc1TtaFGCUiA4msKedPTAeMvkWrfIklx4,2792
|
77
|
+
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=z658dW_D0Iqvo6xnh4vG7_o17-Fufndyis8Rq5yafJY,5439
|
78
|
+
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=GZa0Ou_DvOijB2nTL_jRvGbn0_dvJPosQAPf47yqicw,5988
|
79
79
|
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=7K1xl64UvoHaYmqWjIbahwXHfppwTQ8sN7JrpGKX1XQ,5771
|
80
|
-
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=
|
81
|
-
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=
|
80
|
+
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=x1mgRtVLxkCTvlkPow3y7ADoGTjUh5uc5pF46mxatLw,6099
|
81
|
+
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=HLcu1fWMtFFFONAqVW94rOBqq4XvFHtatX3JFGOsfZw,5345
|
82
82
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
83
83
|
ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
|
84
84
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
|
@@ -227,8 +227,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
227
227
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
228
228
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
229
229
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
230
|
-
ai_edge_torch_nightly-0.3.0.
|
231
|
-
ai_edge_torch_nightly-0.3.0.
|
232
|
-
ai_edge_torch_nightly-0.3.0.
|
233
|
-
ai_edge_torch_nightly-0.3.0.
|
234
|
-
ai_edge_torch_nightly-0.3.0.
|
230
|
+
ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
231
|
+
ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/METADATA,sha256=F9YG6dtQw7Vh9T4m0C2z4JAiddvpobcdY-Rxjmh4WX4,1966
|
232
|
+
ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
233
|
+
ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
234
|
+
ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/RECORD,,
|
File without changes
|
File without changes
|