ai-edge-torch-nightly 0.6.0.dev20250611__py3-none-any.whl → 0.6.0.dev20250613__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,6 +20,7 @@ from absl import app
20
20
  from absl import flags
21
21
  from ai_edge_torch.generative.examples.gemma import gemma1
22
22
  from ai_edge_torch.generative.examples.gemma import verify_util
23
+ from ai_edge_torch.generative.utilities import verifier
23
24
  import kagglehub
24
25
 
25
26
 
@@ -39,7 +40,9 @@ def main(_):
39
40
  checkpoint = kagglehub.model_download("google/gemma/pyTorch/2b-it")
40
41
 
41
42
  logging.info("Building the reauthored model from: %s", checkpoint)
42
- reauthored_model = gemma1.build_2b_model(checkpoint)
43
+ reauthored_model = gemma1.build_2b_model(
44
+ checkpoint, mask_cache_size=verifier.DEFAULT_KV_CACHE_MAX_LEN
45
+ )
43
46
 
44
47
  verify_util.verify_reauthored_gemma_model(
45
48
  checkpoint=checkpoint,
@@ -62,8 +62,13 @@ class GemmaWrapper(verifier.ModelWrapper):
62
62
  actual_input_len = self._get_actual_input_len(tokens)
63
63
  input_pos = torch.arange(0, actual_input_len, dtype=torch.long)
64
64
  mask_cache = attn_utils.build_causal_mask_cache(tokens.shape[1])
65
- local_mask_cache = attn_utils.build_sliding_window_mask_cache(
66
- tokens.shape[1], self.model.config.sliding_window_size)
65
+ local_mask_cache = (
66
+ attn_utils.build_sliding_window_mask_cache(
67
+ tokens.shape[1], self.model.config.sliding_window_size
68
+ )
69
+ if self.model.config.sliding_window_size
70
+ else None # Do not use local mask cache if sliding window size is None.
71
+ )
67
72
  _, logits = self.model.forward(
68
73
  input_token_ids=tokens[0, :actual_input_len].unsqueeze(0),
69
74
  input_positions=input_pos,
@@ -75,6 +80,8 @@ class GemmaWrapper(verifier.ModelWrapper):
75
80
  top_ps=torch.tensor([1.0], dtype=torch.float),
76
81
  top_ks=torch.tensor([1], dtype=torch.long),
77
82
  local_mask=local_mask_cache.index_select(2, input_pos)
83
+ if local_mask_cache
84
+ else None,
78
85
  )
79
86
  return logits
80
87
 
@@ -351,6 +351,7 @@ def verify_reauthored_model(
351
351
  )
352
352
  except AssertionError as e:
353
353
  logging.error("*** FAILED *** verify with input IDs: %s", input_ids)
354
+ logging.error("*** Assertion Error: %s", e)
354
355
  failure_count += 1
355
356
  if not continue_on_failure:
356
357
  return False
@@ -366,6 +367,7 @@ def verify_reauthored_model(
366
367
  )
367
368
  except AssertionError as e:
368
369
  logging.error("*** FAILED *** verify with prompts: %s", prompts)
370
+ logging.error("*** Assertion Error: %s", e)
369
371
  failure_count += 1
370
372
  if not continue_on_failure:
371
373
  return False
@@ -331,6 +331,18 @@ def _aten_sym_size_int(lctx, x: ir.Value, dim: int):
331
331
  return stablehlo.get_dimension_size(x, dim)
332
332
 
333
333
 
334
+ # Lowering for the subtraction operator (`-`).
335
+ # Handles cases where one operand is an integer (scalar) and the other is a
336
+ # tensor, broadcasting the scalar to the tensor's shape before subtraction.
337
+ @lower(operator.sub)
338
+ def _operator_sub(lctx, self: int | ir.Value, other: int | ir.Value):
339
+ if isinstance(self, int) and isinstance(other, ir.Value):
340
+ self = utils.splat(self, other.type.element_type, other.type.shape)
341
+ if isinstance(other, int) and isinstance(self, ir.Value):
342
+ other = utils.splat(other, self.type.element_type, self.type.shape)
343
+ return stablehlo.subtract(self, other)
344
+
345
+
334
346
  # Lowering for the multiplication operator (`*`).
335
347
  # Handles cases where one operand is an integer (scalar) and the other is a
336
348
  # tensor, broadcasting the scalar to the tensor's shape before multiplication.
ai_edge_torch/version.py CHANGED
@@ -15,4 +15,4 @@
15
15
 
16
16
  # The next version of ai-edge-torch.
17
17
  # The minor version code should be bumped after every release.
18
- __version__ = "0.6.0.dev20250611"
18
+ __version__ = "0.6.0.dev20250613"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.6.0.dev20250611
3
+ Version: 0.6.0.dev20250613
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
5
- ai_edge_torch/version.py,sha256=03SgV_6zF-xAWJAi3O3nHJh9aG0F05fy5C5nLfoj5-Q,806
5
+ ai_edge_torch/version.py,sha256=YTxxjYgksS0xe-Vy2RMt-V4C9dddyuKoIcLfmaTZBh4,806
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -67,9 +67,9 @@ ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=m5N3M
67
67
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=fR4869w1UZIuUVGLDdPof0IDEMDhqn2ej5kl7QW9vNA,1882
68
68
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=TH9XQAp5p4S829XbaWbJQZBwB18WizDRIQMsUkKqj38,3377
69
69
  ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=E6jotWYYIx6SXUWqurKWjiZpbfj_M2jJrBc2rQ90z1s,11782
70
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
70
+ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=USyist332qZkhCBof2tJwQSqtnKjTQsKAK_jCE_CO2U,1853
71
71
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=eAM7EVVMW-QCqjeZEss7TOkVKArgUs1La51LAC-5a9A,1962
72
- ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=vFhb3mOxB0Pps-Nx8FQWskh-Zly7N83n_EPyG22R-oM,8204
72
+ ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=b12naCj4zZxOjkIKrd08qovtajYuX-Ba3fbrv6kkDZs,8410
73
73
  ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
74
74
  ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=UEDNN3JmI31WfE2pvacxeJpqumKK86L2dEus3yTURaY,2114
75
75
  ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=1UVv9SFFg5degX3wf-Fefx7nor1AzJj2NWBVuo8bRnM,15540
@@ -218,7 +218,7 @@ ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yG
218
218
  ai_edge_torch/generative/utilities/test_utils.py,sha256=fhUMCMxoeMzxYbOCjNeX5wbQmF6Y88Hi52FtRiZYJAk,1147
219
219
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=l54bmmhj613eB2oCoONIAKEHhf8TQOhC9Gwjp6lxHAE,1659
220
220
  ai_edge_torch/generative/utilities/types.py,sha256=gZI9hIPB3XAo4oecKIIoVDfiyibLaSNFhecPFx4VDTM,2913
221
- ai_edge_torch/generative/utilities/verifier.py,sha256=lSzVBFkfnVcFr3rHYQqf1OxfxgQrevkF6jGdqmZTJHA,13702
221
+ ai_edge_torch/generative/utilities/verifier.py,sha256=HmpA5-q7s1a5WOoO5gMlfdlPp5oLeKwnx56n3DUOBBM,13806
222
222
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
223
223
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=JsVmYrM_JEuN_smMHXUsRlo3Liapp7UyktbPpPARwDk,5386
224
224
  ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=YCtMgu-4w2BQ5fpnlpWC6IauKPf_tVqc7Ff91OTqlSw,1796
@@ -248,7 +248,7 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNi
248
248
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
249
249
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
250
250
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
251
- ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=fEWjIdEpDIqT1EYLZE13O9A41OuaNdbfBrv3vNxS9gI,11601
251
+ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=VWb5HEeVljnuXi1eecKp1ieOIcBrSLlu7YIZnxnrozU,12198
252
252
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
253
253
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
254
254
  ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
@@ -268,8 +268,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
268
268
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
269
269
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
270
270
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
271
- ai_edge_torch_nightly-0.6.0.dev20250611.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
272
- ai_edge_torch_nightly-0.6.0.dev20250611.dist-info/METADATA,sha256=EPsBf0cXltLduemjqkavDrVuu4orHWptAtRfnnDGnoM,2074
273
- ai_edge_torch_nightly-0.6.0.dev20250611.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
274
- ai_edge_torch_nightly-0.6.0.dev20250611.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
275
- ai_edge_torch_nightly-0.6.0.dev20250611.dist-info/RECORD,,
271
+ ai_edge_torch_nightly-0.6.0.dev20250613.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
272
+ ai_edge_torch_nightly-0.6.0.dev20250613.dist-info/METADATA,sha256=WVeg8CPHQeHC3GNBou1kvlzF8zXyPtpSadLTayhpmew,2074
273
+ ai_edge_torch_nightly-0.6.0.dev20250613.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
274
+ ai_edge_torch_nightly-0.6.0.dev20250613.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
275
+ ai_edge_torch_nightly-0.6.0.dev20250613.dist-info/RECORD,,