ai-edge-torch-nightly 0.3.0.dev20241211__py3-none-any.whl → 0.3.0.dev20241212__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,6 +15,8 @@
15
15
 
16
16
  """Example of building a decoder of PaliGemma 3B model which is Gemma1."""
17
17
 
18
+ from typing import Optional
19
+
18
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
19
21
  import ai_edge_torch.generative.layers.model_config as cfg
20
22
  from ai_edge_torch.generative.utilities import model_builder
@@ -51,6 +53,7 @@ class Decoder(model_builder.DecoderOnlyModel):
51
53
  input_pos: torch.Tensor,
52
54
  kv_cache: kv_utils.KVCache,
53
55
  input_embeds: torch.Tensor = None,
56
+ export_config: Optional[model_builder.ExportConfig] = None,
54
57
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
55
58
  if input_embeds is None:
56
59
  return super().forward(tokens, input_pos, kv_cache)
@@ -16,11 +16,13 @@
16
16
  """Example of building a full-stack of PaliGemma model."""
17
17
 
18
18
  from dataclasses import dataclass
19
+ from typing import Optional
19
20
 
20
21
  from ai_edge_torch.generative.examples.paligemma import decoder
21
22
  from ai_edge_torch.generative.examples.paligemma import image_encoder
22
23
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
23
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ from ai_edge_torch.generative.utilities import model_builder
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
25
27
  import torch
26
28
  from torch import nn
@@ -67,9 +69,16 @@ class PaliGemma(nn.Module):
67
69
  input_pos: torch.Tensor,
68
70
  kv_cache: kv_utils.KVCache,
69
71
  pixel_values: torch.Tensor = None,
72
+ export_config: Optional[model_builder.ExportConfig] = None,
70
73
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
71
74
  if pixel_values is None:
72
- return self.decoder(tokens, input_pos, kv_cache)
75
+ return self.decoder(
76
+ tokens=tokens,
77
+ input_pos=input_pos,
78
+ kv_cache=kv_cache,
79
+ input_embeds=None,
80
+ export_config=export_config
81
+ )
73
82
 
74
83
  input_embeds = self.decoder.tok_embedding(tokens)
75
84
 
@@ -100,6 +109,7 @@ class PaliGemma(nn.Module):
100
109
  input_pos=input_pos,
101
110
  kv_cache=kv_cache,
102
111
  input_embeds=input_embeds,
112
+ export_config=export_config,
103
113
  )
104
114
 
105
115
 
@@ -190,14 +190,12 @@ def group_norm_with_hlfb(
190
190
  """
191
191
  x = torch.permute(x, (0, 2, 3, 1))
192
192
 
193
- # TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
194
- # int32 when the bug is fixed.
195
193
  builder = StableHLOCompositeBuilder(
196
194
  name="odml.group_norm",
197
195
  attr={
198
196
  "num_groups": num_groups,
199
197
  "epsilon": eps,
200
- "reduction_axes": 3,
198
+ "reduction_axes": [3],
201
199
  "channel_axis": 3,
202
200
  },
203
201
  )
@@ -29,7 +29,7 @@ class TransformersModelWrapper(verifier.ModelWrapper):
29
29
  an object with `logits` field.
30
30
 
31
31
  Transformers models get `max_new_tokens` settings for generate() via
32
- ExportConfig.
32
+ GenerationConfig.
33
33
  """
34
34
 
35
35
  def forward(self, tokens: torch.Tensor) -> torch.Tensor:
@@ -38,5 +38,5 @@ class TransformersModelWrapper(verifier.ModelWrapper):
38
38
  def generate(
39
39
  self, inputs: torch.Tensor, max_new_tokens: int
40
40
  ) -> torch.IntTensor:
41
- export_config = transformers.ExportConfig(max_new_tokens=max_new_tokens)
42
- return self.model.generate(inputs=inputs, generation_config=export_config)
41
+ gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
42
+ return self.model.generate(inputs=inputs, generation_config=gen_config)
@@ -115,7 +115,7 @@ class ReauthoredModelWrapper(ModelWrapper):
115
115
  # pixel_values only when it is not None. Otherwise, it may raise an error.
116
116
  if pixel_values is None:
117
117
  output = self.model.forward(
118
- tokens, input_pos, kv_cache, self.export_config
118
+ tokens, input_pos, kv_cache, export_config=self.export_config
119
119
  )
120
120
  else:
121
121
  output = self.model.forward(
@@ -16,12 +16,15 @@ import functools
16
16
  import logging
17
17
 
18
18
  from ai_edge_torch.odml_torch import jax_bridge
19
+ from ai_edge_torch.odml_torch.lowerings import context
20
+ from ai_edge_torch.odml_torch.lowerings import registry
21
+ import jax.numpy as jnp
22
+ from jax._src.lib.mlir import ir
19
23
  import torch
20
24
  import torch_xla2.ops.jaten # Import to load torch_xla2 ops
21
25
  import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
22
26
 
23
- from . import registry
24
-
27
+ LoweringContext = context.LoweringContext
25
28
 
26
29
  @functools.cache
27
30
  def _log_usage(op):
@@ -258,3 +261,26 @@ def _aten_copy(self, *args, **kwargs):
258
261
  @lower_by_jax(torch.ops.aten.copy, ir_input_names=["src"])
259
262
  def _aten_copy(self, src, **kwargs):
260
263
  return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src)
264
+
265
+
266
+ # Schema:
267
+ # - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None)
268
+ # -> Tensor
269
+ # Torch Reference:
270
+ # - https://pytorch.org/docs/stable/generated/torch.einsum.html
271
+ # - https://github.com/pytorch/pytorch/blob/1b3f8b75896720e88362cbec7db32abc52afa83e/aten/src/ATen/native/Linear.cpp#L255
272
+ @registry.lower(torch.ops.aten.einsum.default)
273
+ def _aten_einsum_default(
274
+ lctx: LoweringContext,
275
+ equation: str,
276
+ tensors: list[ir.Value],
277
+ path=None,
278
+ ):
279
+ _log_usage(torch.ops.aten.einsum.default)
280
+
281
+ @jax_bridge.wrap
282
+ def jax_lowering(operands):
283
+ # Ignore the input path and let JAX determine the path.
284
+ return jnp.einsum(equation, *operands, optimize="optimal")
285
+
286
+ return jax_lowering(lctx, tuple(tensors))
@@ -46,7 +46,13 @@ def decompositions():
46
46
 
47
47
  torch._decomp.remove_decompositions(
48
48
  decompositions,
49
- [torch.ops.aten.roll],
49
+ [
50
+ torch.ops.aten.roll,
51
+ # Torch's default einsum impl/decompositions is less efficient and
52
+ # optimized through converter than JAX's impl. Disable einsum
53
+ # decomposition to use JAX bridge for a more efficient lowering.
54
+ torch.ops.aten.einsum.default,
55
+ ],
50
56
  )
51
57
 
52
58
  # Override _safe_softmax decompositions with regular softmax.
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241211"
16
+ __version__ = "0.3.0.dev20241212"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241211
3
+ Version: 0.3.0.dev20241212
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -29,7 +29,7 @@ Requires-Dist: safetensors
29
29
  Requires-Dist: tabulate
30
30
  Requires-Dist: torch>=2.4.0
31
31
  Requires-Dist: torch-xla>=2.4.0
32
- Requires-Dist: tf-nightly>=2.19.0.dev20241121
32
+ Requires-Dist: tf-nightly>=2.19.0.dev20241201
33
33
  Requires-Dist: ai-edge-litert-nightly
34
34
  Requires-Dist: ai-edge-quantizer-nightly
35
35
 
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=_uS2Df0H-aUbz-7M-gLxfjDVOJxr03EeNDfbVC_cBrE,706
6
+ ai_edge_torch/version.py,sha256=QnJ2_alMOUe5ea0vTpY7AIBr8eoHvuwKaaj917g5DFA,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -64,9 +64,9 @@ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2u
64
64
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
65
65
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
66
66
  ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=rPFqcsv8RHvjmgfBW9OL6EKxMtVX-ySjBsMP4N8FErk,2816
67
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=f_A3GWcLrP0nRq2Tq-fThfXIQVJ-EYWoExYLO_6iVIQ,4866
67
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=eICKQkJsJuEUkuvn5ymUsI9CGB-oNbgV7VH7BlmklfQ,4961
68
68
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
69
- ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=pIjsS-IUFevRjFA9153YT1vtWXATGWHsgVQQX_nWaZQ,5280
69
+ ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=nDyI-wUFJSawu57uLbFENei5l4cciqZ8lM5S5beN0FU,5604
70
70
  ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
71
71
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
72
72
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
@@ -121,7 +121,7 @@ ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbX
121
121
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
122
122
  ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
123
123
  ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
124
- ai_edge_torch/generative/layers/normalization.py,sha256=_2hps2m2MXEHQWbM-1B4he90hbq8wqOnIDIf-qXHhpc,7589
124
+ ai_edge_torch/generative/layers/normalization.py,sha256=h2btgRHMMjOcyLm8adEmcT0pG6imq4QcWblKJK5MYXA,7479
125
125
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
126
126
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
127
127
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -151,8 +151,8 @@ ai_edge_torch/generative/utilities/model_builder.py,sha256=rfD6INxunvDVdiUfTUxD7
151
151
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
152
152
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
153
153
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
154
- ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=nHmI27ybu7lj8Ufw2LzmCwRDqEwNppIFNTx5ltLHIgE,1547
155
- ai_edge_torch/generative/utilities/verifier.py,sha256=1NcmT_55Sb5e5spnHab4x5wqJZi2CKKVtXuXgK3lE6Q,11927
154
+ ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
155
+ ai_edge_torch/generative/utilities/verifier.py,sha256=ESSA8W1EYNsd4ntwmXbr-dn-BcIS27hf53XL5RTwjEU,11941
156
156
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
157
157
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
158
158
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -186,11 +186,11 @@ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=0GytV1dGnqe1mKityqQDNFNS8T
186
186
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=ufvnaAh6rM_yfoc8ybI3VErHEVBv5W_p4iOe9slfwKM,9948
187
187
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
188
188
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
189
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=4UyNyaR2W-vCOvj-P5lywQ1_RfLIxVE7J_GONI6CQvI,10718
189
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
190
190
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
191
191
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=rFmzqcdjYrwhcxH8j9zCFStPy21HFF7hkUV_GQ8FPAk,6056
192
192
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
193
- ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=aR6JPFP2Iq-aR0qPxJEHehmAVTjiGhgQEoycZV_1vPY,2130
193
+ ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
194
194
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
195
195
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
196
196
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
@@ -201,8 +201,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
201
201
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
202
202
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
203
203
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
204
- ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
205
- ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/METADATA,sha256=Lyub5vadYf6Yu6mGY7l1PFk8Jg2rB36ojIBHm9CxhBM,1897
206
- ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
207
- ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
208
- ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/RECORD,,
204
+ ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
205
+ ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/METADATA,sha256=tkJXZvoB1p4WKAKgK9Ql071JxwI7BwU3gKmdJR5jcrs,1897
206
+ ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
207
+ ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
208
+ ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/RECORD,,