ai-edge-torch-nightly 0.3.0.dev20241211__py3-none-any.whl → 0.3.0.dev20241212__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -15,6 +15,8 @@
15
15
 
16
16
  """Example of building a decoder of PaliGemma 3B model which is Gemma1."""
17
17
 
18
+ from typing import Optional
19
+
18
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
19
21
  import ai_edge_torch.generative.layers.model_config as cfg
20
22
  from ai_edge_torch.generative.utilities import model_builder
@@ -51,6 +53,7 @@ class Decoder(model_builder.DecoderOnlyModel):
51
53
  input_pos: torch.Tensor,
52
54
  kv_cache: kv_utils.KVCache,
53
55
  input_embeds: torch.Tensor = None,
56
+ export_config: Optional[model_builder.ExportConfig] = None,
54
57
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
55
58
  if input_embeds is None:
56
59
  return super().forward(tokens, input_pos, kv_cache)
@@ -16,11 +16,13 @@
16
16
  """Example of building a full-stack of PaliGemma model."""
17
17
 
18
18
  from dataclasses import dataclass
19
+ from typing import Optional
19
20
 
20
21
  from ai_edge_torch.generative.examples.paligemma import decoder
21
22
  from ai_edge_torch.generative.examples.paligemma import image_encoder
22
23
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
23
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ from ai_edge_torch.generative.utilities import model_builder
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
25
27
  import torch
26
28
  from torch import nn
@@ -67,9 +69,16 @@ class PaliGemma(nn.Module):
67
69
  input_pos: torch.Tensor,
68
70
  kv_cache: kv_utils.KVCache,
69
71
  pixel_values: torch.Tensor = None,
72
+ export_config: Optional[model_builder.ExportConfig] = None,
70
73
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
71
74
  if pixel_values is None:
72
- return self.decoder(tokens, input_pos, kv_cache)
75
+ return self.decoder(
76
+ tokens=tokens,
77
+ input_pos=input_pos,
78
+ kv_cache=kv_cache,
79
+ input_embeds=None,
80
+ export_config=export_config
81
+ )
73
82
 
74
83
  input_embeds = self.decoder.tok_embedding(tokens)
75
84
 
@@ -100,6 +109,7 @@ class PaliGemma(nn.Module):
100
109
  input_pos=input_pos,
101
110
  kv_cache=kv_cache,
102
111
  input_embeds=input_embeds,
112
+ export_config=export_config,
103
113
  )
104
114
 
105
115
 
@@ -190,14 +190,12 @@ def group_norm_with_hlfb(
190
190
  """
191
191
  x = torch.permute(x, (0, 2, 3, 1))
192
192
 
193
- # TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
194
- # int32 when the bug is fixed.
195
193
  builder = StableHLOCompositeBuilder(
196
194
  name="odml.group_norm",
197
195
  attr={
198
196
  "num_groups": num_groups,
199
197
  "epsilon": eps,
200
- "reduction_axes": 3,
198
+ "reduction_axes": [3],
201
199
  "channel_axis": 3,
202
200
  },
203
201
  )
@@ -29,7 +29,7 @@ class TransformersModelWrapper(verifier.ModelWrapper):
29
29
  an object with `logits` field.
30
30
 
31
31
  Transformers models get `max_new_tokens` settings for generate() via
32
- ExportConfig.
32
+ GenerationConfig.
33
33
  """
34
34
 
35
35
  def forward(self, tokens: torch.Tensor) -> torch.Tensor:
@@ -38,5 +38,5 @@ class TransformersModelWrapper(verifier.ModelWrapper):
38
38
  def generate(
39
39
  self, inputs: torch.Tensor, max_new_tokens: int
40
40
  ) -> torch.IntTensor:
41
- export_config = transformers.ExportConfig(max_new_tokens=max_new_tokens)
42
- return self.model.generate(inputs=inputs, generation_config=export_config)
41
+ gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
42
+ return self.model.generate(inputs=inputs, generation_config=gen_config)
@@ -115,7 +115,7 @@ class ReauthoredModelWrapper(ModelWrapper):
115
115
  # pixel_values only when it is not None. Otherwise, it may raise an error.
116
116
  if pixel_values is None:
117
117
  output = self.model.forward(
118
- tokens, input_pos, kv_cache, self.export_config
118
+ tokens, input_pos, kv_cache, export_config=self.export_config
119
119
  )
120
120
  else:
121
121
  output = self.model.forward(
@@ -16,12 +16,15 @@ import functools
16
16
  import logging
17
17
 
18
18
  from ai_edge_torch.odml_torch import jax_bridge
19
+ from ai_edge_torch.odml_torch.lowerings import context
20
+ from ai_edge_torch.odml_torch.lowerings import registry
21
+ import jax.numpy as jnp
22
+ from jax._src.lib.mlir import ir
19
23
  import torch
20
24
  import torch_xla2.ops.jaten # Import to load torch_xla2 ops
21
25
  import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
22
26
 
23
- from . import registry
24
-
27
+ LoweringContext = context.LoweringContext
25
28
 
26
29
  @functools.cache
27
30
  def _log_usage(op):
@@ -258,3 +261,26 @@ def _aten_copy(self, *args, **kwargs):
258
261
  @lower_by_jax(torch.ops.aten.copy, ir_input_names=["src"])
259
262
  def _aten_copy(self, src, **kwargs):
260
263
  return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src)
264
+
265
+
266
+ # Schema:
267
+ # - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None)
268
+ # -> Tensor
269
+ # Torch Reference:
270
+ # - https://pytorch.org/docs/stable/generated/torch.einsum.html
271
+ # - https://github.com/pytorch/pytorch/blob/1b3f8b75896720e88362cbec7db32abc52afa83e/aten/src/ATen/native/Linear.cpp#L255
272
+ @registry.lower(torch.ops.aten.einsum.default)
273
+ def _aten_einsum_default(
274
+ lctx: LoweringContext,
275
+ equation: str,
276
+ tensors: list[ir.Value],
277
+ path=None,
278
+ ):
279
+ _log_usage(torch.ops.aten.einsum.default)
280
+
281
+ @jax_bridge.wrap
282
+ def jax_lowering(operands):
283
+ # Ignore the input path and let JAX determine the path.
284
+ return jnp.einsum(equation, *operands, optimize="optimal")
285
+
286
+ return jax_lowering(lctx, tuple(tensors))
@@ -46,7 +46,13 @@ def decompositions():
46
46
 
47
47
  torch._decomp.remove_decompositions(
48
48
  decompositions,
49
- [torch.ops.aten.roll],
49
+ [
50
+ torch.ops.aten.roll,
51
+ # Torch's default einsum impl/decompositions is less efficient and
52
+ # optimized through converter than JAX's impl. Disable einsum
53
+ # decomposition to use JAX bridge for a more efficient lowering.
54
+ torch.ops.aten.einsum.default,
55
+ ],
50
56
  )
51
57
 
52
58
  # Override _safe_softmax decompositions with regular softmax.
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241211"
16
+ __version__ = "0.3.0.dev20241212"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241211
3
+ Version: 0.3.0.dev20241212
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -29,7 +29,7 @@ Requires-Dist: safetensors
29
29
  Requires-Dist: tabulate
30
30
  Requires-Dist: torch>=2.4.0
31
31
  Requires-Dist: torch-xla>=2.4.0
32
- Requires-Dist: tf-nightly>=2.19.0.dev20241121
32
+ Requires-Dist: tf-nightly>=2.19.0.dev20241201
33
33
  Requires-Dist: ai-edge-litert-nightly
34
34
  Requires-Dist: ai-edge-quantizer-nightly
35
35
 
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=_uS2Df0H-aUbz-7M-gLxfjDVOJxr03EeNDfbVC_cBrE,706
6
+ ai_edge_torch/version.py,sha256=QnJ2_alMOUe5ea0vTpY7AIBr8eoHvuwKaaj917g5DFA,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -64,9 +64,9 @@ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2u
64
64
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
65
65
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
66
66
  ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=rPFqcsv8RHvjmgfBW9OL6EKxMtVX-ySjBsMP4N8FErk,2816
67
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=f_A3GWcLrP0nRq2Tq-fThfXIQVJ-EYWoExYLO_6iVIQ,4866
67
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=eICKQkJsJuEUkuvn5ymUsI9CGB-oNbgV7VH7BlmklfQ,4961
68
68
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
69
- ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=pIjsS-IUFevRjFA9153YT1vtWXATGWHsgVQQX_nWaZQ,5280
69
+ ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=nDyI-wUFJSawu57uLbFENei5l4cciqZ8lM5S5beN0FU,5604
70
70
  ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
71
71
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
72
72
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
@@ -121,7 +121,7 @@ ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbX
121
121
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
122
122
  ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
123
123
  ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
124
- ai_edge_torch/generative/layers/normalization.py,sha256=_2hps2m2MXEHQWbM-1B4he90hbq8wqOnIDIf-qXHhpc,7589
124
+ ai_edge_torch/generative/layers/normalization.py,sha256=h2btgRHMMjOcyLm8adEmcT0pG6imq4QcWblKJK5MYXA,7479
125
125
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
126
126
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
127
127
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -151,8 +151,8 @@ ai_edge_torch/generative/utilities/model_builder.py,sha256=rfD6INxunvDVdiUfTUxD7
151
151
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
152
152
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
153
153
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
154
- ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=nHmI27ybu7lj8Ufw2LzmCwRDqEwNppIFNTx5ltLHIgE,1547
155
- ai_edge_torch/generative/utilities/verifier.py,sha256=1NcmT_55Sb5e5spnHab4x5wqJZi2CKKVtXuXgK3lE6Q,11927
154
+ ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
155
+ ai_edge_torch/generative/utilities/verifier.py,sha256=ESSA8W1EYNsd4ntwmXbr-dn-BcIS27hf53XL5RTwjEU,11941
156
156
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
157
157
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
158
158
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -186,11 +186,11 @@ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=0GytV1dGnqe1mKityqQDNFNS8T
186
186
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=ufvnaAh6rM_yfoc8ybI3VErHEVBv5W_p4iOe9slfwKM,9948
187
187
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
188
188
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
189
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=4UyNyaR2W-vCOvj-P5lywQ1_RfLIxVE7J_GONI6CQvI,10718
189
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
190
190
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
191
191
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=rFmzqcdjYrwhcxH8j9zCFStPy21HFF7hkUV_GQ8FPAk,6056
192
192
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
193
- ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=aR6JPFP2Iq-aR0qPxJEHehmAVTjiGhgQEoycZV_1vPY,2130
193
+ ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
194
194
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
195
195
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
196
196
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
@@ -201,8 +201,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
201
201
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
202
202
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
203
203
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
204
- ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
205
- ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/METADATA,sha256=Lyub5vadYf6Yu6mGY7l1PFk8Jg2rB36ojIBHm9CxhBM,1897
206
- ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
207
- ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
208
- ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/RECORD,,
204
+ ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
205
+ ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/METADATA,sha256=tkJXZvoB1p4WKAKgK9Ql071JxwI7BwU3gKmdJR5jcrs,1897
206
+ ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
207
+ ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
208
+ ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/RECORD,,