ai-edge-torch-nightly 0.3.0.dev20241218__py3-none-any.whl → 0.3.0.dev20241224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +3 -2
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +43 -25
  3. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -3
  4. ai_edge_torch/generative/examples/paligemma/decoder.py +14 -5
  5. ai_edge_torch/generative/examples/paligemma/decoder2.py +174 -0
  6. ai_edge_torch/generative/examples/paligemma/paligemma.py +30 -15
  7. ai_edge_torch/generative/examples/paligemma/verify.py +36 -9
  8. ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
  9. ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +24 -7
  10. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  11. ai_edge_torch/generative/layers/attention.py +4 -29
  12. ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -27
  13. ai_edge_torch/generative/test/test_model_conversion_large.py +28 -9
  14. ai_edge_torch/generative/utilities/model_builder.py +14 -14
  15. ai_edge_torch/generative/utilities/verifier.py +22 -22
  16. ai_edge_torch/odml_torch/export.py +6 -1
  17. ai_edge_torch/odml_torch/jax_bridge/__init__.py +4 -1
  18. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  19. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +1 -2
  20. ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
  21. ai_edge_torch/version.py +1 -1
  22. {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/METADATA +1 -1
  23. {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/RECORD +26 -23
  24. {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/LICENSE +0 -0
  25. {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/WHEEL +0 -0
  26. {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/top_level.txt +0 -0
@@ -72,12 +72,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
72
72
  pre_attention_norm_config=norm_config,
73
73
  post_attention_norm_config=norm_config,
74
74
  )
75
+ embedding_dim = 2048
75
76
  config = cfg.ModelConfig(
76
77
  vocab_size=256000,
77
78
  num_layers=18,
78
79
  max_seq_len=8192,
79
- embedding_dim=2048,
80
- embedding_scale=2048**0.5,
80
+ embedding_dim=embedding_dim,
81
+ embedding_scale=embedding_dim**0.5,
81
82
  kv_cache_max_len=kv_cache_max_len,
82
83
  block_configs=block_config,
83
84
  final_norm_config=norm_config,
@@ -15,13 +15,14 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- from typing import Optional, Tuple
18
+ from typing import List, Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
25
26
  from ai_edge_torch.generative.utilities import model_builder
26
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
28
  import torch
@@ -103,17 +104,12 @@ class Gemma2(nn.Module):
103
104
  config.embedding_dim,
104
105
  config.final_norm_config,
105
106
  )
106
- # Gemma2 has same hyper parameters for each layer except for attention
107
- # types. Use the first layer.
108
- attn_config = config.block_config(0).attn_config
109
- self.rope_cache = attn_utils.build_rope_cache(
110
- size=config.kv_cache_max,
111
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
112
- base=attn_config.rotary_base,
113
- )
114
107
  self.mask_cache = attn_utils.build_causal_mask_cache(
115
108
  size=config.kv_cache_max,
116
109
  )
110
+ # Gemma2 has same hyper parameters for each layer except for attention
111
+ # types. Use the first layer.
112
+ attn_config = config.block_config(0).attn_config
117
113
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
118
114
  size=config.kv_cache_max,
119
115
  window_size=attn_config.sliding_window_size,
@@ -140,29 +136,48 @@ class Gemma2(nn.Module):
140
136
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
141
137
  f" {self.config.max_seq_len}"
142
138
  )
139
+
140
+ # token embeddings of shape (b, t, n_embd)
141
+ input_embeds = self.tok_embedding(tokens)
142
+ # RoPE parameters are the same for all blocks. Use the first layer.
143
+ attn_config = self.config.block_config(0).attn_config
144
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
145
+ rope = rotary_pos_emb.build_rope(
146
+ input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
147
+ )
148
+ mask = [self.get_attention_mask(
149
+ self.config.block_config(i).attn_config.attn_type, input_pos
150
+ ) for i in range(self.config.num_layers)]
151
+
152
+ return self._forward_with_embeds(
153
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
154
+ )
155
+
156
+ def _forward_with_embeds(
157
+ self,
158
+ input_embeds: torch.Tensor,
159
+ rope: Tuple[torch.Tensor, torch.Tensor],
160
+ mask: List[torch.Tensor],
161
+ input_pos: torch.Tensor,
162
+ kv_cache: kv_utils.KVCache,
163
+ export_config: Optional[model_builder.ExportConfig] = None,
164
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
165
+ """Forwards the model with input embeddings."""
143
166
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
144
167
  "The number of transformer blocks and the number of KV cache entries"
145
168
  " must be the same."
146
169
  )
147
170
 
148
- cos, sin = self.rope_cache
149
- cos = cos.index_select(0, input_pos)
150
- sin = sin.index_select(0, input_pos)
151
-
152
- # token embeddings of shape (b, t, n_embd)
153
- x = self.tok_embedding(tokens)
154
- x = x * (self.config.embedding_dim**0.5)
155
-
156
- updated_kv_entires = []
171
+ if self.config.embedding_scale is not None:
172
+ input_embeds = input_embeds * self.config.embedding_scale
173
+ x = input_embeds
174
+ updated_kv_entries = []
157
175
  for i, block in enumerate(self.transformer_blocks):
158
- mask = self.get_attention_mask(
159
- block.config.attn_config.attn_type, input_pos
160
- )
161
176
  kv_entry = kv_cache.caches[i] if kv_cache else None
162
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
177
+ x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
163
178
  if kv_entry:
164
- updated_kv_entires.append(kv_entry)
165
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
179
+ updated_kv_entries.append(kv_entry)
180
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
166
181
 
167
182
  if export_config is not None:
168
183
  if (
@@ -228,11 +243,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
228
243
  )
229
244
 
230
245
  num_layers = 26
246
+ embedding_dim = 2304
231
247
  config = cfg.ModelConfig(
232
248
  vocab_size=256000,
233
249
  num_layers=num_layers,
234
250
  max_seq_len=8192,
235
- embedding_dim=2304,
251
+ embedding_dim=embedding_dim,
252
+ embedding_scale=embedding_dim**0.5,
236
253
  kv_cache_max_len=kv_cache_max_len,
237
254
  block_configs=[get_block_config(i) for i in range(num_layers)],
238
255
  final_norm_config=norm_config,
@@ -249,6 +266,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
249
266
  config.num_layers = 2
250
267
  config.max_seq_len = 2 * kv_cache_max_len
251
268
  config.embedding_dim = 128
269
+ config.embedding_scale = config.embedding_dim**0.5
252
270
  config.block_configs = config.block_configs[: config.num_layers]
253
271
  for block_config in config.block_configs:
254
272
  block_config.attn_config.num_heads = 4
@@ -29,9 +29,15 @@ from ai_edge_torch.generative.utilities import converter
29
29
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
30
30
  import torch
31
31
 
32
+ _VERSION = flags.DEFINE_enum(
33
+ 'version',
34
+ '2',
35
+ ['1', '2'],
36
+ 'The version of PaliGemma model to verify.',
37
+ )
32
38
  _CHECKPOINT_PATH = flags.DEFINE_string(
33
39
  'checkpoint_path',
34
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma-3b-224'),
40
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
35
41
  'The path to the model checkpoint, or directory holding the checkpoint.',
36
42
  )
37
43
  _TFLITE_PATH = flags.DEFINE_string(
@@ -63,10 +69,12 @@ _QUANTIZE = flags.DEFINE_bool(
63
69
 
64
70
  def main(_):
65
71
  pytorch_model = paligemma.build_model(
66
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
72
+ _CHECKPOINT_PATH.value,
73
+ version=int(_VERSION.value),
74
+ kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
67
75
  )
68
76
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
69
- output_filename = f'paligemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
77
+ output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
70
78
  converter.convert_to_tflite(
71
79
  pytorch_model,
72
80
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
@@ -19,6 +19,7 @@ from typing import Optional
19
19
 
20
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
21
  import ai_edge_torch.generative.layers.model_config as cfg
22
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
22
23
  from ai_edge_torch.generative.utilities import model_builder
23
24
  import ai_edge_torch.generative.utilities.loader as loading_utils
24
25
  import torch
@@ -54,6 +55,7 @@ class Decoder(model_builder.DecoderOnlyModel):
54
55
  kv_cache: kv_utils.KVCache,
55
56
  input_embeds: torch.Tensor = None,
56
57
  export_config: Optional[model_builder.ExportConfig] = None,
58
+ called_by_generate: bool = True,
57
59
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
58
60
  if input_embeds is None:
59
61
  return super().forward(tokens, input_pos, kv_cache)
@@ -61,8 +63,12 @@ class Decoder(model_builder.DecoderOnlyModel):
61
63
  assert input_embeds is not None
62
64
 
63
65
  repo_pos = input_pos + 1 # PaliGemma position is 1-based.
64
- cos, sin = self.rope_cache
65
- rope = (cos.index_select(0, repo_pos), sin.index_select(0, repo_pos))
66
+ # ROPE parameters for all attn_configs are the same. Take the first one.
67
+ attn_config = self.config.block_config(0).attn_config
68
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
69
+ rope = rotary_pos_emb.build_rope(
70
+ repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
71
+ )
66
72
 
67
73
  # The first part of input_embeds are image embeddings. Diagonal causal mask
68
74
  # doesn't work here.
@@ -70,7 +76,7 @@ class Decoder(model_builder.DecoderOnlyModel):
70
76
  mask = torch.zeros(embeds_len, self.config.kv_cache_max)
71
77
  mask[:, embeds_len:] = float("-inf")
72
78
 
73
- return self.forward_with_embeds(
79
+ return self._forward_with_embeds(
74
80
  input_embeds, rope, mask, input_pos, kv_cache
75
81
  )
76
82
 
@@ -108,12 +114,13 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
108
114
  pre_attention_norm_config=norm_config,
109
115
  post_attention_norm_config=norm_config,
110
116
  )
117
+ embedding_dim = 2048
111
118
  config = cfg.ModelConfig(
112
119
  vocab_size=257216,
113
120
  num_layers=18,
114
121
  max_seq_len=8192,
115
- embedding_dim=2048,
116
- embedding_scale=2048**0.5,
122
+ embedding_dim=embedding_dim,
123
+ embedding_scale=embedding_dim**0.5,
117
124
  kv_cache_max_len=kv_cache_max_len,
118
125
  block_configs=block_config,
119
126
  final_norm_config=norm_config,
@@ -130,6 +137,8 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
130
137
  config.vocab_size = 128
131
138
  config.num_layers = 2
132
139
  config.max_seq_len = 2 * kv_cache_max_len
140
+ config.embedding_dim = 128
141
+ config.embedding_scale = 128**0.5
133
142
  return config
134
143
 
135
144
 
@@ -0,0 +1,174 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a decoder of PaliGemma2 3B model which is Gemma2."""
17
+
18
+ from typing import Optional
19
+
20
+ from ai_edge_torch.generative.examples.gemma import gemma2
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ import ai_edge_torch.generative.layers.model_config as cfg
23
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
24
+ from ai_edge_torch.generative.utilities import model_builder
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import torch
27
+
28
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
29
+ ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
30
+ ff_down_proj="language_model.model.layers.{}.mlp.down_proj",
31
+ ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj",
32
+ attn_query_proj="language_model.model.layers.{}.self_attn.q_proj",
33
+ attn_key_proj="language_model.model.layers.{}.self_attn.k_proj",
34
+ attn_value_proj="language_model.model.layers.{}.self_attn.v_proj",
35
+ attn_output_proj="language_model.model.layers.{}.self_attn.o_proj",
36
+ pre_attn_norm="language_model.model.layers.{}.input_layernorm",
37
+ post_attn_norm="language_model.model.layers.{}.post_attention_layernorm",
38
+ pre_ff_norm="language_model.model.layers.{}.pre_feedforward_layernorm",
39
+ post_ff_norm="language_model.model.layers.{}.post_feedforward_layernorm",
40
+ embedding="language_model.model.embed_tokens",
41
+ final_norm="language_model.model.norm",
42
+ lm_head=None,
43
+ )
44
+
45
+
46
+ class Decoder2(gemma2.Gemma2):
47
+ """A decoder of PaliGemma2 3B model which is Gemma2.
48
+
49
+ Besides a tensor of text token IDs, forward() can also take a tensor of
50
+ embeddings which may include text or image or both.
51
+ """
52
+
53
+ @torch.inference_mode
54
+ def forward(
55
+ self,
56
+ tokens: torch.Tensor,
57
+ input_pos: torch.Tensor,
58
+ kv_cache: kv_utils.KVCache,
59
+ input_embeds: torch.Tensor = None,
60
+ export_config: Optional[model_builder.ExportConfig] = None,
61
+ called_by_generate: bool = True,
62
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
63
+ if input_embeds is None:
64
+ return super().forward(tokens, input_pos, kv_cache)
65
+
66
+ assert input_embeds is not None
67
+
68
+ repo_pos = input_pos + 1 # PaliGemma2 position is 1-based.
69
+ # ROPE parameters for all attn_configs are the same. Take the first one.
70
+ attn_config = self.config.block_config(0).attn_config
71
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
72
+ rope = rotary_pos_emb.build_rope(
73
+ repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
74
+ )
75
+
76
+ if called_by_generate:
77
+ # PaliGemma2 generate() use a diagonal causal mask even with image embeds.
78
+ mask = [self.get_attention_mask(
79
+ self.config.block_config(i).attn_config.attn_type, input_pos
80
+ ) for i in range(self.config.num_layers)]
81
+ else:
82
+ # By default, don't mask image embeds with a diagonal causal mask.
83
+ embeds_len = input_embeds.shape[1]
84
+ mask = torch.zeros(embeds_len, self.config.kv_cache_max)
85
+ mask[:, embeds_len:] = float("-inf")
86
+ mask = [mask] * self.config.num_layers
87
+
88
+ return self._forward_with_embeds(
89
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
90
+ )
91
+
92
+
93
+ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
94
+ """Returns the model config for the decoder of a PaliGemma 3B model.
95
+
96
+ Args:
97
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
98
+ is 1024.
99
+
100
+ Returns:
101
+ The model config for the decoder of a PaliGemma 3B model.
102
+ """
103
+ norm_config = cfg.NormalizationConfig(
104
+ type=cfg.NormalizationType.RMS_NORM,
105
+ epsilon=1e-6,
106
+ zero_centered=True,
107
+ )
108
+ ff_config = cfg.FeedForwardConfig(
109
+ type=cfg.FeedForwardType.GATED,
110
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
111
+ intermediate_size=9216,
112
+ pre_ff_norm_config=norm_config,
113
+ post_ff_norm_config=norm_config,
114
+ )
115
+
116
+ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
117
+ attn_config = cfg.AttentionConfig(
118
+ num_heads=8,
119
+ head_dim=256,
120
+ num_query_groups=4,
121
+ rotary_base=10000,
122
+ rotary_percentage=1.0,
123
+ logit_softcap=50.0,
124
+ sliding_window_size=4096,
125
+ attn_type=(
126
+ cfg.AttentionType.GLOBAL
127
+ if idx % 2 == 0
128
+ else cfg.AttentionType.LOCAL_SLIDING
129
+ ),
130
+ )
131
+ return cfg.TransformerBlockConfig(
132
+ attn_config=attn_config,
133
+ ff_config=ff_config,
134
+ pre_attention_norm_config=norm_config,
135
+ post_attention_norm_config=norm_config,
136
+ )
137
+
138
+ num_layers = 26
139
+ embedding_dim = 2304
140
+ config = cfg.ModelConfig(
141
+ vocab_size=257216,
142
+ num_layers=num_layers,
143
+ max_seq_len=8192,
144
+ embedding_dim=embedding_dim,
145
+ embedding_scale=embedding_dim**0.5,
146
+ kv_cache_max_len=kv_cache_max_len,
147
+ block_configs=[get_block_config(i) for i in range(num_layers)],
148
+ final_norm_config=norm_config,
149
+ lm_head_use_bias=False,
150
+ enable_hlfb=True,
151
+ final_logit_softcap=30.0,
152
+ )
153
+ return config
154
+
155
+
156
+ def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
157
+ config = get_decoder2_config(kv_cache_max_len)
158
+ # PaliGemma2 decoder has only one block config.
159
+ config.block_config(0).ff_config.intermediate_size = 128
160
+ config.vocab_size = 128
161
+ config.num_layers = 2
162
+ config.max_seq_len = 2 * kv_cache_max_len
163
+ config.embedding_dim = 128
164
+ config.embedding_scale = 128**0.5
165
+ return config
166
+
167
+
168
+ def build_decoder2(checkpoint_path: str, **kwargs) -> torch.nn.Module:
169
+ return model_builder.build_decoder_only_model(
170
+ checkpoint_path=checkpoint_path,
171
+ config=get_decoder2_config(**kwargs),
172
+ tensor_names=TENSOR_NAMES,
173
+ model_class=Decoder2,
174
+ )
@@ -19,6 +19,7 @@ from dataclasses import dataclass
19
19
  from typing import Optional
20
20
 
21
21
  from ai_edge_torch.generative.examples.paligemma import decoder
22
+ from ai_edge_torch.generative.examples.paligemma import decoder2
22
23
  from ai_edge_torch.generative.examples.paligemma import image_encoder
23
24
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
24
25
  import ai_edge_torch.generative.layers.model_config as cfg
@@ -38,13 +39,14 @@ class PaliGemmaConfig:
38
39
  decoder_config: cfg.ModelConfig
39
40
 
40
41
  image_token_id: int
42
+ image_projection_scale: float
41
43
  image_projection_use_bias: bool = False
42
44
 
43
45
 
44
46
  class PaliGemma(nn.Module):
45
47
  """PaliGemma model from the Edge Generative API."""
46
48
 
47
- def __init__(self, config: PaliGemmaConfig):
49
+ def __init__(self, config: PaliGemmaConfig, decoder_class: nn.Module):
48
50
  super().__init__()
49
51
 
50
52
  self.image_encoder = image_encoder.SiglipVisionEncoder(
@@ -55,7 +57,7 @@ class PaliGemma(nn.Module):
55
57
  config.decoder_config.embedding_dim,
56
58
  bias=config.image_projection_use_bias,
57
59
  )
58
- self.decoder = decoder.Decoder(config.decoder_config)
60
+ self.decoder = decoder_class(config.decoder_config)
59
61
  image_embedding_config = config.image_encoder_config.image_embedding
60
62
  self.num_patches = (
61
63
  image_embedding_config.image_size // image_embedding_config.patch_size
@@ -70,6 +72,7 @@ class PaliGemma(nn.Module):
70
72
  kv_cache: kv_utils.KVCache,
71
73
  pixel_values: torch.Tensor = None,
72
74
  export_config: Optional[model_builder.ExportConfig] = None,
75
+ called_by_generate: bool = True,
73
76
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
74
77
  if pixel_values is None:
75
78
  return self.decoder(
@@ -77,15 +80,15 @@ class PaliGemma(nn.Module):
77
80
  input_pos=input_pos,
78
81
  kv_cache=kv_cache,
79
82
  input_embeds=None,
80
- export_config=export_config
83
+ export_config=export_config,
84
+ called_by_generate=called_by_generate,
81
85
  )
82
86
 
83
87
  input_embeds = self.decoder.tok_embedding(tokens)
84
88
 
85
89
  image_encoded = self.image_encoder(pixel_values=pixel_values)
86
90
  image_embeds = self.image_projection(image_encoded)
87
- if self.config.decoder_config.embedding_scale is not None:
88
- image_embeds = image_embeds / self.config.decoder_config.embedding_scale
91
+ image_embeds = image_embeds / self.config.image_projection_scale
89
92
 
90
93
  # Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
91
94
  # can be done like:
@@ -110,10 +113,11 @@ class PaliGemma(nn.Module):
110
113
  kv_cache=kv_cache,
111
114
  input_embeds=input_embeds,
112
115
  export_config=export_config,
116
+ called_by_generate=called_by_generate,
113
117
  )
114
118
 
115
119
 
116
- def get_model_config(**kwargs) -> PaliGemmaConfig:
120
+ def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
117
121
  """Returns the model config for a PaliGemma 3B-224 model.
118
122
 
119
123
  Returns:
@@ -121,31 +125,42 @@ def get_model_config(**kwargs) -> PaliGemmaConfig:
121
125
  """
122
126
  return PaliGemmaConfig(
123
127
  image_encoder_config=image_encoder.get_image_encoder_config(),
124
- decoder_config=decoder.get_decoder_config(**kwargs),
125
- image_projection_use_bias=True,
128
+ decoder_config=get_decoder_config(**kwargs),
126
129
  image_token_id=257152,
130
+ image_projection_scale=2048**0.5,
131
+ image_projection_use_bias=True,
127
132
  )
128
133
 
129
134
 
130
- def get_fake_model_config() -> PaliGemmaConfig:
135
+ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
131
136
  return PaliGemmaConfig(
132
137
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
133
- decoder_config=decoder.get_fake_decoder_config(),
138
+ decoder_config=get_decoder_config(**kwargs),
139
+ image_token_id=127,
140
+ image_projection_scale=128**0.5,
134
141
  image_projection_use_bias=True,
135
- image_token_id=257152,
136
142
  )
137
143
 
138
144
 
139
- def build_model(checkpoint_path: str, **kwargs) -> PaliGemma:
140
- config = get_model_config(**kwargs)
141
- model = PaliGemma(config)
145
+ def build_model(checkpoint_path: str, version: int = 2, **kwargs) -> PaliGemma:
146
+ if version == 1:
147
+ decoder_class = decoder.Decoder
148
+ decoder_tensor_names = decoder.TENSOR_NAMES
149
+ get_decoder_config = decoder.get_decoder_config
150
+ else:
151
+ decoder_class = decoder2.Decoder2
152
+ decoder_tensor_names = decoder2.TENSOR_NAMES
153
+ get_decoder_config = decoder2.get_decoder2_config
154
+
155
+ config = get_model_config(get_decoder_config, **kwargs)
156
+ model = PaliGemma(config, decoder_class)
142
157
  # Load the parameters of image encoder.
143
158
  loader = loading_utils.ModelLoader(
144
159
  checkpoint_path, image_encoder.TENSOR_NAMES
145
160
  )
146
161
  loader.load(model.image_encoder, strict=False)
147
162
  # Load the parameters of decoder.
148
- loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES)
163
+ loader = loading_utils.ModelLoader(checkpoint_path, decoder_tensor_names)
149
164
  loader.load(model.decoder, strict=False)
150
165
 
151
166
  # Load the parameters of image projection.
@@ -22,11 +22,18 @@ from absl import flags
22
22
  from ai_edge_torch.generative.examples.paligemma import paligemma
23
23
  from ai_edge_torch.generative.layers import kv_cache
24
24
  from ai_edge_torch.generative.utilities import verifier
25
+ import kagglehub
25
26
  from PIL import Image
26
27
  import requests
27
28
  import torch
28
29
  import transformers
29
30
 
31
+ _VERSION = flags.DEFINE_enum(
32
+ "version",
33
+ "2",
34
+ ["1", "2"],
35
+ "The version of PaliGemma model to verify.",
36
+ )
30
37
  _IMAGE_URL = flags.DEFINE_string(
31
38
  "image_url",
32
39
  "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
@@ -34,7 +41,7 @@ _IMAGE_URL = flags.DEFINE_string(
34
41
  )
35
42
  _PROMPTS = flags.DEFINE_string(
36
43
  "prompts",
37
- "Caption en",
44
+ "describe en",
38
45
  "The input prompts to generate answers.",
39
46
  )
40
47
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
@@ -43,28 +50,47 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
43
50
  "The maximum size of the generated tokens.",
44
51
  )
45
52
 
53
+ _CHECKPOINT = {
54
+ "1": "google/paligemma-3b-mix-224",
55
+ "2": "google/paligemma-2/transformers/paligemma2-3b-pt-224",
56
+ }
57
+
46
58
 
47
59
  class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
48
60
  """Reauthored PaliGemma model wrapper."""
49
61
 
62
+ def __init__(self, model: torch.nn.Module):
63
+ super().__init__(model)
64
+ self.forward_called_by_generate = False
65
+
50
66
  def _init_kv_cache(self):
51
67
  return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
52
68
 
69
+ def _get_extra_args_for_forward(self):
70
+ return {"called_by_generate": self.forward_called_by_generate}
71
+
53
72
 
54
73
  def main(_):
55
- checkpoint = "google/paligemma-3b-mix-224"
74
+ if _VERSION.value == "1":
75
+ checkpoint = _CHECKPOINT[_VERSION.value]
76
+ # Locate the cached dir.
77
+ cached_config_file = transformers.utils.cached_file(
78
+ checkpoint, transformers.utils.CONFIG_NAME
79
+ )
80
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
81
+ else:
82
+ checkpoint = kagglehub.model_download(_CHECKPOINT[_VERSION.value])
83
+ reauthored_checkpoint = checkpoint
84
+
56
85
  logging.info("Loading the original model from: %s", checkpoint)
57
86
  original_model = (
58
87
  transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
59
88
  )
60
89
 
61
- # Locate the cached dir.
62
- cached_config_file = transformers.utils.cached_file(
63
- checkpoint, transformers.utils.CONFIG_NAME
64
- )
65
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
66
90
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
67
- reauthored_model = paligemma.build_model(reauthored_checkpoint)
91
+ reauthored_model = paligemma.build_model(
92
+ reauthored_checkpoint, version=int(_VERSION.value)
93
+ )
68
94
 
69
95
  logging.info("Loading the processor from: %s", checkpoint)
70
96
  # It works only when GemmaTokenizerFast is available. In some environments,
@@ -93,7 +119,7 @@ def main(_):
93
119
  logging.info("outputs_reauthored: %s", outputs_reauthored)
94
120
 
95
121
  try:
96
- assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-03)
122
+ assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-02)
97
123
  except AssertionError as e:
98
124
  logging.error("*** FAILED *** verify with forward()")
99
125
  raise e
@@ -111,6 +137,7 @@ def main(_):
111
137
  logging.info("outputs_from_original_model: [[%s]]", response_original)
112
138
 
113
139
  logging.info("Generating answer with the reauthored model...")
140
+ wrapped_reauthored_model.forward_called_by_generate = True
114
141
  outputs_reauthored = wrapped_reauthored_model.generate(
115
142
  prompts=inputs["input_ids"],
116
143
  pixel_values=inputs["pixel_values"],