ai-edge-torch-nightly 0.4.0.dev20250328__py3-none-any.whl → 0.4.0.dev20250330__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,90 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Gemma3 model."""
17
+
18
+ import glob
19
+ import logging
20
+ import os
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.gemma3 import verify_util
24
+ import kagglehub
25
+
26
+
27
+ _PROMPTS = flags.DEFINE_multi_string(
28
+ "prompts",
29
+ "What is the meaning of life?",
30
+ "The input prompts to generate answers.",
31
+ )
32
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
33
+ "max_new_tokens",
34
+ 30,
35
+ "The maximum size of the generated tokens.",
36
+ )
37
+ _CHECKPOINT = flags.DEFINE_string(
38
+ "checkpoint",
39
+ "",
40
+ "The checkpoint to verify.",
41
+ )
42
+ _VARIANT = flags.DEFINE_string(
43
+ "variant",
44
+ "1b",
45
+ "The variant of the model to verify.",
46
+ )
47
+ _WEIGHT_FILENAME = flags.DEFINE_string(
48
+ "weight_filename",
49
+ None,
50
+ "The weightfilename of the model to verify.",
51
+ )
52
+
53
+
54
+ def find_first_ckpt(folder):
55
+ """Finds the first .ckpt file in a folder."""
56
+ ckpt_files = sorted(glob.glob(os.path.join(folder, "*.ckpt")))
57
+ return os.path.basename(ckpt_files[0]) if ckpt_files else None
58
+
59
+
60
+ def main(_):
61
+ if _CHECKPOINT.value:
62
+ checkpoint = _CHECKPOINT.value
63
+ else:
64
+ checkpoint = kagglehub.model_download(
65
+ "google/gemma-3/pyTorch/gemma-3-1b-it"
66
+ )
67
+
68
+ # If the weight filename is not specified, use the first checkpoint.
69
+ if _WEIGHT_FILENAME.value is None:
70
+ weight_filename = find_first_ckpt(checkpoint)
71
+ logging.info(
72
+ "NOTE: using the first weight file `%s` from `%s`",
73
+ weight_filename,
74
+ checkpoint,
75
+ )
76
+ else:
77
+ weight_filename = _WEIGHT_FILENAME.value
78
+
79
+ # Verify the reauthored model by comparing the outputs with the original one.
80
+ verify_util.verify_gemma3(
81
+ checkpoint,
82
+ _PROMPTS.value,
83
+ _MAX_NEW_TOKENS.value,
84
+ _VARIANT.value,
85
+ weight_filename,
86
+ )
87
+
88
+
89
+ if __name__ == "__main__":
90
+ app.run(main)
@@ -0,0 +1,247 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utility functions to verify the reauthored Gemma model."""
17
+
18
+ import logging
19
+ import os
20
+ from typing import List, Optional, Tuple
21
+
22
+ from ai_edge_torch.generative.examples.gemma3 import gemma3
23
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
+ from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
25
+ from ai_edge_torch.generative.utilities.experimental import verifier
26
+ from gemma import config as gemma_config
27
+ from gemma import model as gemma_model
28
+ import torch
29
+
30
+
31
+ def _get_actual_input_len(tokens: torch.Tensor) -> int:
32
+ for i in range(tokens.shape[1]):
33
+ if tokens[0, i] == 0:
34
+ return i
35
+ return tokens.shape[1]
36
+
37
+
38
+ class GemmaWrapper(verifier.ModelWrapper):
39
+ """Gemma model wrapper for verification.
40
+
41
+ Verifier calls model.forward() with maxium sequence length (1024) expecting
42
+ the output is logits while Gemma gets the input tokens with the actual length
43
+ and returns logits in a tuple.
44
+
45
+ Verifier runs tokenizer before model.generate() while Gemma runs the tokenizer
46
+ inside model.generate().
47
+ """
48
+
49
+ def _get_kv_caches(
50
+ self, max_seq_len: int
51
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
52
+ config = self.model.config
53
+ cache_size = (1, max_seq_len, config.num_key_value_heads, config.head_dim)
54
+ cache = torch.zeros(cache_size)
55
+ return [
56
+ (cache.clone(), cache.clone()) for _ in range(config.num_hidden_layers)
57
+ ]
58
+
59
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
60
+ """Forwards the model after reducing input tokens to the actual length."""
61
+ actual_input_len = _get_actual_input_len(tokens)
62
+ input_pos = torch.arange(0, actual_input_len, dtype=torch.long)
63
+ mask_cache = attn_utils.build_causal_mask_cache(tokens.shape[1])
64
+ local_mask_cache = attn_utils.build_sliding_window_mask_cache(
65
+ tokens.shape[1], self.model.config.sliding_window_size
66
+ )
67
+ _, logits = self.model.forward(
68
+ input_token_ids=tokens[0, :actual_input_len].unsqueeze(0),
69
+ input_positions=input_pos,
70
+ kv_write_indices=None,
71
+ kv_caches=self._get_kv_caches(tokens.shape[1]),
72
+ mask=mask_cache.index_select(2, input_pos),
73
+ output_positions=input_pos,
74
+ temperatures=None,
75
+ top_ps=torch.tensor([1.0], dtype=torch.float),
76
+ top_ks=torch.tensor([1], dtype=torch.long),
77
+ local_mask=local_mask_cache.index_select(2, input_pos),
78
+ )
79
+ return logits
80
+
81
+ def generate(
82
+ self, tokens: torch.Tensor, max_new_tokens: int
83
+ ) -> torch.IntTensor:
84
+ """Generates the response after decoding the tokens into a string."""
85
+ prompts = self.model.tokenizer.decode(tokens[0].tolist())
86
+ response = self.model.generate(
87
+ prompts, device="cpu", output_len=max_new_tokens, top_k=1
88
+ )
89
+ return torch.tensor([self.model.tokenizer.encode(prompts + response)])
90
+
91
+
92
+ class UnifiedGemma3Wrapper(verifier.ReauthoredModelWrapper):
93
+ """Unified Gemma3 model wrapper for verification."""
94
+
95
+ def _init_kv_cache(self):
96
+ """Returns an initialized KV cache."""
97
+ return kv_utils.KVCacheTransposed.from_model_config(self.model.model.config)
98
+
99
+ def forward(
100
+ self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
101
+ ) -> torch.Tensor:
102
+ """Forwards the model."""
103
+ mask = attn_utils.build_causal_mask_cache(
104
+ self.model.model.config.kv_cache_max_len
105
+ )
106
+ input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
107
+ mask = mask.index_select(2, input_pos)
108
+ output = self.model.model.forward(
109
+ tokens, input_pos, self._init_kv_cache(), mask=mask
110
+ )
111
+ return output["logits"]
112
+
113
+ def generate(
114
+ self,
115
+ prompts: torch.Tensor,
116
+ max_new_tokens: int,
117
+ pixel_values: torch.Tensor = None,
118
+ eos_token_id: Optional[int] = None,
119
+ ) -> torch.IntTensor:
120
+ """Generates the response."""
121
+ input_ids = prompts[0].int().tolist()
122
+ tokens = torch.tensor([input_ids])
123
+ input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
124
+ kv_cache = self._init_kv_cache()
125
+ mask_cache = attn_utils.build_causal_mask_cache(
126
+ self.model.model.config.kv_cache_max_len
127
+ )
128
+ for _ in range(max_new_tokens):
129
+ mask = mask_cache.index_select(2, input_pos)
130
+ output = self.model.model.forward(
131
+ tokens, input_pos, kv_cache, mask=mask
132
+ )
133
+ logits, kv_cache = output["logits"], output["kv_cache"]
134
+ generated_token = logits[0][-1].argmax().item()
135
+ input_ids.append(generated_token)
136
+ if eos_token_id is not None and generated_token == eos_token_id:
137
+ break
138
+ tokens = torch.tensor([[generated_token]])
139
+ input_pos = torch.tensor([len(input_ids) - 1])
140
+ return torch.tensor([input_ids])
141
+
142
+
143
+ class GemmaTokenizerWrapper(verifier.TokenizerWrapper):
144
+ """Tokenizer wrapper for verification.
145
+
146
+ Verifier expects the tokenizer to handle tokens in torch.Tensor while Gemma
147
+ tokenizer expects tokens in a list.
148
+ """
149
+
150
+ def encode(self, text: str, **_) -> torch.Tensor:
151
+ """Adds one more dimension to the output of the tokenizer."""
152
+ return torch.tensor([self.tokenizer.encode(text)])
153
+
154
+ def decode(self, tokens: torch.Tensor) -> str:
155
+ """Decodes the token sequence after converting to a list."""
156
+ return self.tokenizer.decode(tokens.tolist())
157
+
158
+
159
+ def verify_reauthored_gemma_model(
160
+ checkpoint: str,
161
+ variant: str,
162
+ reauthored_model: torch.nn.Module,
163
+ generate_prompts: List[str],
164
+ forward_input_ids: List[List[int]],
165
+ weight_filename: str,
166
+ tokenizer_filename: str = "tokenizer.model",
167
+ max_new_tokens: int = 20,
168
+ rtol: float = 1e-05,
169
+ atol: float = 1e-05,
170
+ ) -> bool:
171
+ """Verifies the reauthored Gemma model against the original model.
172
+
173
+ Args:
174
+ checkpoint: Path to the Gemma checkpoint.
175
+ variant: Gemma model variant.
176
+ reauthored_model: The reauthored model to verify.
177
+ generate_prompts: List of prompts for generation.
178
+ forward_input_ids: List of input ids for forward pass.
179
+ weight_filename: Name of the weight file.
180
+ tokenizer_filename: Name of the tokenizer file.
181
+ max_new_tokens: Maximum number of new tokens to generate.
182
+ rtol: Relative tolerance for comparison.
183
+ atol: Absolute tolerance for comparison.
184
+
185
+ Returns:
186
+ True if the verification passes, False otherwise.
187
+ """
188
+ config = gemma_config.get_model_config(variant)
189
+ config.tokenizer = os.path.join(checkpoint, tokenizer_filename)
190
+ # Use float32 to be compatible with the reauthored model.
191
+ config.dtype = torch.float32
192
+
193
+ logging.info("Loading the original model from: %s", checkpoint)
194
+ original_model = gemma_model.GemmaForCausalLM(config).eval()
195
+ original_model.load_weights(os.path.join(checkpoint, weight_filename))
196
+
197
+ return verifier.verify_reauthored_model(
198
+ original_model=GemmaWrapper(original_model),
199
+ reauthored_model=UnifiedGemma3Wrapper(reauthored_model),
200
+ tokenizer=GemmaTokenizerWrapper(original_model.tokenizer),
201
+ generate_prompts=generate_prompts,
202
+ max_new_tokens=max_new_tokens,
203
+ forward_input_ids=forward_input_ids,
204
+ rtol=rtol,
205
+ atol=atol,
206
+ )
207
+
208
+
209
+ def verify_gemma3(
210
+ checkpoint: str,
211
+ prompts: List[str],
212
+ max_new_tokens: int,
213
+ variant: str,
214
+ weight_filename: str,
215
+ ) -> bool:
216
+ """Verifies the reauthored Gemma3 model.
217
+
218
+ Args:
219
+ checkpoint: Path to the Gemma checkpoint.
220
+ prompts: List of prompts for generation.
221
+ max_new_tokens: Maximum number of new tokens to generate.
222
+ variant: Gemma model variant.
223
+ weight_filename: Name of the weight file.
224
+
225
+ Returns:
226
+ True if the verification passes, False otherwise.
227
+ """
228
+ gemma3_model_path = os.path.join(checkpoint, weight_filename)
229
+ logging.info("Building the reauthored model from: %s", gemma3_model_path)
230
+
231
+ if variant == "1b":
232
+ reauthored_model = UnifiedGemma3Wrapper(
233
+ gemma3.build_model_1b(gemma3_model_path)
234
+ )
235
+ else:
236
+ raise ValueError(f"Unsupported Gemma3 variant: {variant}")
237
+
238
+ return verify_reauthored_gemma_model(
239
+ checkpoint=checkpoint,
240
+ variant=variant,
241
+ reauthored_model=reauthored_model,
242
+ generate_prompts=prompts,
243
+ forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
244
+ max_new_tokens=max_new_tokens,
245
+ weight_filename=weight_filename,
246
+ atol=1e-04,
247
+ )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.4.0.dev20250328"
16
+ __version__ = "0.4.0.dev20250330"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.4.0.dev20250328
3
+ Version: 0.4.0.dev20250330
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=IyhMWqN-g3wNhaYTXhegaL93NTmZkKXsXS6yx4E2kko,706
5
+ ai_edge_torch/version.py,sha256=vjzDDfl72IBJApHHN-UvaZwIZohudx1WyOQK2VwUXa4,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -69,6 +69,8 @@ ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=xAjM
69
69
  ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=4Vf1zA94qLyNzj9iLU0jrd3kzFFZXft4uiItoIBjKyM,15632
70
70
  ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=NQzqZ55cmC8tGlZ1SKkDeD0Su8mZ79KiazCS8X08xUY,6473
71
71
  ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
72
+ ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
73
+ ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=u30qiZu3HJCTt5noWqtf9PgGLKQ87ke4Zpa4cpG6-As,8883
72
74
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
73
75
  ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=tMSsqg7LU3LR-PHtKvlWtLCqlk71mfcO9hANU4vnvDM,2734
74
76
  ai_edge_torch/generative/examples/llama/llama.py,sha256=UKvMO85_5z1vEY5MVu6QBW_vpQYA8LWHbJI4Yx6BrCc,6592
@@ -240,8 +242,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
240
242
  ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
241
243
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
242
244
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
243
- ai_edge_torch_nightly-0.4.0.dev20250328.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
244
- ai_edge_torch_nightly-0.4.0.dev20250328.dist-info/METADATA,sha256=j22JbcB95xcu4aqu-G4NRbh1NxwLi9GpPjJjvpLsaSE,1966
245
- ai_edge_torch_nightly-0.4.0.dev20250328.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
246
- ai_edge_torch_nightly-0.4.0.dev20250328.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
247
- ai_edge_torch_nightly-0.4.0.dev20250328.dist-info/RECORD,,
245
+ ai_edge_torch_nightly-0.4.0.dev20250330.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
246
+ ai_edge_torch_nightly-0.4.0.dev20250330.dist-info/METADATA,sha256=WJm9GPjx-Tg4b_ABbbP7hVrvLR5cgcR350UBcuI2j9A,1966
247
+ ai_edge_torch_nightly-0.4.0.dev20250330.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
248
+ ai_edge_torch_nightly-0.4.0.dev20250330.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
249
+ ai_edge_torch_nightly-0.4.0.dev20250330.dist-info/RECORD,,