ai-edge-torch-nightly 0.3.0.dev20241218__py3-none-any.whl → 0.3.0.dev20241224__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/gemma1.py +3 -2
- ai_edge_torch/generative/examples/gemma/gemma2.py +43 -25
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -3
- ai_edge_torch/generative/examples/paligemma/decoder.py +14 -5
- ai_edge_torch/generative/examples/paligemma/decoder2.py +174 -0
- ai_edge_torch/generative/examples/paligemma/paligemma.py +30 -15
- ai_edge_torch/generative/examples/paligemma/verify.py +36 -9
- ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +24 -7
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/layers/attention.py +4 -29
- ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -27
- ai_edge_torch/generative/test/test_model_conversion_large.py +28 -9
- ai_edge_torch/generative/utilities/model_builder.py +14 -14
- ai_edge_torch/generative/utilities/verifier.py +22 -22
- ai_edge_torch/odml_torch/export.py +6 -1
- ai_edge_torch/odml_torch/jax_bridge/__init__.py +4 -1
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +1 -2
- ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/RECORD +26 -23
- {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/top_level.txt +0 -0
@@ -72,12 +72,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
72
72
|
pre_attention_norm_config=norm_config,
|
73
73
|
post_attention_norm_config=norm_config,
|
74
74
|
)
|
75
|
+
embedding_dim = 2048
|
75
76
|
config = cfg.ModelConfig(
|
76
77
|
vocab_size=256000,
|
77
78
|
num_layers=18,
|
78
79
|
max_seq_len=8192,
|
79
|
-
embedding_dim=
|
80
|
-
embedding_scale=
|
80
|
+
embedding_dim=embedding_dim,
|
81
|
+
embedding_scale=embedding_dim**0.5,
|
81
82
|
kv_cache_max_len=kv_cache_max_len,
|
82
83
|
block_configs=block_config,
|
83
84
|
final_norm_config=norm_config,
|
@@ -15,13 +15,14 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma2 model."""
|
17
17
|
|
18
|
-
from typing import Optional, Tuple
|
18
|
+
from typing import List, Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
25
26
|
from ai_edge_torch.generative.utilities import model_builder
|
26
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
28
|
import torch
|
@@ -103,17 +104,12 @@ class Gemma2(nn.Module):
|
|
103
104
|
config.embedding_dim,
|
104
105
|
config.final_norm_config,
|
105
106
|
)
|
106
|
-
# Gemma2 has same hyper parameters for each layer except for attention
|
107
|
-
# types. Use the first layer.
|
108
|
-
attn_config = config.block_config(0).attn_config
|
109
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
110
|
-
size=config.kv_cache_max,
|
111
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
112
|
-
base=attn_config.rotary_base,
|
113
|
-
)
|
114
107
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
115
108
|
size=config.kv_cache_max,
|
116
109
|
)
|
110
|
+
# Gemma2 has same hyper parameters for each layer except for attention
|
111
|
+
# types. Use the first layer.
|
112
|
+
attn_config = config.block_config(0).attn_config
|
117
113
|
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
118
114
|
size=config.kv_cache_max,
|
119
115
|
window_size=attn_config.sliding_window_size,
|
@@ -140,29 +136,48 @@ class Gemma2(nn.Module):
|
|
140
136
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
141
137
|
f" {self.config.max_seq_len}"
|
142
138
|
)
|
139
|
+
|
140
|
+
# token embeddings of shape (b, t, n_embd)
|
141
|
+
input_embeds = self.tok_embedding(tokens)
|
142
|
+
# RoPE parameters are the same for all blocks. Use the first layer.
|
143
|
+
attn_config = self.config.block_config(0).attn_config
|
144
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
145
|
+
rope = rotary_pos_emb.build_rope(
|
146
|
+
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
147
|
+
)
|
148
|
+
mask = [self.get_attention_mask(
|
149
|
+
self.config.block_config(i).attn_config.attn_type, input_pos
|
150
|
+
) for i in range(self.config.num_layers)]
|
151
|
+
|
152
|
+
return self._forward_with_embeds(
|
153
|
+
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
154
|
+
)
|
155
|
+
|
156
|
+
def _forward_with_embeds(
|
157
|
+
self,
|
158
|
+
input_embeds: torch.Tensor,
|
159
|
+
rope: Tuple[torch.Tensor, torch.Tensor],
|
160
|
+
mask: List[torch.Tensor],
|
161
|
+
input_pos: torch.Tensor,
|
162
|
+
kv_cache: kv_utils.KVCache,
|
163
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
164
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
165
|
+
"""Forwards the model with input embeddings."""
|
143
166
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
144
167
|
"The number of transformer blocks and the number of KV cache entries"
|
145
168
|
" must be the same."
|
146
169
|
)
|
147
170
|
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
# token embeddings of shape (b, t, n_embd)
|
153
|
-
x = self.tok_embedding(tokens)
|
154
|
-
x = x * (self.config.embedding_dim**0.5)
|
155
|
-
|
156
|
-
updated_kv_entires = []
|
171
|
+
if self.config.embedding_scale is not None:
|
172
|
+
input_embeds = input_embeds * self.config.embedding_scale
|
173
|
+
x = input_embeds
|
174
|
+
updated_kv_entries = []
|
157
175
|
for i, block in enumerate(self.transformer_blocks):
|
158
|
-
mask = self.get_attention_mask(
|
159
|
-
block.config.attn_config.attn_type, input_pos
|
160
|
-
)
|
161
176
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
162
|
-
x, kv_entry = block(x,
|
177
|
+
x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
|
163
178
|
if kv_entry:
|
164
|
-
|
165
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
179
|
+
updated_kv_entries.append(kv_entry)
|
180
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
166
181
|
|
167
182
|
if export_config is not None:
|
168
183
|
if (
|
@@ -228,11 +243,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
228
243
|
)
|
229
244
|
|
230
245
|
num_layers = 26
|
246
|
+
embedding_dim = 2304
|
231
247
|
config = cfg.ModelConfig(
|
232
248
|
vocab_size=256000,
|
233
249
|
num_layers=num_layers,
|
234
250
|
max_seq_len=8192,
|
235
|
-
embedding_dim=
|
251
|
+
embedding_dim=embedding_dim,
|
252
|
+
embedding_scale=embedding_dim**0.5,
|
236
253
|
kv_cache_max_len=kv_cache_max_len,
|
237
254
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
238
255
|
final_norm_config=norm_config,
|
@@ -249,6 +266,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
249
266
|
config.num_layers = 2
|
250
267
|
config.max_seq_len = 2 * kv_cache_max_len
|
251
268
|
config.embedding_dim = 128
|
269
|
+
config.embedding_scale = config.embedding_dim**0.5
|
252
270
|
config.block_configs = config.block_configs[: config.num_layers]
|
253
271
|
for block_config in config.block_configs:
|
254
272
|
block_config.attn_config.num_heads = 4
|
@@ -29,9 +29,15 @@ from ai_edge_torch.generative.utilities import converter
|
|
29
29
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
30
30
|
import torch
|
31
31
|
|
32
|
+
_VERSION = flags.DEFINE_enum(
|
33
|
+
'version',
|
34
|
+
'2',
|
35
|
+
['1', '2'],
|
36
|
+
'The version of PaliGemma model to verify.',
|
37
|
+
)
|
32
38
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
33
39
|
'checkpoint_path',
|
34
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
40
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
|
35
41
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
36
42
|
)
|
37
43
|
_TFLITE_PATH = flags.DEFINE_string(
|
@@ -63,10 +69,12 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
63
69
|
|
64
70
|
def main(_):
|
65
71
|
pytorch_model = paligemma.build_model(
|
66
|
-
_CHECKPOINT_PATH.value,
|
72
|
+
_CHECKPOINT_PATH.value,
|
73
|
+
version=int(_VERSION.value),
|
74
|
+
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
|
67
75
|
)
|
68
76
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
69
|
-
output_filename = f'
|
77
|
+
output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
70
78
|
converter.convert_to_tflite(
|
71
79
|
pytorch_model,
|
72
80
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
@@ -19,6 +19,7 @@ from typing import Optional
|
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
22
23
|
from ai_edge_torch.generative.utilities import model_builder
|
23
24
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
24
25
|
import torch
|
@@ -54,6 +55,7 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
54
55
|
kv_cache: kv_utils.KVCache,
|
55
56
|
input_embeds: torch.Tensor = None,
|
56
57
|
export_config: Optional[model_builder.ExportConfig] = None,
|
58
|
+
called_by_generate: bool = True,
|
57
59
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
58
60
|
if input_embeds is None:
|
59
61
|
return super().forward(tokens, input_pos, kv_cache)
|
@@ -61,8 +63,12 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
61
63
|
assert input_embeds is not None
|
62
64
|
|
63
65
|
repo_pos = input_pos + 1 # PaliGemma position is 1-based.
|
64
|
-
|
65
|
-
|
66
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
67
|
+
attn_config = self.config.block_config(0).attn_config
|
68
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
69
|
+
rope = rotary_pos_emb.build_rope(
|
70
|
+
repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
71
|
+
)
|
66
72
|
|
67
73
|
# The first part of input_embeds are image embeddings. Diagonal causal mask
|
68
74
|
# doesn't work here.
|
@@ -70,7 +76,7 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
70
76
|
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
71
77
|
mask[:, embeds_len:] = float("-inf")
|
72
78
|
|
73
|
-
return self.
|
79
|
+
return self._forward_with_embeds(
|
74
80
|
input_embeds, rope, mask, input_pos, kv_cache
|
75
81
|
)
|
76
82
|
|
@@ -108,12 +114,13 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
108
114
|
pre_attention_norm_config=norm_config,
|
109
115
|
post_attention_norm_config=norm_config,
|
110
116
|
)
|
117
|
+
embedding_dim = 2048
|
111
118
|
config = cfg.ModelConfig(
|
112
119
|
vocab_size=257216,
|
113
120
|
num_layers=18,
|
114
121
|
max_seq_len=8192,
|
115
|
-
embedding_dim=
|
116
|
-
embedding_scale=
|
122
|
+
embedding_dim=embedding_dim,
|
123
|
+
embedding_scale=embedding_dim**0.5,
|
117
124
|
kv_cache_max_len=kv_cache_max_len,
|
118
125
|
block_configs=block_config,
|
119
126
|
final_norm_config=norm_config,
|
@@ -130,6 +137,8 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
130
137
|
config.vocab_size = 128
|
131
138
|
config.num_layers = 2
|
132
139
|
config.max_seq_len = 2 * kv_cache_max_len
|
140
|
+
config.embedding_dim = 128
|
141
|
+
config.embedding_scale = 128**0.5
|
133
142
|
return config
|
134
143
|
|
135
144
|
|
@@ -0,0 +1,174 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a decoder of PaliGemma2 3B model which is Gemma2."""
|
17
|
+
|
18
|
+
from typing import Optional
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma2
|
21
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
23
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
24
|
+
from ai_edge_torch.generative.utilities import model_builder
|
25
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
|
+
import torch
|
27
|
+
|
28
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
29
|
+
ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
|
30
|
+
ff_down_proj="language_model.model.layers.{}.mlp.down_proj",
|
31
|
+
ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj",
|
32
|
+
attn_query_proj="language_model.model.layers.{}.self_attn.q_proj",
|
33
|
+
attn_key_proj="language_model.model.layers.{}.self_attn.k_proj",
|
34
|
+
attn_value_proj="language_model.model.layers.{}.self_attn.v_proj",
|
35
|
+
attn_output_proj="language_model.model.layers.{}.self_attn.o_proj",
|
36
|
+
pre_attn_norm="language_model.model.layers.{}.input_layernorm",
|
37
|
+
post_attn_norm="language_model.model.layers.{}.post_attention_layernorm",
|
38
|
+
pre_ff_norm="language_model.model.layers.{}.pre_feedforward_layernorm",
|
39
|
+
post_ff_norm="language_model.model.layers.{}.post_feedforward_layernorm",
|
40
|
+
embedding="language_model.model.embed_tokens",
|
41
|
+
final_norm="language_model.model.norm",
|
42
|
+
lm_head=None,
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class Decoder2(gemma2.Gemma2):
|
47
|
+
"""A decoder of PaliGemma2 3B model which is Gemma2.
|
48
|
+
|
49
|
+
Besides a tensor of text token IDs, forward() can also take a tensor of
|
50
|
+
embeddings which may include text or image or both.
|
51
|
+
"""
|
52
|
+
|
53
|
+
@torch.inference_mode
|
54
|
+
def forward(
|
55
|
+
self,
|
56
|
+
tokens: torch.Tensor,
|
57
|
+
input_pos: torch.Tensor,
|
58
|
+
kv_cache: kv_utils.KVCache,
|
59
|
+
input_embeds: torch.Tensor = None,
|
60
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
61
|
+
called_by_generate: bool = True,
|
62
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
63
|
+
if input_embeds is None:
|
64
|
+
return super().forward(tokens, input_pos, kv_cache)
|
65
|
+
|
66
|
+
assert input_embeds is not None
|
67
|
+
|
68
|
+
repo_pos = input_pos + 1 # PaliGemma2 position is 1-based.
|
69
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
70
|
+
attn_config = self.config.block_config(0).attn_config
|
71
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
72
|
+
rope = rotary_pos_emb.build_rope(
|
73
|
+
repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
74
|
+
)
|
75
|
+
|
76
|
+
if called_by_generate:
|
77
|
+
# PaliGemma2 generate() use a diagonal causal mask even with image embeds.
|
78
|
+
mask = [self.get_attention_mask(
|
79
|
+
self.config.block_config(i).attn_config.attn_type, input_pos
|
80
|
+
) for i in range(self.config.num_layers)]
|
81
|
+
else:
|
82
|
+
# By default, don't mask image embeds with a diagonal causal mask.
|
83
|
+
embeds_len = input_embeds.shape[1]
|
84
|
+
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
85
|
+
mask[:, embeds_len:] = float("-inf")
|
86
|
+
mask = [mask] * self.config.num_layers
|
87
|
+
|
88
|
+
return self._forward_with_embeds(
|
89
|
+
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
90
|
+
)
|
91
|
+
|
92
|
+
|
93
|
+
def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
94
|
+
"""Returns the model config for the decoder of a PaliGemma 3B model.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
98
|
+
is 1024.
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
The model config for the decoder of a PaliGemma 3B model.
|
102
|
+
"""
|
103
|
+
norm_config = cfg.NormalizationConfig(
|
104
|
+
type=cfg.NormalizationType.RMS_NORM,
|
105
|
+
epsilon=1e-6,
|
106
|
+
zero_centered=True,
|
107
|
+
)
|
108
|
+
ff_config = cfg.FeedForwardConfig(
|
109
|
+
type=cfg.FeedForwardType.GATED,
|
110
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
111
|
+
intermediate_size=9216,
|
112
|
+
pre_ff_norm_config=norm_config,
|
113
|
+
post_ff_norm_config=norm_config,
|
114
|
+
)
|
115
|
+
|
116
|
+
def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
|
117
|
+
attn_config = cfg.AttentionConfig(
|
118
|
+
num_heads=8,
|
119
|
+
head_dim=256,
|
120
|
+
num_query_groups=4,
|
121
|
+
rotary_base=10000,
|
122
|
+
rotary_percentage=1.0,
|
123
|
+
logit_softcap=50.0,
|
124
|
+
sliding_window_size=4096,
|
125
|
+
attn_type=(
|
126
|
+
cfg.AttentionType.GLOBAL
|
127
|
+
if idx % 2 == 0
|
128
|
+
else cfg.AttentionType.LOCAL_SLIDING
|
129
|
+
),
|
130
|
+
)
|
131
|
+
return cfg.TransformerBlockConfig(
|
132
|
+
attn_config=attn_config,
|
133
|
+
ff_config=ff_config,
|
134
|
+
pre_attention_norm_config=norm_config,
|
135
|
+
post_attention_norm_config=norm_config,
|
136
|
+
)
|
137
|
+
|
138
|
+
num_layers = 26
|
139
|
+
embedding_dim = 2304
|
140
|
+
config = cfg.ModelConfig(
|
141
|
+
vocab_size=257216,
|
142
|
+
num_layers=num_layers,
|
143
|
+
max_seq_len=8192,
|
144
|
+
embedding_dim=embedding_dim,
|
145
|
+
embedding_scale=embedding_dim**0.5,
|
146
|
+
kv_cache_max_len=kv_cache_max_len,
|
147
|
+
block_configs=[get_block_config(i) for i in range(num_layers)],
|
148
|
+
final_norm_config=norm_config,
|
149
|
+
lm_head_use_bias=False,
|
150
|
+
enable_hlfb=True,
|
151
|
+
final_logit_softcap=30.0,
|
152
|
+
)
|
153
|
+
return config
|
154
|
+
|
155
|
+
|
156
|
+
def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
157
|
+
config = get_decoder2_config(kv_cache_max_len)
|
158
|
+
# PaliGemma2 decoder has only one block config.
|
159
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
160
|
+
config.vocab_size = 128
|
161
|
+
config.num_layers = 2
|
162
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
163
|
+
config.embedding_dim = 128
|
164
|
+
config.embedding_scale = 128**0.5
|
165
|
+
return config
|
166
|
+
|
167
|
+
|
168
|
+
def build_decoder2(checkpoint_path: str, **kwargs) -> torch.nn.Module:
|
169
|
+
return model_builder.build_decoder_only_model(
|
170
|
+
checkpoint_path=checkpoint_path,
|
171
|
+
config=get_decoder2_config(**kwargs),
|
172
|
+
tensor_names=TENSOR_NAMES,
|
173
|
+
model_class=Decoder2,
|
174
|
+
)
|
@@ -19,6 +19,7 @@ from dataclasses import dataclass
|
|
19
19
|
from typing import Optional
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.examples.paligemma import decoder
|
22
|
+
from ai_edge_torch.generative.examples.paligemma import decoder2
|
22
23
|
from ai_edge_torch.generative.examples.paligemma import image_encoder
|
23
24
|
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
24
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
@@ -38,13 +39,14 @@ class PaliGemmaConfig:
|
|
38
39
|
decoder_config: cfg.ModelConfig
|
39
40
|
|
40
41
|
image_token_id: int
|
42
|
+
image_projection_scale: float
|
41
43
|
image_projection_use_bias: bool = False
|
42
44
|
|
43
45
|
|
44
46
|
class PaliGemma(nn.Module):
|
45
47
|
"""PaliGemma model from the Edge Generative API."""
|
46
48
|
|
47
|
-
def __init__(self, config: PaliGemmaConfig):
|
49
|
+
def __init__(self, config: PaliGemmaConfig, decoder_class: nn.Module):
|
48
50
|
super().__init__()
|
49
51
|
|
50
52
|
self.image_encoder = image_encoder.SiglipVisionEncoder(
|
@@ -55,7 +57,7 @@ class PaliGemma(nn.Module):
|
|
55
57
|
config.decoder_config.embedding_dim,
|
56
58
|
bias=config.image_projection_use_bias,
|
57
59
|
)
|
58
|
-
self.decoder =
|
60
|
+
self.decoder = decoder_class(config.decoder_config)
|
59
61
|
image_embedding_config = config.image_encoder_config.image_embedding
|
60
62
|
self.num_patches = (
|
61
63
|
image_embedding_config.image_size // image_embedding_config.patch_size
|
@@ -70,6 +72,7 @@ class PaliGemma(nn.Module):
|
|
70
72
|
kv_cache: kv_utils.KVCache,
|
71
73
|
pixel_values: torch.Tensor = None,
|
72
74
|
export_config: Optional[model_builder.ExportConfig] = None,
|
75
|
+
called_by_generate: bool = True,
|
73
76
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
74
77
|
if pixel_values is None:
|
75
78
|
return self.decoder(
|
@@ -77,15 +80,15 @@ class PaliGemma(nn.Module):
|
|
77
80
|
input_pos=input_pos,
|
78
81
|
kv_cache=kv_cache,
|
79
82
|
input_embeds=None,
|
80
|
-
export_config=export_config
|
83
|
+
export_config=export_config,
|
84
|
+
called_by_generate=called_by_generate,
|
81
85
|
)
|
82
86
|
|
83
87
|
input_embeds = self.decoder.tok_embedding(tokens)
|
84
88
|
|
85
89
|
image_encoded = self.image_encoder(pixel_values=pixel_values)
|
86
90
|
image_embeds = self.image_projection(image_encoded)
|
87
|
-
|
88
|
-
image_embeds = image_embeds / self.config.decoder_config.embedding_scale
|
91
|
+
image_embeds = image_embeds / self.config.image_projection_scale
|
89
92
|
|
90
93
|
# Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
|
91
94
|
# can be done like:
|
@@ -110,10 +113,11 @@ class PaliGemma(nn.Module):
|
|
110
113
|
kv_cache=kv_cache,
|
111
114
|
input_embeds=input_embeds,
|
112
115
|
export_config=export_config,
|
116
|
+
called_by_generate=called_by_generate,
|
113
117
|
)
|
114
118
|
|
115
119
|
|
116
|
-
def get_model_config(**kwargs) -> PaliGemmaConfig:
|
120
|
+
def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
|
117
121
|
"""Returns the model config for a PaliGemma 3B-224 model.
|
118
122
|
|
119
123
|
Returns:
|
@@ -121,31 +125,42 @@ def get_model_config(**kwargs) -> PaliGemmaConfig:
|
|
121
125
|
"""
|
122
126
|
return PaliGemmaConfig(
|
123
127
|
image_encoder_config=image_encoder.get_image_encoder_config(),
|
124
|
-
decoder_config=
|
125
|
-
image_projection_use_bias=True,
|
128
|
+
decoder_config=get_decoder_config(**kwargs),
|
126
129
|
image_token_id=257152,
|
130
|
+
image_projection_scale=2048**0.5,
|
131
|
+
image_projection_use_bias=True,
|
127
132
|
)
|
128
133
|
|
129
134
|
|
130
|
-
def get_fake_model_config() -> PaliGemmaConfig:
|
135
|
+
def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
|
131
136
|
return PaliGemmaConfig(
|
132
137
|
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
133
|
-
decoder_config=
|
138
|
+
decoder_config=get_decoder_config(**kwargs),
|
139
|
+
image_token_id=127,
|
140
|
+
image_projection_scale=128**0.5,
|
134
141
|
image_projection_use_bias=True,
|
135
|
-
image_token_id=257152,
|
136
142
|
)
|
137
143
|
|
138
144
|
|
139
|
-
def build_model(checkpoint_path: str, **kwargs) -> PaliGemma:
|
140
|
-
|
141
|
-
|
145
|
+
def build_model(checkpoint_path: str, version: int = 2, **kwargs) -> PaliGemma:
|
146
|
+
if version == 1:
|
147
|
+
decoder_class = decoder.Decoder
|
148
|
+
decoder_tensor_names = decoder.TENSOR_NAMES
|
149
|
+
get_decoder_config = decoder.get_decoder_config
|
150
|
+
else:
|
151
|
+
decoder_class = decoder2.Decoder2
|
152
|
+
decoder_tensor_names = decoder2.TENSOR_NAMES
|
153
|
+
get_decoder_config = decoder2.get_decoder2_config
|
154
|
+
|
155
|
+
config = get_model_config(get_decoder_config, **kwargs)
|
156
|
+
model = PaliGemma(config, decoder_class)
|
142
157
|
# Load the parameters of image encoder.
|
143
158
|
loader = loading_utils.ModelLoader(
|
144
159
|
checkpoint_path, image_encoder.TENSOR_NAMES
|
145
160
|
)
|
146
161
|
loader.load(model.image_encoder, strict=False)
|
147
162
|
# Load the parameters of decoder.
|
148
|
-
loader = loading_utils.ModelLoader(checkpoint_path,
|
163
|
+
loader = loading_utils.ModelLoader(checkpoint_path, decoder_tensor_names)
|
149
164
|
loader.load(model.decoder, strict=False)
|
150
165
|
|
151
166
|
# Load the parameters of image projection.
|
@@ -22,11 +22,18 @@ from absl import flags
|
|
22
22
|
from ai_edge_torch.generative.examples.paligemma import paligemma
|
23
23
|
from ai_edge_torch.generative.layers import kv_cache
|
24
24
|
from ai_edge_torch.generative.utilities import verifier
|
25
|
+
import kagglehub
|
25
26
|
from PIL import Image
|
26
27
|
import requests
|
27
28
|
import torch
|
28
29
|
import transformers
|
29
30
|
|
31
|
+
_VERSION = flags.DEFINE_enum(
|
32
|
+
"version",
|
33
|
+
"2",
|
34
|
+
["1", "2"],
|
35
|
+
"The version of PaliGemma model to verify.",
|
36
|
+
)
|
30
37
|
_IMAGE_URL = flags.DEFINE_string(
|
31
38
|
"image_url",
|
32
39
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
|
@@ -34,7 +41,7 @@ _IMAGE_URL = flags.DEFINE_string(
|
|
34
41
|
)
|
35
42
|
_PROMPTS = flags.DEFINE_string(
|
36
43
|
"prompts",
|
37
|
-
"
|
44
|
+
"describe en",
|
38
45
|
"The input prompts to generate answers.",
|
39
46
|
)
|
40
47
|
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
@@ -43,28 +50,47 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
43
50
|
"The maximum size of the generated tokens.",
|
44
51
|
)
|
45
52
|
|
53
|
+
_CHECKPOINT = {
|
54
|
+
"1": "google/paligemma-3b-mix-224",
|
55
|
+
"2": "google/paligemma-2/transformers/paligemma2-3b-pt-224",
|
56
|
+
}
|
57
|
+
|
46
58
|
|
47
59
|
class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
|
48
60
|
"""Reauthored PaliGemma model wrapper."""
|
49
61
|
|
62
|
+
def __init__(self, model: torch.nn.Module):
|
63
|
+
super().__init__(model)
|
64
|
+
self.forward_called_by_generate = False
|
65
|
+
|
50
66
|
def _init_kv_cache(self):
|
51
67
|
return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
|
52
68
|
|
69
|
+
def _get_extra_args_for_forward(self):
|
70
|
+
return {"called_by_generate": self.forward_called_by_generate}
|
71
|
+
|
53
72
|
|
54
73
|
def main(_):
|
55
|
-
|
74
|
+
if _VERSION.value == "1":
|
75
|
+
checkpoint = _CHECKPOINT[_VERSION.value]
|
76
|
+
# Locate the cached dir.
|
77
|
+
cached_config_file = transformers.utils.cached_file(
|
78
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
79
|
+
)
|
80
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
81
|
+
else:
|
82
|
+
checkpoint = kagglehub.model_download(_CHECKPOINT[_VERSION.value])
|
83
|
+
reauthored_checkpoint = checkpoint
|
84
|
+
|
56
85
|
logging.info("Loading the original model from: %s", checkpoint)
|
57
86
|
original_model = (
|
58
87
|
transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
|
59
88
|
)
|
60
89
|
|
61
|
-
# Locate the cached dir.
|
62
|
-
cached_config_file = transformers.utils.cached_file(
|
63
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
64
|
-
)
|
65
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
66
90
|
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
67
|
-
reauthored_model = paligemma.build_model(
|
91
|
+
reauthored_model = paligemma.build_model(
|
92
|
+
reauthored_checkpoint, version=int(_VERSION.value)
|
93
|
+
)
|
68
94
|
|
69
95
|
logging.info("Loading the processor from: %s", checkpoint)
|
70
96
|
# It works only when GemmaTokenizerFast is available. In some environments,
|
@@ -93,7 +119,7 @@ def main(_):
|
|
93
119
|
logging.info("outputs_reauthored: %s", outputs_reauthored)
|
94
120
|
|
95
121
|
try:
|
96
|
-
assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-
|
122
|
+
assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-02)
|
97
123
|
except AssertionError as e:
|
98
124
|
logging.error("*** FAILED *** verify with forward()")
|
99
125
|
raise e
|
@@ -111,6 +137,7 @@ def main(_):
|
|
111
137
|
logging.info("outputs_from_original_model: [[%s]]", response_original)
|
112
138
|
|
113
139
|
logging.info("Generating answer with the reauthored model...")
|
140
|
+
wrapped_reauthored_model.forward_called_by_generate = True
|
114
141
|
outputs_reauthored = wrapped_reauthored_model.generate(
|
115
142
|
prompts=inputs["input_ids"],
|
116
143
|
pixel_values=inputs["pixel_values"],
|