ai-edge-torch-nightly 0.6.0.dev20250611__py3-none-any.whl → 0.6.0.dev20250613__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +4 -1
- ai_edge_torch/generative/examples/gemma/verify_util.py +9 -2
- ai_edge_torch/generative/utilities/verifier.py +2 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +12 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250611.dist-info → ai_edge_torch_nightly-0.6.0.dev20250613.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250611.dist-info → ai_edge_torch_nightly-0.6.0.dev20250613.dist-info}/RECORD +10 -10
- {ai_edge_torch_nightly-0.6.0.dev20250611.dist-info → ai_edge_torch_nightly-0.6.0.dev20250613.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250611.dist-info → ai_edge_torch_nightly-0.6.0.dev20250613.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250611.dist-info → ai_edge_torch_nightly-0.6.0.dev20250613.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ from absl import app
|
|
20
20
|
from absl import flags
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
22
22
|
from ai_edge_torch.generative.examples.gemma import verify_util
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
23
24
|
import kagglehub
|
24
25
|
|
25
26
|
|
@@ -39,7 +40,9 @@ def main(_):
|
|
39
40
|
checkpoint = kagglehub.model_download("google/gemma/pyTorch/2b-it")
|
40
41
|
|
41
42
|
logging.info("Building the reauthored model from: %s", checkpoint)
|
42
|
-
reauthored_model = gemma1.build_2b_model(
|
43
|
+
reauthored_model = gemma1.build_2b_model(
|
44
|
+
checkpoint, mask_cache_size=verifier.DEFAULT_KV_CACHE_MAX_LEN
|
45
|
+
)
|
43
46
|
|
44
47
|
verify_util.verify_reauthored_gemma_model(
|
45
48
|
checkpoint=checkpoint,
|
@@ -62,8 +62,13 @@ class GemmaWrapper(verifier.ModelWrapper):
|
|
62
62
|
actual_input_len = self._get_actual_input_len(tokens)
|
63
63
|
input_pos = torch.arange(0, actual_input_len, dtype=torch.long)
|
64
64
|
mask_cache = attn_utils.build_causal_mask_cache(tokens.shape[1])
|
65
|
-
local_mask_cache =
|
66
|
-
|
65
|
+
local_mask_cache = (
|
66
|
+
attn_utils.build_sliding_window_mask_cache(
|
67
|
+
tokens.shape[1], self.model.config.sliding_window_size
|
68
|
+
)
|
69
|
+
if self.model.config.sliding_window_size
|
70
|
+
else None # Do not use local mask cache if sliding window size is None.
|
71
|
+
)
|
67
72
|
_, logits = self.model.forward(
|
68
73
|
input_token_ids=tokens[0, :actual_input_len].unsqueeze(0),
|
69
74
|
input_positions=input_pos,
|
@@ -75,6 +80,8 @@ class GemmaWrapper(verifier.ModelWrapper):
|
|
75
80
|
top_ps=torch.tensor([1.0], dtype=torch.float),
|
76
81
|
top_ks=torch.tensor([1], dtype=torch.long),
|
77
82
|
local_mask=local_mask_cache.index_select(2, input_pos)
|
83
|
+
if local_mask_cache
|
84
|
+
else None,
|
78
85
|
)
|
79
86
|
return logits
|
80
87
|
|
@@ -351,6 +351,7 @@ def verify_reauthored_model(
|
|
351
351
|
)
|
352
352
|
except AssertionError as e:
|
353
353
|
logging.error("*** FAILED *** verify with input IDs: %s", input_ids)
|
354
|
+
logging.error("*** Assertion Error: %s", e)
|
354
355
|
failure_count += 1
|
355
356
|
if not continue_on_failure:
|
356
357
|
return False
|
@@ -366,6 +367,7 @@ def verify_reauthored_model(
|
|
366
367
|
)
|
367
368
|
except AssertionError as e:
|
368
369
|
logging.error("*** FAILED *** verify with prompts: %s", prompts)
|
370
|
+
logging.error("*** Assertion Error: %s", e)
|
369
371
|
failure_count += 1
|
370
372
|
if not continue_on_failure:
|
371
373
|
return False
|
@@ -331,6 +331,18 @@ def _aten_sym_size_int(lctx, x: ir.Value, dim: int):
|
|
331
331
|
return stablehlo.get_dimension_size(x, dim)
|
332
332
|
|
333
333
|
|
334
|
+
# Lowering for the subtraction operator (`-`).
|
335
|
+
# Handles cases where one operand is an integer (scalar) and the other is a
|
336
|
+
# tensor, broadcasting the scalar to the tensor's shape before subtraction.
|
337
|
+
@lower(operator.sub)
|
338
|
+
def _operator_sub(lctx, self: int | ir.Value, other: int | ir.Value):
|
339
|
+
if isinstance(self, int) and isinstance(other, ir.Value):
|
340
|
+
self = utils.splat(self, other.type.element_type, other.type.shape)
|
341
|
+
if isinstance(other, int) and isinstance(self, ir.Value):
|
342
|
+
other = utils.splat(other, self.type.element_type, self.type.shape)
|
343
|
+
return stablehlo.subtract(self, other)
|
344
|
+
|
345
|
+
|
334
346
|
# Lowering for the multiplication operator (`*`).
|
335
347
|
# Handles cases where one operand is an integer (scalar) and the other is a
|
336
348
|
# tensor, broadcasting the scalar to the tensor's shape before multiplication.
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.6.0.
|
3
|
+
Version: 0.6.0.dev20250613
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=YTxxjYgksS0xe-Vy2RMt-V4C9dddyuKoIcLfmaTZBh4,806
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -67,9 +67,9 @@ ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=m5N3M
|
|
67
67
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=fR4869w1UZIuUVGLDdPof0IDEMDhqn2ej5kl7QW9vNA,1882
|
68
68
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=TH9XQAp5p4S829XbaWbJQZBwB18WizDRIQMsUkKqj38,3377
|
69
69
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=E6jotWYYIx6SXUWqurKWjiZpbfj_M2jJrBc2rQ90z1s,11782
|
70
|
-
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=
|
70
|
+
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=USyist332qZkhCBof2tJwQSqtnKjTQsKAK_jCE_CO2U,1853
|
71
71
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=eAM7EVVMW-QCqjeZEss7TOkVKArgUs1La51LAC-5a9A,1962
|
72
|
-
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=
|
72
|
+
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=b12naCj4zZxOjkIKrd08qovtajYuX-Ba3fbrv6kkDZs,8410
|
73
73
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
74
74
|
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=UEDNN3JmI31WfE2pvacxeJpqumKK86L2dEus3yTURaY,2114
|
75
75
|
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=1UVv9SFFg5degX3wf-Fefx7nor1AzJj2NWBVuo8bRnM,15540
|
@@ -218,7 +218,7 @@ ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yG
|
|
218
218
|
ai_edge_torch/generative/utilities/test_utils.py,sha256=fhUMCMxoeMzxYbOCjNeX5wbQmF6Y88Hi52FtRiZYJAk,1147
|
219
219
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=l54bmmhj613eB2oCoONIAKEHhf8TQOhC9Gwjp6lxHAE,1659
|
220
220
|
ai_edge_torch/generative/utilities/types.py,sha256=gZI9hIPB3XAo4oecKIIoVDfiyibLaSNFhecPFx4VDTM,2913
|
221
|
-
ai_edge_torch/generative/utilities/verifier.py,sha256=
|
221
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=HmpA5-q7s1a5WOoO5gMlfdlPp5oLeKwnx56n3DUOBBM,13806
|
222
222
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
223
223
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=JsVmYrM_JEuN_smMHXUsRlo3Liapp7UyktbPpPARwDk,5386
|
224
224
|
ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=YCtMgu-4w2BQ5fpnlpWC6IauKPf_tVqc7Ff91OTqlSw,1796
|
@@ -248,7 +248,7 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNi
|
|
248
248
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
249
249
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
250
250
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
|
251
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
251
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=VWb5HEeVljnuXi1eecKp1ieOIcBrSLlu7YIZnxnrozU,12198
|
252
252
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
253
253
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
254
254
|
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
|
@@ -268,8 +268,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
268
268
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
269
269
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
270
270
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
271
|
-
ai_edge_torch_nightly-0.6.0.
|
272
|
-
ai_edge_torch_nightly-0.6.0.
|
273
|
-
ai_edge_torch_nightly-0.6.0.
|
274
|
-
ai_edge_torch_nightly-0.6.0.
|
275
|
-
ai_edge_torch_nightly-0.6.0.
|
271
|
+
ai_edge_torch_nightly-0.6.0.dev20250613.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
272
|
+
ai_edge_torch_nightly-0.6.0.dev20250613.dist-info/METADATA,sha256=WVeg8CPHQeHC3GNBou1kvlzF8zXyPtpSadLTayhpmew,2074
|
273
|
+
ai_edge_torch_nightly-0.6.0.dev20250613.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
274
|
+
ai_edge_torch_nightly-0.6.0.dev20250613.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
275
|
+
ai_edge_torch_nightly-0.6.0.dev20250613.dist-info/RECORD,,
|
File without changes
|
File without changes
|