ai-edge-torch-nightly 0.3.0.dev20241211__py3-none-any.whl → 0.3.0.dev20241212__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/paligemma/decoder.py +3 -0
- ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
- ai_edge_torch/generative/layers/normalization.py +1 -3
- ai_edge_torch/generative/utilities/transformers_verifier.py +3 -3
- ai_edge_torch/generative/utilities/verifier.py +1 -1
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +7 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241211.dist-info → ai_edge_torch_nightly-0.3.0.dev20241212.dist-info}/METADATA +2 -2
- {ai_edge_torch_nightly-0.3.0.dev20241211.dist-info → ai_edge_torch_nightly-0.3.0.dev20241212.dist-info}/RECORD +13 -13
- {ai_edge_torch_nightly-0.3.0.dev20241211.dist-info → ai_edge_torch_nightly-0.3.0.dev20241212.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241211.dist-info → ai_edge_torch_nightly-0.3.0.dev20241212.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241211.dist-info → ai_edge_torch_nightly-0.3.0.dev20241212.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,8 @@
|
|
15
15
|
|
16
16
|
"""Example of building a decoder of PaliGemma 3B model which is Gemma1."""
|
17
17
|
|
18
|
+
from typing import Optional
|
19
|
+
|
18
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
19
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
20
22
|
from ai_edge_torch.generative.utilities import model_builder
|
@@ -51,6 +53,7 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
51
53
|
input_pos: torch.Tensor,
|
52
54
|
kv_cache: kv_utils.KVCache,
|
53
55
|
input_embeds: torch.Tensor = None,
|
56
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
54
57
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
55
58
|
if input_embeds is None:
|
56
59
|
return super().forward(tokens, input_pos, kv_cache)
|
@@ -16,11 +16,13 @@
|
|
16
16
|
"""Example of building a full-stack of PaliGemma model."""
|
17
17
|
|
18
18
|
from dataclasses import dataclass
|
19
|
+
from typing import Optional
|
19
20
|
|
20
21
|
from ai_edge_torch.generative.examples.paligemma import decoder
|
21
22
|
from ai_edge_torch.generative.examples.paligemma import image_encoder
|
22
23
|
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
23
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
from ai_edge_torch.generative.utilities import model_builder
|
24
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
25
27
|
import torch
|
26
28
|
from torch import nn
|
@@ -67,9 +69,16 @@ class PaliGemma(nn.Module):
|
|
67
69
|
input_pos: torch.Tensor,
|
68
70
|
kv_cache: kv_utils.KVCache,
|
69
71
|
pixel_values: torch.Tensor = None,
|
72
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
70
73
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
71
74
|
if pixel_values is None:
|
72
|
-
return self.decoder(
|
75
|
+
return self.decoder(
|
76
|
+
tokens=tokens,
|
77
|
+
input_pos=input_pos,
|
78
|
+
kv_cache=kv_cache,
|
79
|
+
input_embeds=None,
|
80
|
+
export_config=export_config
|
81
|
+
)
|
73
82
|
|
74
83
|
input_embeds = self.decoder.tok_embedding(tokens)
|
75
84
|
|
@@ -100,6 +109,7 @@ class PaliGemma(nn.Module):
|
|
100
109
|
input_pos=input_pos,
|
101
110
|
kv_cache=kv_cache,
|
102
111
|
input_embeds=input_embeds,
|
112
|
+
export_config=export_config,
|
103
113
|
)
|
104
114
|
|
105
115
|
|
@@ -190,14 +190,12 @@ def group_norm_with_hlfb(
|
|
190
190
|
"""
|
191
191
|
x = torch.permute(x, (0, 2, 3, 1))
|
192
192
|
|
193
|
-
# TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
|
194
|
-
# int32 when the bug is fixed.
|
195
193
|
builder = StableHLOCompositeBuilder(
|
196
194
|
name="odml.group_norm",
|
197
195
|
attr={
|
198
196
|
"num_groups": num_groups,
|
199
197
|
"epsilon": eps,
|
200
|
-
"reduction_axes": 3,
|
198
|
+
"reduction_axes": [3],
|
201
199
|
"channel_axis": 3,
|
202
200
|
},
|
203
201
|
)
|
@@ -29,7 +29,7 @@ class TransformersModelWrapper(verifier.ModelWrapper):
|
|
29
29
|
an object with `logits` field.
|
30
30
|
|
31
31
|
Transformers models get `max_new_tokens` settings for generate() via
|
32
|
-
|
32
|
+
GenerationConfig.
|
33
33
|
"""
|
34
34
|
|
35
35
|
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
@@ -38,5 +38,5 @@ class TransformersModelWrapper(verifier.ModelWrapper):
|
|
38
38
|
def generate(
|
39
39
|
self, inputs: torch.Tensor, max_new_tokens: int
|
40
40
|
) -> torch.IntTensor:
|
41
|
-
|
42
|
-
return self.model.generate(inputs=inputs, generation_config=
|
41
|
+
gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
|
42
|
+
return self.model.generate(inputs=inputs, generation_config=gen_config)
|
@@ -115,7 +115,7 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
115
115
|
# pixel_values only when it is not None. Otherwise, it may raise an error.
|
116
116
|
if pixel_values is None:
|
117
117
|
output = self.model.forward(
|
118
|
-
tokens, input_pos, kv_cache, self.export_config
|
118
|
+
tokens, input_pos, kv_cache, export_config=self.export_config
|
119
119
|
)
|
120
120
|
else:
|
121
121
|
output = self.model.forward(
|
@@ -16,12 +16,15 @@ import functools
|
|
16
16
|
import logging
|
17
17
|
|
18
18
|
from ai_edge_torch.odml_torch import jax_bridge
|
19
|
+
from ai_edge_torch.odml_torch.lowerings import context
|
20
|
+
from ai_edge_torch.odml_torch.lowerings import registry
|
21
|
+
import jax.numpy as jnp
|
22
|
+
from jax._src.lib.mlir import ir
|
19
23
|
import torch
|
20
24
|
import torch_xla2.ops.jaten # Import to load torch_xla2 ops
|
21
25
|
import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
|
22
26
|
|
23
|
-
|
24
|
-
|
27
|
+
LoweringContext = context.LoweringContext
|
25
28
|
|
26
29
|
@functools.cache
|
27
30
|
def _log_usage(op):
|
@@ -258,3 +261,26 @@ def _aten_copy(self, *args, **kwargs):
|
|
258
261
|
@lower_by_jax(torch.ops.aten.copy, ir_input_names=["src"])
|
259
262
|
def _aten_copy(self, src, **kwargs):
|
260
263
|
return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src)
|
264
|
+
|
265
|
+
|
266
|
+
# Schema:
|
267
|
+
# - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None)
|
268
|
+
# -> Tensor
|
269
|
+
# Torch Reference:
|
270
|
+
# - https://pytorch.org/docs/stable/generated/torch.einsum.html
|
271
|
+
# - https://github.com/pytorch/pytorch/blob/1b3f8b75896720e88362cbec7db32abc52afa83e/aten/src/ATen/native/Linear.cpp#L255
|
272
|
+
@registry.lower(torch.ops.aten.einsum.default)
|
273
|
+
def _aten_einsum_default(
|
274
|
+
lctx: LoweringContext,
|
275
|
+
equation: str,
|
276
|
+
tensors: list[ir.Value],
|
277
|
+
path=None,
|
278
|
+
):
|
279
|
+
_log_usage(torch.ops.aten.einsum.default)
|
280
|
+
|
281
|
+
@jax_bridge.wrap
|
282
|
+
def jax_lowering(operands):
|
283
|
+
# Ignore the input path and let JAX determine the path.
|
284
|
+
return jnp.einsum(equation, *operands, optimize="optimal")
|
285
|
+
|
286
|
+
return jax_lowering(lctx, tuple(tensors))
|
@@ -46,7 +46,13 @@ def decompositions():
|
|
46
46
|
|
47
47
|
torch._decomp.remove_decompositions(
|
48
48
|
decompositions,
|
49
|
-
[
|
49
|
+
[
|
50
|
+
torch.ops.aten.roll,
|
51
|
+
# Torch's default einsum impl/decompositions is less efficient and
|
52
|
+
# optimized through converter than JAX's impl. Disable einsum
|
53
|
+
# decomposition to use JAX bridge for a more efficient lowering.
|
54
|
+
torch.ops.aten.einsum.default,
|
55
|
+
],
|
50
56
|
)
|
51
57
|
|
52
58
|
# Override _safe_softmax decompositions with regular softmax.
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241212
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -29,7 +29,7 @@ Requires-Dist: safetensors
|
|
29
29
|
Requires-Dist: tabulate
|
30
30
|
Requires-Dist: torch>=2.4.0
|
31
31
|
Requires-Dist: torch-xla>=2.4.0
|
32
|
-
Requires-Dist: tf-nightly>=2.19.0.
|
32
|
+
Requires-Dist: tf-nightly>=2.19.0.dev20241201
|
33
33
|
Requires-Dist: ai-edge-litert-nightly
|
34
34
|
Requires-Dist: ai-edge-quantizer-nightly
|
35
35
|
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=QnJ2_alMOUe5ea0vTpY7AIBr8eoHvuwKaaj917g5DFA,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -64,9 +64,9 @@ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2u
|
|
64
64
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
65
65
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
66
66
|
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=rPFqcsv8RHvjmgfBW9OL6EKxMtVX-ySjBsMP4N8FErk,2816
|
67
|
-
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=
|
67
|
+
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=eICKQkJsJuEUkuvn5ymUsI9CGB-oNbgV7VH7BlmklfQ,4961
|
68
68
|
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
|
69
|
-
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=
|
69
|
+
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=nDyI-wUFJSawu57uLbFENei5l4cciqZ8lM5S5beN0FU,5604
|
70
70
|
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
|
71
71
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
72
72
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
|
@@ -121,7 +121,7 @@ ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbX
|
|
121
121
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
122
122
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
|
123
123
|
ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
|
124
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
124
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=h2btgRHMMjOcyLm8adEmcT0pG6imq4QcWblKJK5MYXA,7479
|
125
125
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
|
126
126
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
127
127
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -151,8 +151,8 @@ ai_edge_torch/generative/utilities/model_builder.py,sha256=rfD6INxunvDVdiUfTUxD7
|
|
151
151
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
152
152
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
153
153
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
154
|
-
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=
|
155
|
-
ai_edge_torch/generative/utilities/verifier.py,sha256=
|
154
|
+
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
155
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=ESSA8W1EYNsd4ntwmXbr-dn-BcIS27hf53XL5RTwjEU,11941
|
156
156
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
157
157
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
158
158
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -186,11 +186,11 @@ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=0GytV1dGnqe1mKityqQDNFNS8T
|
|
186
186
|
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=ufvnaAh6rM_yfoc8ybI3VErHEVBv5W_p4iOe9slfwKM,9948
|
187
187
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
188
188
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
|
189
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
189
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
|
190
190
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
|
191
191
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=rFmzqcdjYrwhcxH8j9zCFStPy21HFF7hkUV_GQ8FPAk,6056
|
192
192
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
193
|
-
ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=
|
193
|
+
ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
|
194
194
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
195
195
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
|
196
196
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
@@ -201,8 +201,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
201
201
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
202
202
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
203
203
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
204
|
-
ai_edge_torch_nightly-0.3.0.
|
205
|
-
ai_edge_torch_nightly-0.3.0.
|
206
|
-
ai_edge_torch_nightly-0.3.0.
|
207
|
-
ai_edge_torch_nightly-0.3.0.
|
208
|
-
ai_edge_torch_nightly-0.3.0.
|
204
|
+
ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
205
|
+
ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/METADATA,sha256=tkJXZvoB1p4WKAKgK9Ql071JxwI7BwU3gKmdJR5jcrs,1897
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
208
|
+
ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/RECORD,,
|
File without changes
|
File without changes
|