ai-edge-torch-nightly 0.3.0.dev20241218__py3-none-any.whl → 0.3.0.dev20241224__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (26) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +3 -2
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +43 -25
  3. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -3
  4. ai_edge_torch/generative/examples/paligemma/decoder.py +14 -5
  5. ai_edge_torch/generative/examples/paligemma/decoder2.py +174 -0
  6. ai_edge_torch/generative/examples/paligemma/paligemma.py +30 -15
  7. ai_edge_torch/generative/examples/paligemma/verify.py +36 -9
  8. ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
  9. ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +24 -7
  10. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  11. ai_edge_torch/generative/layers/attention.py +4 -29
  12. ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -27
  13. ai_edge_torch/generative/test/test_model_conversion_large.py +28 -9
  14. ai_edge_torch/generative/utilities/model_builder.py +14 -14
  15. ai_edge_torch/generative/utilities/verifier.py +22 -22
  16. ai_edge_torch/odml_torch/export.py +6 -1
  17. ai_edge_torch/odml_torch/jax_bridge/__init__.py +4 -1
  18. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  19. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +1 -2
  20. ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
  21. ai_edge_torch/version.py +1 -1
  22. {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/METADATA +1 -1
  23. {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/RECORD +26 -23
  24. {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/LICENSE +0 -0
  25. {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/WHEEL +0 -0
  26. {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/top_level.txt +0 -0
@@ -72,12 +72,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
72
72
  pre_attention_norm_config=norm_config,
73
73
  post_attention_norm_config=norm_config,
74
74
  )
75
+ embedding_dim = 2048
75
76
  config = cfg.ModelConfig(
76
77
  vocab_size=256000,
77
78
  num_layers=18,
78
79
  max_seq_len=8192,
79
- embedding_dim=2048,
80
- embedding_scale=2048**0.5,
80
+ embedding_dim=embedding_dim,
81
+ embedding_scale=embedding_dim**0.5,
81
82
  kv_cache_max_len=kv_cache_max_len,
82
83
  block_configs=block_config,
83
84
  final_norm_config=norm_config,
@@ -15,13 +15,14 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- from typing import Optional, Tuple
18
+ from typing import List, Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
25
26
  from ai_edge_torch.generative.utilities import model_builder
26
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
28
  import torch
@@ -103,17 +104,12 @@ class Gemma2(nn.Module):
103
104
  config.embedding_dim,
104
105
  config.final_norm_config,
105
106
  )
106
- # Gemma2 has same hyper parameters for each layer except for attention
107
- # types. Use the first layer.
108
- attn_config = config.block_config(0).attn_config
109
- self.rope_cache = attn_utils.build_rope_cache(
110
- size=config.kv_cache_max,
111
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
112
- base=attn_config.rotary_base,
113
- )
114
107
  self.mask_cache = attn_utils.build_causal_mask_cache(
115
108
  size=config.kv_cache_max,
116
109
  )
110
+ # Gemma2 has same hyper parameters for each layer except for attention
111
+ # types. Use the first layer.
112
+ attn_config = config.block_config(0).attn_config
117
113
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
118
114
  size=config.kv_cache_max,
119
115
  window_size=attn_config.sliding_window_size,
@@ -140,29 +136,48 @@ class Gemma2(nn.Module):
140
136
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
141
137
  f" {self.config.max_seq_len}"
142
138
  )
139
+
140
+ # token embeddings of shape (b, t, n_embd)
141
+ input_embeds = self.tok_embedding(tokens)
142
+ # RoPE parameters are the same for all blocks. Use the first layer.
143
+ attn_config = self.config.block_config(0).attn_config
144
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
145
+ rope = rotary_pos_emb.build_rope(
146
+ input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
147
+ )
148
+ mask = [self.get_attention_mask(
149
+ self.config.block_config(i).attn_config.attn_type, input_pos
150
+ ) for i in range(self.config.num_layers)]
151
+
152
+ return self._forward_with_embeds(
153
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
154
+ )
155
+
156
+ def _forward_with_embeds(
157
+ self,
158
+ input_embeds: torch.Tensor,
159
+ rope: Tuple[torch.Tensor, torch.Tensor],
160
+ mask: List[torch.Tensor],
161
+ input_pos: torch.Tensor,
162
+ kv_cache: kv_utils.KVCache,
163
+ export_config: Optional[model_builder.ExportConfig] = None,
164
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
165
+ """Forwards the model with input embeddings."""
143
166
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
144
167
  "The number of transformer blocks and the number of KV cache entries"
145
168
  " must be the same."
146
169
  )
147
170
 
148
- cos, sin = self.rope_cache
149
- cos = cos.index_select(0, input_pos)
150
- sin = sin.index_select(0, input_pos)
151
-
152
- # token embeddings of shape (b, t, n_embd)
153
- x = self.tok_embedding(tokens)
154
- x = x * (self.config.embedding_dim**0.5)
155
-
156
- updated_kv_entires = []
171
+ if self.config.embedding_scale is not None:
172
+ input_embeds = input_embeds * self.config.embedding_scale
173
+ x = input_embeds
174
+ updated_kv_entries = []
157
175
  for i, block in enumerate(self.transformer_blocks):
158
- mask = self.get_attention_mask(
159
- block.config.attn_config.attn_type, input_pos
160
- )
161
176
  kv_entry = kv_cache.caches[i] if kv_cache else None
162
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
177
+ x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
163
178
  if kv_entry:
164
- updated_kv_entires.append(kv_entry)
165
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
179
+ updated_kv_entries.append(kv_entry)
180
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
166
181
 
167
182
  if export_config is not None:
168
183
  if (
@@ -228,11 +243,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
228
243
  )
229
244
 
230
245
  num_layers = 26
246
+ embedding_dim = 2304
231
247
  config = cfg.ModelConfig(
232
248
  vocab_size=256000,
233
249
  num_layers=num_layers,
234
250
  max_seq_len=8192,
235
- embedding_dim=2304,
251
+ embedding_dim=embedding_dim,
252
+ embedding_scale=embedding_dim**0.5,
236
253
  kv_cache_max_len=kv_cache_max_len,
237
254
  block_configs=[get_block_config(i) for i in range(num_layers)],
238
255
  final_norm_config=norm_config,
@@ -249,6 +266,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
249
266
  config.num_layers = 2
250
267
  config.max_seq_len = 2 * kv_cache_max_len
251
268
  config.embedding_dim = 128
269
+ config.embedding_scale = config.embedding_dim**0.5
252
270
  config.block_configs = config.block_configs[: config.num_layers]
253
271
  for block_config in config.block_configs:
254
272
  block_config.attn_config.num_heads = 4
@@ -29,9 +29,15 @@ from ai_edge_torch.generative.utilities import converter
29
29
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
30
30
  import torch
31
31
 
32
+ _VERSION = flags.DEFINE_enum(
33
+ 'version',
34
+ '2',
35
+ ['1', '2'],
36
+ 'The version of PaliGemma model to verify.',
37
+ )
32
38
  _CHECKPOINT_PATH = flags.DEFINE_string(
33
39
  'checkpoint_path',
34
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma-3b-224'),
40
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
35
41
  'The path to the model checkpoint, or directory holding the checkpoint.',
36
42
  )
37
43
  _TFLITE_PATH = flags.DEFINE_string(
@@ -63,10 +69,12 @@ _QUANTIZE = flags.DEFINE_bool(
63
69
 
64
70
  def main(_):
65
71
  pytorch_model = paligemma.build_model(
66
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
72
+ _CHECKPOINT_PATH.value,
73
+ version=int(_VERSION.value),
74
+ kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
67
75
  )
68
76
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
69
- output_filename = f'paligemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
77
+ output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
70
78
  converter.convert_to_tflite(
71
79
  pytorch_model,
72
80
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
@@ -19,6 +19,7 @@ from typing import Optional
19
19
 
20
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
21
  import ai_edge_torch.generative.layers.model_config as cfg
22
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
22
23
  from ai_edge_torch.generative.utilities import model_builder
23
24
  import ai_edge_torch.generative.utilities.loader as loading_utils
24
25
  import torch
@@ -54,6 +55,7 @@ class Decoder(model_builder.DecoderOnlyModel):
54
55
  kv_cache: kv_utils.KVCache,
55
56
  input_embeds: torch.Tensor = None,
56
57
  export_config: Optional[model_builder.ExportConfig] = None,
58
+ called_by_generate: bool = True,
57
59
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
58
60
  if input_embeds is None:
59
61
  return super().forward(tokens, input_pos, kv_cache)
@@ -61,8 +63,12 @@ class Decoder(model_builder.DecoderOnlyModel):
61
63
  assert input_embeds is not None
62
64
 
63
65
  repo_pos = input_pos + 1 # PaliGemma position is 1-based.
64
- cos, sin = self.rope_cache
65
- rope = (cos.index_select(0, repo_pos), sin.index_select(0, repo_pos))
66
+ # ROPE parameters for all attn_configs are the same. Take the first one.
67
+ attn_config = self.config.block_config(0).attn_config
68
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
69
+ rope = rotary_pos_emb.build_rope(
70
+ repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
71
+ )
66
72
 
67
73
  # The first part of input_embeds are image embeddings. Diagonal causal mask
68
74
  # doesn't work here.
@@ -70,7 +76,7 @@ class Decoder(model_builder.DecoderOnlyModel):
70
76
  mask = torch.zeros(embeds_len, self.config.kv_cache_max)
71
77
  mask[:, embeds_len:] = float("-inf")
72
78
 
73
- return self.forward_with_embeds(
79
+ return self._forward_with_embeds(
74
80
  input_embeds, rope, mask, input_pos, kv_cache
75
81
  )
76
82
 
@@ -108,12 +114,13 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
108
114
  pre_attention_norm_config=norm_config,
109
115
  post_attention_norm_config=norm_config,
110
116
  )
117
+ embedding_dim = 2048
111
118
  config = cfg.ModelConfig(
112
119
  vocab_size=257216,
113
120
  num_layers=18,
114
121
  max_seq_len=8192,
115
- embedding_dim=2048,
116
- embedding_scale=2048**0.5,
122
+ embedding_dim=embedding_dim,
123
+ embedding_scale=embedding_dim**0.5,
117
124
  kv_cache_max_len=kv_cache_max_len,
118
125
  block_configs=block_config,
119
126
  final_norm_config=norm_config,
@@ -130,6 +137,8 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
130
137
  config.vocab_size = 128
131
138
  config.num_layers = 2
132
139
  config.max_seq_len = 2 * kv_cache_max_len
140
+ config.embedding_dim = 128
141
+ config.embedding_scale = 128**0.5
133
142
  return config
134
143
 
135
144
 
@@ -0,0 +1,174 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a decoder of PaliGemma2 3B model which is Gemma2."""
17
+
18
+ from typing import Optional
19
+
20
+ from ai_edge_torch.generative.examples.gemma import gemma2
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ import ai_edge_torch.generative.layers.model_config as cfg
23
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
24
+ from ai_edge_torch.generative.utilities import model_builder
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import torch
27
+
28
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
29
+ ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
30
+ ff_down_proj="language_model.model.layers.{}.mlp.down_proj",
31
+ ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj",
32
+ attn_query_proj="language_model.model.layers.{}.self_attn.q_proj",
33
+ attn_key_proj="language_model.model.layers.{}.self_attn.k_proj",
34
+ attn_value_proj="language_model.model.layers.{}.self_attn.v_proj",
35
+ attn_output_proj="language_model.model.layers.{}.self_attn.o_proj",
36
+ pre_attn_norm="language_model.model.layers.{}.input_layernorm",
37
+ post_attn_norm="language_model.model.layers.{}.post_attention_layernorm",
38
+ pre_ff_norm="language_model.model.layers.{}.pre_feedforward_layernorm",
39
+ post_ff_norm="language_model.model.layers.{}.post_feedforward_layernorm",
40
+ embedding="language_model.model.embed_tokens",
41
+ final_norm="language_model.model.norm",
42
+ lm_head=None,
43
+ )
44
+
45
+
46
+ class Decoder2(gemma2.Gemma2):
47
+ """A decoder of PaliGemma2 3B model which is Gemma2.
48
+
49
+ Besides a tensor of text token IDs, forward() can also take a tensor of
50
+ embeddings which may include text or image or both.
51
+ """
52
+
53
+ @torch.inference_mode
54
+ def forward(
55
+ self,
56
+ tokens: torch.Tensor,
57
+ input_pos: torch.Tensor,
58
+ kv_cache: kv_utils.KVCache,
59
+ input_embeds: torch.Tensor = None,
60
+ export_config: Optional[model_builder.ExportConfig] = None,
61
+ called_by_generate: bool = True,
62
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
63
+ if input_embeds is None:
64
+ return super().forward(tokens, input_pos, kv_cache)
65
+
66
+ assert input_embeds is not None
67
+
68
+ repo_pos = input_pos + 1 # PaliGemma2 position is 1-based.
69
+ # ROPE parameters for all attn_configs are the same. Take the first one.
70
+ attn_config = self.config.block_config(0).attn_config
71
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
72
+ rope = rotary_pos_emb.build_rope(
73
+ repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
74
+ )
75
+
76
+ if called_by_generate:
77
+ # PaliGemma2 generate() use a diagonal causal mask even with image embeds.
78
+ mask = [self.get_attention_mask(
79
+ self.config.block_config(i).attn_config.attn_type, input_pos
80
+ ) for i in range(self.config.num_layers)]
81
+ else:
82
+ # By default, don't mask image embeds with a diagonal causal mask.
83
+ embeds_len = input_embeds.shape[1]
84
+ mask = torch.zeros(embeds_len, self.config.kv_cache_max)
85
+ mask[:, embeds_len:] = float("-inf")
86
+ mask = [mask] * self.config.num_layers
87
+
88
+ return self._forward_with_embeds(
89
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
90
+ )
91
+
92
+
93
+ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
94
+ """Returns the model config for the decoder of a PaliGemma 3B model.
95
+
96
+ Args:
97
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
98
+ is 1024.
99
+
100
+ Returns:
101
+ The model config for the decoder of a PaliGemma 3B model.
102
+ """
103
+ norm_config = cfg.NormalizationConfig(
104
+ type=cfg.NormalizationType.RMS_NORM,
105
+ epsilon=1e-6,
106
+ zero_centered=True,
107
+ )
108
+ ff_config = cfg.FeedForwardConfig(
109
+ type=cfg.FeedForwardType.GATED,
110
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
111
+ intermediate_size=9216,
112
+ pre_ff_norm_config=norm_config,
113
+ post_ff_norm_config=norm_config,
114
+ )
115
+
116
+ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
117
+ attn_config = cfg.AttentionConfig(
118
+ num_heads=8,
119
+ head_dim=256,
120
+ num_query_groups=4,
121
+ rotary_base=10000,
122
+ rotary_percentage=1.0,
123
+ logit_softcap=50.0,
124
+ sliding_window_size=4096,
125
+ attn_type=(
126
+ cfg.AttentionType.GLOBAL
127
+ if idx % 2 == 0
128
+ else cfg.AttentionType.LOCAL_SLIDING
129
+ ),
130
+ )
131
+ return cfg.TransformerBlockConfig(
132
+ attn_config=attn_config,
133
+ ff_config=ff_config,
134
+ pre_attention_norm_config=norm_config,
135
+ post_attention_norm_config=norm_config,
136
+ )
137
+
138
+ num_layers = 26
139
+ embedding_dim = 2304
140
+ config = cfg.ModelConfig(
141
+ vocab_size=257216,
142
+ num_layers=num_layers,
143
+ max_seq_len=8192,
144
+ embedding_dim=embedding_dim,
145
+ embedding_scale=embedding_dim**0.5,
146
+ kv_cache_max_len=kv_cache_max_len,
147
+ block_configs=[get_block_config(i) for i in range(num_layers)],
148
+ final_norm_config=norm_config,
149
+ lm_head_use_bias=False,
150
+ enable_hlfb=True,
151
+ final_logit_softcap=30.0,
152
+ )
153
+ return config
154
+
155
+
156
+ def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
157
+ config = get_decoder2_config(kv_cache_max_len)
158
+ # PaliGemma2 decoder has only one block config.
159
+ config.block_config(0).ff_config.intermediate_size = 128
160
+ config.vocab_size = 128
161
+ config.num_layers = 2
162
+ config.max_seq_len = 2 * kv_cache_max_len
163
+ config.embedding_dim = 128
164
+ config.embedding_scale = 128**0.5
165
+ return config
166
+
167
+
168
+ def build_decoder2(checkpoint_path: str, **kwargs) -> torch.nn.Module:
169
+ return model_builder.build_decoder_only_model(
170
+ checkpoint_path=checkpoint_path,
171
+ config=get_decoder2_config(**kwargs),
172
+ tensor_names=TENSOR_NAMES,
173
+ model_class=Decoder2,
174
+ )
@@ -19,6 +19,7 @@ from dataclasses import dataclass
19
19
  from typing import Optional
20
20
 
21
21
  from ai_edge_torch.generative.examples.paligemma import decoder
22
+ from ai_edge_torch.generative.examples.paligemma import decoder2
22
23
  from ai_edge_torch.generative.examples.paligemma import image_encoder
23
24
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
24
25
  import ai_edge_torch.generative.layers.model_config as cfg
@@ -38,13 +39,14 @@ class PaliGemmaConfig:
38
39
  decoder_config: cfg.ModelConfig
39
40
 
40
41
  image_token_id: int
42
+ image_projection_scale: float
41
43
  image_projection_use_bias: bool = False
42
44
 
43
45
 
44
46
  class PaliGemma(nn.Module):
45
47
  """PaliGemma model from the Edge Generative API."""
46
48
 
47
- def __init__(self, config: PaliGemmaConfig):
49
+ def __init__(self, config: PaliGemmaConfig, decoder_class: nn.Module):
48
50
  super().__init__()
49
51
 
50
52
  self.image_encoder = image_encoder.SiglipVisionEncoder(
@@ -55,7 +57,7 @@ class PaliGemma(nn.Module):
55
57
  config.decoder_config.embedding_dim,
56
58
  bias=config.image_projection_use_bias,
57
59
  )
58
- self.decoder = decoder.Decoder(config.decoder_config)
60
+ self.decoder = decoder_class(config.decoder_config)
59
61
  image_embedding_config = config.image_encoder_config.image_embedding
60
62
  self.num_patches = (
61
63
  image_embedding_config.image_size // image_embedding_config.patch_size
@@ -70,6 +72,7 @@ class PaliGemma(nn.Module):
70
72
  kv_cache: kv_utils.KVCache,
71
73
  pixel_values: torch.Tensor = None,
72
74
  export_config: Optional[model_builder.ExportConfig] = None,
75
+ called_by_generate: bool = True,
73
76
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
74
77
  if pixel_values is None:
75
78
  return self.decoder(
@@ -77,15 +80,15 @@ class PaliGemma(nn.Module):
77
80
  input_pos=input_pos,
78
81
  kv_cache=kv_cache,
79
82
  input_embeds=None,
80
- export_config=export_config
83
+ export_config=export_config,
84
+ called_by_generate=called_by_generate,
81
85
  )
82
86
 
83
87
  input_embeds = self.decoder.tok_embedding(tokens)
84
88
 
85
89
  image_encoded = self.image_encoder(pixel_values=pixel_values)
86
90
  image_embeds = self.image_projection(image_encoded)
87
- if self.config.decoder_config.embedding_scale is not None:
88
- image_embeds = image_embeds / self.config.decoder_config.embedding_scale
91
+ image_embeds = image_embeds / self.config.image_projection_scale
89
92
 
90
93
  # Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
91
94
  # can be done like:
@@ -110,10 +113,11 @@ class PaliGemma(nn.Module):
110
113
  kv_cache=kv_cache,
111
114
  input_embeds=input_embeds,
112
115
  export_config=export_config,
116
+ called_by_generate=called_by_generate,
113
117
  )
114
118
 
115
119
 
116
- def get_model_config(**kwargs) -> PaliGemmaConfig:
120
+ def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
117
121
  """Returns the model config for a PaliGemma 3B-224 model.
118
122
 
119
123
  Returns:
@@ -121,31 +125,42 @@ def get_model_config(**kwargs) -> PaliGemmaConfig:
121
125
  """
122
126
  return PaliGemmaConfig(
123
127
  image_encoder_config=image_encoder.get_image_encoder_config(),
124
- decoder_config=decoder.get_decoder_config(**kwargs),
125
- image_projection_use_bias=True,
128
+ decoder_config=get_decoder_config(**kwargs),
126
129
  image_token_id=257152,
130
+ image_projection_scale=2048**0.5,
131
+ image_projection_use_bias=True,
127
132
  )
128
133
 
129
134
 
130
- def get_fake_model_config() -> PaliGemmaConfig:
135
+ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
131
136
  return PaliGemmaConfig(
132
137
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
133
- decoder_config=decoder.get_fake_decoder_config(),
138
+ decoder_config=get_decoder_config(**kwargs),
139
+ image_token_id=127,
140
+ image_projection_scale=128**0.5,
134
141
  image_projection_use_bias=True,
135
- image_token_id=257152,
136
142
  )
137
143
 
138
144
 
139
- def build_model(checkpoint_path: str, **kwargs) -> PaliGemma:
140
- config = get_model_config(**kwargs)
141
- model = PaliGemma(config)
145
+ def build_model(checkpoint_path: str, version: int = 2, **kwargs) -> PaliGemma:
146
+ if version == 1:
147
+ decoder_class = decoder.Decoder
148
+ decoder_tensor_names = decoder.TENSOR_NAMES
149
+ get_decoder_config = decoder.get_decoder_config
150
+ else:
151
+ decoder_class = decoder2.Decoder2
152
+ decoder_tensor_names = decoder2.TENSOR_NAMES
153
+ get_decoder_config = decoder2.get_decoder2_config
154
+
155
+ config = get_model_config(get_decoder_config, **kwargs)
156
+ model = PaliGemma(config, decoder_class)
142
157
  # Load the parameters of image encoder.
143
158
  loader = loading_utils.ModelLoader(
144
159
  checkpoint_path, image_encoder.TENSOR_NAMES
145
160
  )
146
161
  loader.load(model.image_encoder, strict=False)
147
162
  # Load the parameters of decoder.
148
- loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES)
163
+ loader = loading_utils.ModelLoader(checkpoint_path, decoder_tensor_names)
149
164
  loader.load(model.decoder, strict=False)
150
165
 
151
166
  # Load the parameters of image projection.
@@ -22,11 +22,18 @@ from absl import flags
22
22
  from ai_edge_torch.generative.examples.paligemma import paligemma
23
23
  from ai_edge_torch.generative.layers import kv_cache
24
24
  from ai_edge_torch.generative.utilities import verifier
25
+ import kagglehub
25
26
  from PIL import Image
26
27
  import requests
27
28
  import torch
28
29
  import transformers
29
30
 
31
+ _VERSION = flags.DEFINE_enum(
32
+ "version",
33
+ "2",
34
+ ["1", "2"],
35
+ "The version of PaliGemma model to verify.",
36
+ )
30
37
  _IMAGE_URL = flags.DEFINE_string(
31
38
  "image_url",
32
39
  "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
@@ -34,7 +41,7 @@ _IMAGE_URL = flags.DEFINE_string(
34
41
  )
35
42
  _PROMPTS = flags.DEFINE_string(
36
43
  "prompts",
37
- "Caption en",
44
+ "describe en",
38
45
  "The input prompts to generate answers.",
39
46
  )
40
47
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
@@ -43,28 +50,47 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
43
50
  "The maximum size of the generated tokens.",
44
51
  )
45
52
 
53
+ _CHECKPOINT = {
54
+ "1": "google/paligemma-3b-mix-224",
55
+ "2": "google/paligemma-2/transformers/paligemma2-3b-pt-224",
56
+ }
57
+
46
58
 
47
59
  class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
48
60
  """Reauthored PaliGemma model wrapper."""
49
61
 
62
+ def __init__(self, model: torch.nn.Module):
63
+ super().__init__(model)
64
+ self.forward_called_by_generate = False
65
+
50
66
  def _init_kv_cache(self):
51
67
  return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
52
68
 
69
+ def _get_extra_args_for_forward(self):
70
+ return {"called_by_generate": self.forward_called_by_generate}
71
+
53
72
 
54
73
  def main(_):
55
- checkpoint = "google/paligemma-3b-mix-224"
74
+ if _VERSION.value == "1":
75
+ checkpoint = _CHECKPOINT[_VERSION.value]
76
+ # Locate the cached dir.
77
+ cached_config_file = transformers.utils.cached_file(
78
+ checkpoint, transformers.utils.CONFIG_NAME
79
+ )
80
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
81
+ else:
82
+ checkpoint = kagglehub.model_download(_CHECKPOINT[_VERSION.value])
83
+ reauthored_checkpoint = checkpoint
84
+
56
85
  logging.info("Loading the original model from: %s", checkpoint)
57
86
  original_model = (
58
87
  transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
59
88
  )
60
89
 
61
- # Locate the cached dir.
62
- cached_config_file = transformers.utils.cached_file(
63
- checkpoint, transformers.utils.CONFIG_NAME
64
- )
65
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
66
90
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
67
- reauthored_model = paligemma.build_model(reauthored_checkpoint)
91
+ reauthored_model = paligemma.build_model(
92
+ reauthored_checkpoint, version=int(_VERSION.value)
93
+ )
68
94
 
69
95
  logging.info("Loading the processor from: %s", checkpoint)
70
96
  # It works only when GemmaTokenizerFast is available. In some environments,
@@ -93,7 +119,7 @@ def main(_):
93
119
  logging.info("outputs_reauthored: %s", outputs_reauthored)
94
120
 
95
121
  try:
96
- assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-03)
122
+ assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-02)
97
123
  except AssertionError as e:
98
124
  logging.error("*** FAILED *** verify with forward()")
99
125
  raise e
@@ -111,6 +137,7 @@ def main(_):
111
137
  logging.info("outputs_from_original_model: [[%s]]", response_original)
112
138
 
113
139
  logging.info("Generating answer with the reauthored model...")
140
+ wrapped_reauthored_model.forward_called_by_generate = True
114
141
  outputs_reauthored = wrapped_reauthored_model.generate(
115
142
  prompts=inputs["input_ids"],
116
143
  pixel_values=inputs["pixel_values"],