ai-edge-torch-nightly 0.3.0.dev20241218__py3-none-any.whl → 0.3.0.dev20241224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/gemma1.py +3 -2
- ai_edge_torch/generative/examples/gemma/gemma2.py +43 -25
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -3
- ai_edge_torch/generative/examples/paligemma/decoder.py +14 -5
- ai_edge_torch/generative/examples/paligemma/decoder2.py +174 -0
- ai_edge_torch/generative/examples/paligemma/paligemma.py +30 -15
- ai_edge_torch/generative/examples/paligemma/verify.py +36 -9
- ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +24 -7
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/layers/attention.py +4 -29
- ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -27
- ai_edge_torch/generative/test/test_model_conversion_large.py +28 -9
- ai_edge_torch/generative/utilities/model_builder.py +14 -14
- ai_edge_torch/generative/utilities/verifier.py +22 -22
- ai_edge_torch/odml_torch/export.py +6 -1
- ai_edge_torch/odml_torch/jax_bridge/__init__.py +4 -1
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +1 -2
- ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/RECORD +26 -23
- {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/top_level.txt +0 -0
@@ -72,12 +72,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
72
72
|
pre_attention_norm_config=norm_config,
|
73
73
|
post_attention_norm_config=norm_config,
|
74
74
|
)
|
75
|
+
embedding_dim = 2048
|
75
76
|
config = cfg.ModelConfig(
|
76
77
|
vocab_size=256000,
|
77
78
|
num_layers=18,
|
78
79
|
max_seq_len=8192,
|
79
|
-
embedding_dim=
|
80
|
-
embedding_scale=
|
80
|
+
embedding_dim=embedding_dim,
|
81
|
+
embedding_scale=embedding_dim**0.5,
|
81
82
|
kv_cache_max_len=kv_cache_max_len,
|
82
83
|
block_configs=block_config,
|
83
84
|
final_norm_config=norm_config,
|
@@ -15,13 +15,14 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma2 model."""
|
17
17
|
|
18
|
-
from typing import Optional, Tuple
|
18
|
+
from typing import List, Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
25
26
|
from ai_edge_torch.generative.utilities import model_builder
|
26
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
28
|
import torch
|
@@ -103,17 +104,12 @@ class Gemma2(nn.Module):
|
|
103
104
|
config.embedding_dim,
|
104
105
|
config.final_norm_config,
|
105
106
|
)
|
106
|
-
# Gemma2 has same hyper parameters for each layer except for attention
|
107
|
-
# types. Use the first layer.
|
108
|
-
attn_config = config.block_config(0).attn_config
|
109
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
110
|
-
size=config.kv_cache_max,
|
111
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
112
|
-
base=attn_config.rotary_base,
|
113
|
-
)
|
114
107
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
115
108
|
size=config.kv_cache_max,
|
116
109
|
)
|
110
|
+
# Gemma2 has same hyper parameters for each layer except for attention
|
111
|
+
# types. Use the first layer.
|
112
|
+
attn_config = config.block_config(0).attn_config
|
117
113
|
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
118
114
|
size=config.kv_cache_max,
|
119
115
|
window_size=attn_config.sliding_window_size,
|
@@ -140,29 +136,48 @@ class Gemma2(nn.Module):
|
|
140
136
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
141
137
|
f" {self.config.max_seq_len}"
|
142
138
|
)
|
139
|
+
|
140
|
+
# token embeddings of shape (b, t, n_embd)
|
141
|
+
input_embeds = self.tok_embedding(tokens)
|
142
|
+
# RoPE parameters are the same for all blocks. Use the first layer.
|
143
|
+
attn_config = self.config.block_config(0).attn_config
|
144
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
145
|
+
rope = rotary_pos_emb.build_rope(
|
146
|
+
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
147
|
+
)
|
148
|
+
mask = [self.get_attention_mask(
|
149
|
+
self.config.block_config(i).attn_config.attn_type, input_pos
|
150
|
+
) for i in range(self.config.num_layers)]
|
151
|
+
|
152
|
+
return self._forward_with_embeds(
|
153
|
+
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
154
|
+
)
|
155
|
+
|
156
|
+
def _forward_with_embeds(
|
157
|
+
self,
|
158
|
+
input_embeds: torch.Tensor,
|
159
|
+
rope: Tuple[torch.Tensor, torch.Tensor],
|
160
|
+
mask: List[torch.Tensor],
|
161
|
+
input_pos: torch.Tensor,
|
162
|
+
kv_cache: kv_utils.KVCache,
|
163
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
164
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
165
|
+
"""Forwards the model with input embeddings."""
|
143
166
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
144
167
|
"The number of transformer blocks and the number of KV cache entries"
|
145
168
|
" must be the same."
|
146
169
|
)
|
147
170
|
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
# token embeddings of shape (b, t, n_embd)
|
153
|
-
x = self.tok_embedding(tokens)
|
154
|
-
x = x * (self.config.embedding_dim**0.5)
|
155
|
-
|
156
|
-
updated_kv_entires = []
|
171
|
+
if self.config.embedding_scale is not None:
|
172
|
+
input_embeds = input_embeds * self.config.embedding_scale
|
173
|
+
x = input_embeds
|
174
|
+
updated_kv_entries = []
|
157
175
|
for i, block in enumerate(self.transformer_blocks):
|
158
|
-
mask = self.get_attention_mask(
|
159
|
-
block.config.attn_config.attn_type, input_pos
|
160
|
-
)
|
161
176
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
162
|
-
x, kv_entry = block(x,
|
177
|
+
x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
|
163
178
|
if kv_entry:
|
164
|
-
|
165
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
179
|
+
updated_kv_entries.append(kv_entry)
|
180
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
166
181
|
|
167
182
|
if export_config is not None:
|
168
183
|
if (
|
@@ -228,11 +243,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
228
243
|
)
|
229
244
|
|
230
245
|
num_layers = 26
|
246
|
+
embedding_dim = 2304
|
231
247
|
config = cfg.ModelConfig(
|
232
248
|
vocab_size=256000,
|
233
249
|
num_layers=num_layers,
|
234
250
|
max_seq_len=8192,
|
235
|
-
embedding_dim=
|
251
|
+
embedding_dim=embedding_dim,
|
252
|
+
embedding_scale=embedding_dim**0.5,
|
236
253
|
kv_cache_max_len=kv_cache_max_len,
|
237
254
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
238
255
|
final_norm_config=norm_config,
|
@@ -249,6 +266,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
249
266
|
config.num_layers = 2
|
250
267
|
config.max_seq_len = 2 * kv_cache_max_len
|
251
268
|
config.embedding_dim = 128
|
269
|
+
config.embedding_scale = config.embedding_dim**0.5
|
252
270
|
config.block_configs = config.block_configs[: config.num_layers]
|
253
271
|
for block_config in config.block_configs:
|
254
272
|
block_config.attn_config.num_heads = 4
|
@@ -29,9 +29,15 @@ from ai_edge_torch.generative.utilities import converter
|
|
29
29
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
30
30
|
import torch
|
31
31
|
|
32
|
+
_VERSION = flags.DEFINE_enum(
|
33
|
+
'version',
|
34
|
+
'2',
|
35
|
+
['1', '2'],
|
36
|
+
'The version of PaliGemma model to verify.',
|
37
|
+
)
|
32
38
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
33
39
|
'checkpoint_path',
|
34
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
40
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
|
35
41
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
36
42
|
)
|
37
43
|
_TFLITE_PATH = flags.DEFINE_string(
|
@@ -63,10 +69,12 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
63
69
|
|
64
70
|
def main(_):
|
65
71
|
pytorch_model = paligemma.build_model(
|
66
|
-
_CHECKPOINT_PATH.value,
|
72
|
+
_CHECKPOINT_PATH.value,
|
73
|
+
version=int(_VERSION.value),
|
74
|
+
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
|
67
75
|
)
|
68
76
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
69
|
-
output_filename = f'
|
77
|
+
output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
70
78
|
converter.convert_to_tflite(
|
71
79
|
pytorch_model,
|
72
80
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
@@ -19,6 +19,7 @@ from typing import Optional
|
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
22
23
|
from ai_edge_torch.generative.utilities import model_builder
|
23
24
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
24
25
|
import torch
|
@@ -54,6 +55,7 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
54
55
|
kv_cache: kv_utils.KVCache,
|
55
56
|
input_embeds: torch.Tensor = None,
|
56
57
|
export_config: Optional[model_builder.ExportConfig] = None,
|
58
|
+
called_by_generate: bool = True,
|
57
59
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
58
60
|
if input_embeds is None:
|
59
61
|
return super().forward(tokens, input_pos, kv_cache)
|
@@ -61,8 +63,12 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
61
63
|
assert input_embeds is not None
|
62
64
|
|
63
65
|
repo_pos = input_pos + 1 # PaliGemma position is 1-based.
|
64
|
-
|
65
|
-
|
66
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
67
|
+
attn_config = self.config.block_config(0).attn_config
|
68
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
69
|
+
rope = rotary_pos_emb.build_rope(
|
70
|
+
repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
71
|
+
)
|
66
72
|
|
67
73
|
# The first part of input_embeds are image embeddings. Diagonal causal mask
|
68
74
|
# doesn't work here.
|
@@ -70,7 +76,7 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
70
76
|
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
71
77
|
mask[:, embeds_len:] = float("-inf")
|
72
78
|
|
73
|
-
return self.
|
79
|
+
return self._forward_with_embeds(
|
74
80
|
input_embeds, rope, mask, input_pos, kv_cache
|
75
81
|
)
|
76
82
|
|
@@ -108,12 +114,13 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
108
114
|
pre_attention_norm_config=norm_config,
|
109
115
|
post_attention_norm_config=norm_config,
|
110
116
|
)
|
117
|
+
embedding_dim = 2048
|
111
118
|
config = cfg.ModelConfig(
|
112
119
|
vocab_size=257216,
|
113
120
|
num_layers=18,
|
114
121
|
max_seq_len=8192,
|
115
|
-
embedding_dim=
|
116
|
-
embedding_scale=
|
122
|
+
embedding_dim=embedding_dim,
|
123
|
+
embedding_scale=embedding_dim**0.5,
|
117
124
|
kv_cache_max_len=kv_cache_max_len,
|
118
125
|
block_configs=block_config,
|
119
126
|
final_norm_config=norm_config,
|
@@ -130,6 +137,8 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
130
137
|
config.vocab_size = 128
|
131
138
|
config.num_layers = 2
|
132
139
|
config.max_seq_len = 2 * kv_cache_max_len
|
140
|
+
config.embedding_dim = 128
|
141
|
+
config.embedding_scale = 128**0.5
|
133
142
|
return config
|
134
143
|
|
135
144
|
|
@@ -0,0 +1,174 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a decoder of PaliGemma2 3B model which is Gemma2."""
|
17
|
+
|
18
|
+
from typing import Optional
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma2
|
21
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
23
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
24
|
+
from ai_edge_torch.generative.utilities import model_builder
|
25
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
|
+
import torch
|
27
|
+
|
28
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
29
|
+
ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
|
30
|
+
ff_down_proj="language_model.model.layers.{}.mlp.down_proj",
|
31
|
+
ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj",
|
32
|
+
attn_query_proj="language_model.model.layers.{}.self_attn.q_proj",
|
33
|
+
attn_key_proj="language_model.model.layers.{}.self_attn.k_proj",
|
34
|
+
attn_value_proj="language_model.model.layers.{}.self_attn.v_proj",
|
35
|
+
attn_output_proj="language_model.model.layers.{}.self_attn.o_proj",
|
36
|
+
pre_attn_norm="language_model.model.layers.{}.input_layernorm",
|
37
|
+
post_attn_norm="language_model.model.layers.{}.post_attention_layernorm",
|
38
|
+
pre_ff_norm="language_model.model.layers.{}.pre_feedforward_layernorm",
|
39
|
+
post_ff_norm="language_model.model.layers.{}.post_feedforward_layernorm",
|
40
|
+
embedding="language_model.model.embed_tokens",
|
41
|
+
final_norm="language_model.model.norm",
|
42
|
+
lm_head=None,
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class Decoder2(gemma2.Gemma2):
|
47
|
+
"""A decoder of PaliGemma2 3B model which is Gemma2.
|
48
|
+
|
49
|
+
Besides a tensor of text token IDs, forward() can also take a tensor of
|
50
|
+
embeddings which may include text or image or both.
|
51
|
+
"""
|
52
|
+
|
53
|
+
@torch.inference_mode
|
54
|
+
def forward(
|
55
|
+
self,
|
56
|
+
tokens: torch.Tensor,
|
57
|
+
input_pos: torch.Tensor,
|
58
|
+
kv_cache: kv_utils.KVCache,
|
59
|
+
input_embeds: torch.Tensor = None,
|
60
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
61
|
+
called_by_generate: bool = True,
|
62
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
63
|
+
if input_embeds is None:
|
64
|
+
return super().forward(tokens, input_pos, kv_cache)
|
65
|
+
|
66
|
+
assert input_embeds is not None
|
67
|
+
|
68
|
+
repo_pos = input_pos + 1 # PaliGemma2 position is 1-based.
|
69
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
70
|
+
attn_config = self.config.block_config(0).attn_config
|
71
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
72
|
+
rope = rotary_pos_emb.build_rope(
|
73
|
+
repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
74
|
+
)
|
75
|
+
|
76
|
+
if called_by_generate:
|
77
|
+
# PaliGemma2 generate() use a diagonal causal mask even with image embeds.
|
78
|
+
mask = [self.get_attention_mask(
|
79
|
+
self.config.block_config(i).attn_config.attn_type, input_pos
|
80
|
+
) for i in range(self.config.num_layers)]
|
81
|
+
else:
|
82
|
+
# By default, don't mask image embeds with a diagonal causal mask.
|
83
|
+
embeds_len = input_embeds.shape[1]
|
84
|
+
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
85
|
+
mask[:, embeds_len:] = float("-inf")
|
86
|
+
mask = [mask] * self.config.num_layers
|
87
|
+
|
88
|
+
return self._forward_with_embeds(
|
89
|
+
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
90
|
+
)
|
91
|
+
|
92
|
+
|
93
|
+
def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
94
|
+
"""Returns the model config for the decoder of a PaliGemma 3B model.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
98
|
+
is 1024.
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
The model config for the decoder of a PaliGemma 3B model.
|
102
|
+
"""
|
103
|
+
norm_config = cfg.NormalizationConfig(
|
104
|
+
type=cfg.NormalizationType.RMS_NORM,
|
105
|
+
epsilon=1e-6,
|
106
|
+
zero_centered=True,
|
107
|
+
)
|
108
|
+
ff_config = cfg.FeedForwardConfig(
|
109
|
+
type=cfg.FeedForwardType.GATED,
|
110
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
111
|
+
intermediate_size=9216,
|
112
|
+
pre_ff_norm_config=norm_config,
|
113
|
+
post_ff_norm_config=norm_config,
|
114
|
+
)
|
115
|
+
|
116
|
+
def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
|
117
|
+
attn_config = cfg.AttentionConfig(
|
118
|
+
num_heads=8,
|
119
|
+
head_dim=256,
|
120
|
+
num_query_groups=4,
|
121
|
+
rotary_base=10000,
|
122
|
+
rotary_percentage=1.0,
|
123
|
+
logit_softcap=50.0,
|
124
|
+
sliding_window_size=4096,
|
125
|
+
attn_type=(
|
126
|
+
cfg.AttentionType.GLOBAL
|
127
|
+
if idx % 2 == 0
|
128
|
+
else cfg.AttentionType.LOCAL_SLIDING
|
129
|
+
),
|
130
|
+
)
|
131
|
+
return cfg.TransformerBlockConfig(
|
132
|
+
attn_config=attn_config,
|
133
|
+
ff_config=ff_config,
|
134
|
+
pre_attention_norm_config=norm_config,
|
135
|
+
post_attention_norm_config=norm_config,
|
136
|
+
)
|
137
|
+
|
138
|
+
num_layers = 26
|
139
|
+
embedding_dim = 2304
|
140
|
+
config = cfg.ModelConfig(
|
141
|
+
vocab_size=257216,
|
142
|
+
num_layers=num_layers,
|
143
|
+
max_seq_len=8192,
|
144
|
+
embedding_dim=embedding_dim,
|
145
|
+
embedding_scale=embedding_dim**0.5,
|
146
|
+
kv_cache_max_len=kv_cache_max_len,
|
147
|
+
block_configs=[get_block_config(i) for i in range(num_layers)],
|
148
|
+
final_norm_config=norm_config,
|
149
|
+
lm_head_use_bias=False,
|
150
|
+
enable_hlfb=True,
|
151
|
+
final_logit_softcap=30.0,
|
152
|
+
)
|
153
|
+
return config
|
154
|
+
|
155
|
+
|
156
|
+
def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
157
|
+
config = get_decoder2_config(kv_cache_max_len)
|
158
|
+
# PaliGemma2 decoder has only one block config.
|
159
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
160
|
+
config.vocab_size = 128
|
161
|
+
config.num_layers = 2
|
162
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
163
|
+
config.embedding_dim = 128
|
164
|
+
config.embedding_scale = 128**0.5
|
165
|
+
return config
|
166
|
+
|
167
|
+
|
168
|
+
def build_decoder2(checkpoint_path: str, **kwargs) -> torch.nn.Module:
|
169
|
+
return model_builder.build_decoder_only_model(
|
170
|
+
checkpoint_path=checkpoint_path,
|
171
|
+
config=get_decoder2_config(**kwargs),
|
172
|
+
tensor_names=TENSOR_NAMES,
|
173
|
+
model_class=Decoder2,
|
174
|
+
)
|
@@ -19,6 +19,7 @@ from dataclasses import dataclass
|
|
19
19
|
from typing import Optional
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.examples.paligemma import decoder
|
22
|
+
from ai_edge_torch.generative.examples.paligemma import decoder2
|
22
23
|
from ai_edge_torch.generative.examples.paligemma import image_encoder
|
23
24
|
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
24
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
@@ -38,13 +39,14 @@ class PaliGemmaConfig:
|
|
38
39
|
decoder_config: cfg.ModelConfig
|
39
40
|
|
40
41
|
image_token_id: int
|
42
|
+
image_projection_scale: float
|
41
43
|
image_projection_use_bias: bool = False
|
42
44
|
|
43
45
|
|
44
46
|
class PaliGemma(nn.Module):
|
45
47
|
"""PaliGemma model from the Edge Generative API."""
|
46
48
|
|
47
|
-
def __init__(self, config: PaliGemmaConfig):
|
49
|
+
def __init__(self, config: PaliGemmaConfig, decoder_class: nn.Module):
|
48
50
|
super().__init__()
|
49
51
|
|
50
52
|
self.image_encoder = image_encoder.SiglipVisionEncoder(
|
@@ -55,7 +57,7 @@ class PaliGemma(nn.Module):
|
|
55
57
|
config.decoder_config.embedding_dim,
|
56
58
|
bias=config.image_projection_use_bias,
|
57
59
|
)
|
58
|
-
self.decoder =
|
60
|
+
self.decoder = decoder_class(config.decoder_config)
|
59
61
|
image_embedding_config = config.image_encoder_config.image_embedding
|
60
62
|
self.num_patches = (
|
61
63
|
image_embedding_config.image_size // image_embedding_config.patch_size
|
@@ -70,6 +72,7 @@ class PaliGemma(nn.Module):
|
|
70
72
|
kv_cache: kv_utils.KVCache,
|
71
73
|
pixel_values: torch.Tensor = None,
|
72
74
|
export_config: Optional[model_builder.ExportConfig] = None,
|
75
|
+
called_by_generate: bool = True,
|
73
76
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
74
77
|
if pixel_values is None:
|
75
78
|
return self.decoder(
|
@@ -77,15 +80,15 @@ class PaliGemma(nn.Module):
|
|
77
80
|
input_pos=input_pos,
|
78
81
|
kv_cache=kv_cache,
|
79
82
|
input_embeds=None,
|
80
|
-
export_config=export_config
|
83
|
+
export_config=export_config,
|
84
|
+
called_by_generate=called_by_generate,
|
81
85
|
)
|
82
86
|
|
83
87
|
input_embeds = self.decoder.tok_embedding(tokens)
|
84
88
|
|
85
89
|
image_encoded = self.image_encoder(pixel_values=pixel_values)
|
86
90
|
image_embeds = self.image_projection(image_encoded)
|
87
|
-
|
88
|
-
image_embeds = image_embeds / self.config.decoder_config.embedding_scale
|
91
|
+
image_embeds = image_embeds / self.config.image_projection_scale
|
89
92
|
|
90
93
|
# Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
|
91
94
|
# can be done like:
|
@@ -110,10 +113,11 @@ class PaliGemma(nn.Module):
|
|
110
113
|
kv_cache=kv_cache,
|
111
114
|
input_embeds=input_embeds,
|
112
115
|
export_config=export_config,
|
116
|
+
called_by_generate=called_by_generate,
|
113
117
|
)
|
114
118
|
|
115
119
|
|
116
|
-
def get_model_config(**kwargs) -> PaliGemmaConfig:
|
120
|
+
def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
|
117
121
|
"""Returns the model config for a PaliGemma 3B-224 model.
|
118
122
|
|
119
123
|
Returns:
|
@@ -121,31 +125,42 @@ def get_model_config(**kwargs) -> PaliGemmaConfig:
|
|
121
125
|
"""
|
122
126
|
return PaliGemmaConfig(
|
123
127
|
image_encoder_config=image_encoder.get_image_encoder_config(),
|
124
|
-
decoder_config=
|
125
|
-
image_projection_use_bias=True,
|
128
|
+
decoder_config=get_decoder_config(**kwargs),
|
126
129
|
image_token_id=257152,
|
130
|
+
image_projection_scale=2048**0.5,
|
131
|
+
image_projection_use_bias=True,
|
127
132
|
)
|
128
133
|
|
129
134
|
|
130
|
-
def get_fake_model_config() -> PaliGemmaConfig:
|
135
|
+
def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
|
131
136
|
return PaliGemmaConfig(
|
132
137
|
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
133
|
-
decoder_config=
|
138
|
+
decoder_config=get_decoder_config(**kwargs),
|
139
|
+
image_token_id=127,
|
140
|
+
image_projection_scale=128**0.5,
|
134
141
|
image_projection_use_bias=True,
|
135
|
-
image_token_id=257152,
|
136
142
|
)
|
137
143
|
|
138
144
|
|
139
|
-
def build_model(checkpoint_path: str, **kwargs) -> PaliGemma:
|
140
|
-
|
141
|
-
|
145
|
+
def build_model(checkpoint_path: str, version: int = 2, **kwargs) -> PaliGemma:
|
146
|
+
if version == 1:
|
147
|
+
decoder_class = decoder.Decoder
|
148
|
+
decoder_tensor_names = decoder.TENSOR_NAMES
|
149
|
+
get_decoder_config = decoder.get_decoder_config
|
150
|
+
else:
|
151
|
+
decoder_class = decoder2.Decoder2
|
152
|
+
decoder_tensor_names = decoder2.TENSOR_NAMES
|
153
|
+
get_decoder_config = decoder2.get_decoder2_config
|
154
|
+
|
155
|
+
config = get_model_config(get_decoder_config, **kwargs)
|
156
|
+
model = PaliGemma(config, decoder_class)
|
142
157
|
# Load the parameters of image encoder.
|
143
158
|
loader = loading_utils.ModelLoader(
|
144
159
|
checkpoint_path, image_encoder.TENSOR_NAMES
|
145
160
|
)
|
146
161
|
loader.load(model.image_encoder, strict=False)
|
147
162
|
# Load the parameters of decoder.
|
148
|
-
loader = loading_utils.ModelLoader(checkpoint_path,
|
163
|
+
loader = loading_utils.ModelLoader(checkpoint_path, decoder_tensor_names)
|
149
164
|
loader.load(model.decoder, strict=False)
|
150
165
|
|
151
166
|
# Load the parameters of image projection.
|
@@ -22,11 +22,18 @@ from absl import flags
|
|
22
22
|
from ai_edge_torch.generative.examples.paligemma import paligemma
|
23
23
|
from ai_edge_torch.generative.layers import kv_cache
|
24
24
|
from ai_edge_torch.generative.utilities import verifier
|
25
|
+
import kagglehub
|
25
26
|
from PIL import Image
|
26
27
|
import requests
|
27
28
|
import torch
|
28
29
|
import transformers
|
29
30
|
|
31
|
+
_VERSION = flags.DEFINE_enum(
|
32
|
+
"version",
|
33
|
+
"2",
|
34
|
+
["1", "2"],
|
35
|
+
"The version of PaliGemma model to verify.",
|
36
|
+
)
|
30
37
|
_IMAGE_URL = flags.DEFINE_string(
|
31
38
|
"image_url",
|
32
39
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
|
@@ -34,7 +41,7 @@ _IMAGE_URL = flags.DEFINE_string(
|
|
34
41
|
)
|
35
42
|
_PROMPTS = flags.DEFINE_string(
|
36
43
|
"prompts",
|
37
|
-
"
|
44
|
+
"describe en",
|
38
45
|
"The input prompts to generate answers.",
|
39
46
|
)
|
40
47
|
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
@@ -43,28 +50,47 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
43
50
|
"The maximum size of the generated tokens.",
|
44
51
|
)
|
45
52
|
|
53
|
+
_CHECKPOINT = {
|
54
|
+
"1": "google/paligemma-3b-mix-224",
|
55
|
+
"2": "google/paligemma-2/transformers/paligemma2-3b-pt-224",
|
56
|
+
}
|
57
|
+
|
46
58
|
|
47
59
|
class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
|
48
60
|
"""Reauthored PaliGemma model wrapper."""
|
49
61
|
|
62
|
+
def __init__(self, model: torch.nn.Module):
|
63
|
+
super().__init__(model)
|
64
|
+
self.forward_called_by_generate = False
|
65
|
+
|
50
66
|
def _init_kv_cache(self):
|
51
67
|
return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
|
52
68
|
|
69
|
+
def _get_extra_args_for_forward(self):
|
70
|
+
return {"called_by_generate": self.forward_called_by_generate}
|
71
|
+
|
53
72
|
|
54
73
|
def main(_):
|
55
|
-
|
74
|
+
if _VERSION.value == "1":
|
75
|
+
checkpoint = _CHECKPOINT[_VERSION.value]
|
76
|
+
# Locate the cached dir.
|
77
|
+
cached_config_file = transformers.utils.cached_file(
|
78
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
79
|
+
)
|
80
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
81
|
+
else:
|
82
|
+
checkpoint = kagglehub.model_download(_CHECKPOINT[_VERSION.value])
|
83
|
+
reauthored_checkpoint = checkpoint
|
84
|
+
|
56
85
|
logging.info("Loading the original model from: %s", checkpoint)
|
57
86
|
original_model = (
|
58
87
|
transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
|
59
88
|
)
|
60
89
|
|
61
|
-
# Locate the cached dir.
|
62
|
-
cached_config_file = transformers.utils.cached_file(
|
63
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
64
|
-
)
|
65
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
66
90
|
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
67
|
-
reauthored_model = paligemma.build_model(
|
91
|
+
reauthored_model = paligemma.build_model(
|
92
|
+
reauthored_checkpoint, version=int(_VERSION.value)
|
93
|
+
)
|
68
94
|
|
69
95
|
logging.info("Loading the processor from: %s", checkpoint)
|
70
96
|
# It works only when GemmaTokenizerFast is available. In some environments,
|
@@ -93,7 +119,7 @@ def main(_):
|
|
93
119
|
logging.info("outputs_reauthored: %s", outputs_reauthored)
|
94
120
|
|
95
121
|
try:
|
96
|
-
assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-
|
122
|
+
assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-02)
|
97
123
|
except AssertionError as e:
|
98
124
|
logging.error("*** FAILED *** verify with forward()")
|
99
125
|
raise e
|
@@ -111,6 +137,7 @@ def main(_):
|
|
111
137
|
logging.info("outputs_from_original_model: [[%s]]", response_original)
|
112
138
|
|
113
139
|
logging.info("Generating answer with the reauthored model...")
|
140
|
+
wrapped_reauthored_model.forward_called_by_generate = True
|
114
141
|
outputs_reauthored = wrapped_reauthored_model.generate(
|
115
142
|
prompts=inputs["input_ids"],
|
116
143
|
pixel_values=inputs["pixel_values"],
|