ai-edge-torch-nightly 0.3.0.dev20241108__py3-none-any.whl → 0.3.0.dev20241109__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -93,6 +93,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
93
93
  kv_cache_max_len=kv_cache_max_len,
94
94
  block_configs=[get_block_config(i) for i in range(num_layers)],
95
95
  final_norm_config=norm_config,
96
+ enable_hlfb=True,
96
97
  )
97
98
  return config
98
99
 
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,103 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a decoder of PaliGemma 3B model which is Gemma1."""
17
+
18
+ import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
20
+ import ai_edge_torch.generative.utilities.loader as loading_utils
21
+
22
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
23
+ ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
24
+ ff_down_proj="language_model.model.layers.{}.mlp.down_proj",
25
+ ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj",
26
+ attn_query_proj="language_model.model.layers.{}.self_attn.q_proj",
27
+ attn_key_proj="language_model.model.layers.{}.self_attn.k_proj",
28
+ attn_value_proj="language_model.model.layers.{}.self_attn.v_proj",
29
+ attn_output_proj="language_model.model.layers.{}.self_attn.o_proj",
30
+ pre_attn_norm="language_model.model.layers.{}.input_layernorm",
31
+ post_attn_norm="language_model.model.layers.{}.post_attention_layernorm",
32
+ embedding="language_model.model.embed_tokens",
33
+ final_norm="language_model.model.norm",
34
+ lm_head=None,
35
+ )
36
+
37
+
38
+ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
39
+ """Returns the model config for the decoder of a PaliGemma 3B model.
40
+
41
+ Args:
42
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
43
+ is 1024.
44
+
45
+ Returns:
46
+ The model config for the decoder of a PaliGemma 3B model.
47
+ """
48
+ attn_config = cfg.AttentionConfig(
49
+ num_heads=8,
50
+ head_dim=256,
51
+ num_query_groups=1,
52
+ rotary_base=10000,
53
+ rotary_percentage=1.0,
54
+ )
55
+ ff_config = cfg.FeedForwardConfig(
56
+ type=cfg.FeedForwardType.GATED,
57
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
58
+ intermediate_size=16384,
59
+ )
60
+ norm_config = cfg.NormalizationConfig(
61
+ type=cfg.NormalizationType.RMS_NORM,
62
+ epsilon=1e-6,
63
+ zero_centered=True,
64
+ )
65
+ block_config = cfg.TransformerBlockConfig(
66
+ attn_config=attn_config,
67
+ ff_config=ff_config,
68
+ pre_attention_norm_config=norm_config,
69
+ post_attention_norm_config=norm_config,
70
+ )
71
+ config = cfg.ModelConfig(
72
+ vocab_size=257216,
73
+ num_layers=18,
74
+ max_seq_len=8192,
75
+ embedding_dim=2048,
76
+ embedding_scale=2048**0.5,
77
+ kv_cache_max_len=kv_cache_max_len,
78
+ block_configs=block_config,
79
+ final_norm_config=norm_config,
80
+ lm_head_use_bias=False,
81
+ enable_hlfb=True,
82
+ )
83
+ return config
84
+
85
+
86
+ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
87
+ config = get_decoder_config(kv_cache_max_len)
88
+ # PaliGemma decoder has only one block config.
89
+ config.block_config(0).ff_config.intermediate_size = 128
90
+ config.vocab_size = 128
91
+ config.num_layers = 2
92
+ config.max_seq_len = 2 * kv_cache_max_len
93
+ return config
94
+
95
+
96
+ def build_decoder(
97
+ checkpoint_path: str, **kwargs
98
+ ) -> model_builder.DecoderOnlyModel:
99
+ return model_builder.build_decoder_only_model(
100
+ checkpoint_path=checkpoint_path,
101
+ config=get_decoder_config(**kwargs),
102
+ tensor_names=TENSOR_NAMES,
103
+ )
@@ -0,0 +1,75 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored decoder of PaliGemma 3B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.paligemma import decoder
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+ _PROMPTS = flags.DEFINE_multi_string(
29
+ "prompts",
30
+ "What is the meaning of life?",
31
+ "The input prompts to generate answers.",
32
+ )
33
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
+ "max_new_tokens",
35
+ 30,
36
+ "The maximum size of the generated tokens.",
37
+ )
38
+
39
+
40
+ def main(_):
41
+ checkpoint = "google/paligemma-3b-mix-224"
42
+ logging.info("Loading the original model from: %s", checkpoint)
43
+ original_full_model = (
44
+ transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
45
+ )
46
+ original_language_model = original_full_model.eval().language_model
47
+
48
+ # Locate the cached dir.
49
+ cached_config_file = transformers.utils.cached_file(
50
+ checkpoint, transformers.utils.CONFIG_NAME
51
+ )
52
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
53
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
54
+ reauthored_model = decoder.build_decoder(reauthored_checkpoint)
55
+
56
+ logging.info("Loading the tokenizer from: %s", checkpoint)
57
+ # It works only when GemmaTokenizerFast is available. In some environments,
58
+ # use_fast=False doeesn't work either if the tokenizer cannot load the
59
+ # sentencepiece model file properly.
60
+ processor = transformers.AutoProcessor.from_pretrained(checkpoint)
61
+
62
+ verifier.verify_reauthored_model(
63
+ original_model=transformers_verifier.TransformersModelWrapper(
64
+ original_language_model
65
+ ),
66
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
67
+ tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
68
+ generate_prompts=_PROMPTS.value,
69
+ max_new_tokens=_MAX_NEW_TOKENS.value,
70
+ atol=1e-04,
71
+ )
72
+
73
+
74
+ if __name__ == "__main__":
75
+ app.run(main)
@@ -143,7 +143,7 @@ def verify_with_input_ids(
143
143
  kv_cache_max_len: int = 1024,
144
144
  rtol: float = 1e-05,
145
145
  atol: float = 1e-05,
146
- ) -> bool:
146
+ ):
147
147
  """Verifies if the model reauthored generates the same output of the oringal.
148
148
 
149
149
  It compares only one outputs from the original and the reauthored model.
@@ -157,8 +157,9 @@ def verify_with_input_ids(
157
157
  rtol (float): The relative tolerance for the comparison.
158
158
  atol (float): The absolute tolerance for the comparison.
159
159
 
160
- Returns:
161
- True if the model reauthored generates the same output of the original.
160
+ Raises:
161
+ AssertError if the model reauthored fails to generate the same output of the
162
+ original.
162
163
  """
163
164
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
164
165
  tokens[0, : len(input_ids)] = torch.tensor([input_ids]).int()
@@ -173,7 +174,7 @@ def verify_with_input_ids(
173
174
  logits_reauthored = outputs_reauthored[0, len(input_ids) - 1, :]
174
175
  logging.info("logits_reauthored: %s", logits_reauthored)
175
176
 
176
- return torch.allclose(
177
+ assert torch.allclose(
177
178
  logits_original, logits_reauthored, rtol=rtol, atol=atol
178
179
  )
179
180
 
@@ -184,7 +185,7 @@ def verify_model_with_prompts(
184
185
  tokenizer: TokenizerWrapper,
185
186
  prompts: str,
186
187
  max_new_tokens: int,
187
- ) -> bool:
188
+ ):
188
189
  """Verifies if the model reauthored generates the same answer of the oringal.
189
190
 
190
191
  It compares an answer, i.e. multiple continuous outputs generated by the
@@ -198,8 +199,9 @@ def verify_model_with_prompts(
198
199
  prompts (str): The input prompts to generate answers.
199
200
  max_new_tokens (int): The maximum number of new tokens to generate.
200
201
 
201
- Returns:
202
- True if the model reauthored generates the same answer of the original.
202
+ Raises:
203
+ AssertError if the model reauthored fails to generate the same answer of the
204
+ original.
203
205
  """
204
206
  prompt_tokens = tokenizer.encode(prompts)
205
207
 
@@ -213,7 +215,7 @@ def verify_model_with_prompts(
213
215
  response_reauthored = tokenizer.decode(outputs_reauthored[0])
214
216
  logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
215
217
 
216
- return response_original == response_reauthored
218
+ assert response_original == response_reauthored
217
219
 
218
220
 
219
221
  def verify_reauthored_model(
@@ -225,6 +227,7 @@ def verify_reauthored_model(
225
227
  forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
226
228
  rtol: float = 1e-05,
227
229
  atol: float = 1e-05,
230
+ continue_on_failure: bool = False,
228
231
  ):
229
232
  """Verifies the reauthored model against the original model.
230
233
 
@@ -247,21 +250,31 @@ def verify_reauthored_model(
247
250
  forward with.
248
251
  rtol (float): The relative tolerance for the comparison.
249
252
  atol (float): The absolute tolerance for the comparison.
253
+ continue_on_failure (bool): If True, it continues to verify the next prompt
254
+ or input IDs even if a previous one fails.
250
255
  """
251
256
  for input_ids in forward_input_ids:
252
257
  logging.info("Verifying the reauthored model with input IDs: %s", input_ids)
253
- if verify_with_input_ids(
254
- original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
255
- ):
256
- logging.info("PASS")
258
+ try:
259
+ verify_with_input_ids(
260
+ original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
261
+ )
262
+ except AssertionError as e:
263
+ logging.error("*** FAILED *** verify with input IDs: %s", input_ids)
264
+ if not continue_on_failure:
265
+ raise e
257
266
  else:
258
- logging.error("FAILED")
267
+ logging.info("*** PASSED *** verify with input IDs: %s", input_ids)
259
268
 
260
269
  for prompts in generate_prompts:
261
- logging.info("Verifying the reauthored model with prompts:%s", prompts)
262
- if verify_model_with_prompts(
263
- original_model, reauthored_model, tokenizer, prompts, max_new_tokens
264
- ):
265
- logging.info("PASS")
270
+ logging.info("Verifying the reauthored model with prompts: %s", prompts)
271
+ try:
272
+ verify_model_with_prompts(
273
+ original_model, reauthored_model, tokenizer, prompts, max_new_tokens
274
+ )
275
+ except AssertionError as e:
276
+ logging.error("*** FAILED *** verify with prompts: %s", prompts)
277
+ if not continue_on_failure:
278
+ raise e
266
279
  else:
267
- logging.error("FAILED")
280
+ logging.info("*** PASSED *** verify with prompts: %s", prompts)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241108"
16
+ __version__ = "0.3.0.dev20241109"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241108
3
+ Version: 0.3.0.dev20241109
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=sBOl0mgVPJtokiP8qTbTtY0R_qIaF0KNiALh7P3AJEk,706
6
+ ai_edge_torch/version.py,sha256=gXIkg8ND_yLshQAgj7V9B03CLweQwexHau-idN_Q-SQ,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -58,8 +58,11 @@ ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaX
58
58
  ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
59
59
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
60
60
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
61
- ai_edge_torch/generative/examples/openelm/openelm.py,sha256=JsrtuUY4q1Rovxsht2cGCuANUj1sUKnah6bAoSe8AoU,4387
61
+ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
62
62
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
63
+ ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
64
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=JSb9h3gcIh5oYrbLU6rI8OU8FzfWeTCFJT5XRWu4btE,3675
65
+ ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
63
66
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
64
67
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
65
68
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
@@ -139,7 +142,7 @@ ai_edge_torch/generative/utilities/model_builder.py,sha256=89jt80UUfDzYBi-x077HB
139
142
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
140
143
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
141
144
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
142
- ai_edge_torch/generative/utilities/verifier.py,sha256=wQ4EtIED_a6FRsaOXeoQVZiHNx07esOYCQYbDVLgZ2o,9520
145
+ ai_edge_torch/generative/utilities/verifier.py,sha256=5C2cm54d9kwL7nGRX-YfnBIJny1ICNhiU-LB3IqJq2E,10075
143
146
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
144
147
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
145
148
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -186,8 +189,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
186
189
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
187
190
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
188
191
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
189
- ai_edge_torch_nightly-0.3.0.dev20241108.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
190
- ai_edge_torch_nightly-0.3.0.dev20241108.dist-info/METADATA,sha256=gp2VN_X4YPdK8axZYIhqafgiJhCwfiN_tOWT-yL3lW0,1897
191
- ai_edge_torch_nightly-0.3.0.dev20241108.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
192
- ai_edge_torch_nightly-0.3.0.dev20241108.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
193
- ai_edge_torch_nightly-0.3.0.dev20241108.dist-info/RECORD,,
192
+ ai_edge_torch_nightly-0.3.0.dev20241109.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
193
+ ai_edge_torch_nightly-0.3.0.dev20241109.dist-info/METADATA,sha256=Bx7bAskkKsT43fx4RZoccotaUwx2QOPeSWUfnarpsU4,1897
194
+ ai_edge_torch_nightly-0.3.0.dev20241109.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
195
+ ai_edge_torch_nightly-0.3.0.dev20241109.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
196
+ ai_edge_torch_nightly-0.3.0.dev20241109.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.44.0)
2
+ Generator: bdist_wheel (0.45.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5