ai-edge-torch-nightly 0.3.0.dev20241211__py3-none-any.whl → 0.3.0.dev20241212__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/paligemma/decoder.py +3 -0
- ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
- ai_edge_torch/generative/layers/normalization.py +1 -3
- ai_edge_torch/generative/utilities/transformers_verifier.py +3 -3
- ai_edge_torch/generative/utilities/verifier.py +1 -1
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +7 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241211.dist-info → ai_edge_torch_nightly-0.3.0.dev20241212.dist-info}/METADATA +2 -2
- {ai_edge_torch_nightly-0.3.0.dev20241211.dist-info → ai_edge_torch_nightly-0.3.0.dev20241212.dist-info}/RECORD +13 -13
- {ai_edge_torch_nightly-0.3.0.dev20241211.dist-info → ai_edge_torch_nightly-0.3.0.dev20241212.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241211.dist-info → ai_edge_torch_nightly-0.3.0.dev20241212.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241211.dist-info → ai_edge_torch_nightly-0.3.0.dev20241212.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,8 @@
|
|
15
15
|
|
16
16
|
"""Example of building a decoder of PaliGemma 3B model which is Gemma1."""
|
17
17
|
|
18
|
+
from typing import Optional
|
19
|
+
|
18
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
19
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
20
22
|
from ai_edge_torch.generative.utilities import model_builder
|
@@ -51,6 +53,7 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
51
53
|
input_pos: torch.Tensor,
|
52
54
|
kv_cache: kv_utils.KVCache,
|
53
55
|
input_embeds: torch.Tensor = None,
|
56
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
54
57
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
55
58
|
if input_embeds is None:
|
56
59
|
return super().forward(tokens, input_pos, kv_cache)
|
@@ -16,11 +16,13 @@
|
|
16
16
|
"""Example of building a full-stack of PaliGemma model."""
|
17
17
|
|
18
18
|
from dataclasses import dataclass
|
19
|
+
from typing import Optional
|
19
20
|
|
20
21
|
from ai_edge_torch.generative.examples.paligemma import decoder
|
21
22
|
from ai_edge_torch.generative.examples.paligemma import image_encoder
|
22
23
|
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
23
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
from ai_edge_torch.generative.utilities import model_builder
|
24
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
25
27
|
import torch
|
26
28
|
from torch import nn
|
@@ -67,9 +69,16 @@ class PaliGemma(nn.Module):
|
|
67
69
|
input_pos: torch.Tensor,
|
68
70
|
kv_cache: kv_utils.KVCache,
|
69
71
|
pixel_values: torch.Tensor = None,
|
72
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
70
73
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
71
74
|
if pixel_values is None:
|
72
|
-
return self.decoder(
|
75
|
+
return self.decoder(
|
76
|
+
tokens=tokens,
|
77
|
+
input_pos=input_pos,
|
78
|
+
kv_cache=kv_cache,
|
79
|
+
input_embeds=None,
|
80
|
+
export_config=export_config
|
81
|
+
)
|
73
82
|
|
74
83
|
input_embeds = self.decoder.tok_embedding(tokens)
|
75
84
|
|
@@ -100,6 +109,7 @@ class PaliGemma(nn.Module):
|
|
100
109
|
input_pos=input_pos,
|
101
110
|
kv_cache=kv_cache,
|
102
111
|
input_embeds=input_embeds,
|
112
|
+
export_config=export_config,
|
103
113
|
)
|
104
114
|
|
105
115
|
|
@@ -190,14 +190,12 @@ def group_norm_with_hlfb(
|
|
190
190
|
"""
|
191
191
|
x = torch.permute(x, (0, 2, 3, 1))
|
192
192
|
|
193
|
-
# TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
|
194
|
-
# int32 when the bug is fixed.
|
195
193
|
builder = StableHLOCompositeBuilder(
|
196
194
|
name="odml.group_norm",
|
197
195
|
attr={
|
198
196
|
"num_groups": num_groups,
|
199
197
|
"epsilon": eps,
|
200
|
-
"reduction_axes": 3,
|
198
|
+
"reduction_axes": [3],
|
201
199
|
"channel_axis": 3,
|
202
200
|
},
|
203
201
|
)
|
@@ -29,7 +29,7 @@ class TransformersModelWrapper(verifier.ModelWrapper):
|
|
29
29
|
an object with `logits` field.
|
30
30
|
|
31
31
|
Transformers models get `max_new_tokens` settings for generate() via
|
32
|
-
|
32
|
+
GenerationConfig.
|
33
33
|
"""
|
34
34
|
|
35
35
|
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
@@ -38,5 +38,5 @@ class TransformersModelWrapper(verifier.ModelWrapper):
|
|
38
38
|
def generate(
|
39
39
|
self, inputs: torch.Tensor, max_new_tokens: int
|
40
40
|
) -> torch.IntTensor:
|
41
|
-
|
42
|
-
return self.model.generate(inputs=inputs, generation_config=
|
41
|
+
gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
|
42
|
+
return self.model.generate(inputs=inputs, generation_config=gen_config)
|
@@ -115,7 +115,7 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
115
115
|
# pixel_values only when it is not None. Otherwise, it may raise an error.
|
116
116
|
if pixel_values is None:
|
117
117
|
output = self.model.forward(
|
118
|
-
tokens, input_pos, kv_cache, self.export_config
|
118
|
+
tokens, input_pos, kv_cache, export_config=self.export_config
|
119
119
|
)
|
120
120
|
else:
|
121
121
|
output = self.model.forward(
|
@@ -16,12 +16,15 @@ import functools
|
|
16
16
|
import logging
|
17
17
|
|
18
18
|
from ai_edge_torch.odml_torch import jax_bridge
|
19
|
+
from ai_edge_torch.odml_torch.lowerings import context
|
20
|
+
from ai_edge_torch.odml_torch.lowerings import registry
|
21
|
+
import jax.numpy as jnp
|
22
|
+
from jax._src.lib.mlir import ir
|
19
23
|
import torch
|
20
24
|
import torch_xla2.ops.jaten # Import to load torch_xla2 ops
|
21
25
|
import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
|
22
26
|
|
23
|
-
|
24
|
-
|
27
|
+
LoweringContext = context.LoweringContext
|
25
28
|
|
26
29
|
@functools.cache
|
27
30
|
def _log_usage(op):
|
@@ -258,3 +261,26 @@ def _aten_copy(self, *args, **kwargs):
|
|
258
261
|
@lower_by_jax(torch.ops.aten.copy, ir_input_names=["src"])
|
259
262
|
def _aten_copy(self, src, **kwargs):
|
260
263
|
return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src)
|
264
|
+
|
265
|
+
|
266
|
+
# Schema:
|
267
|
+
# - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None)
|
268
|
+
# -> Tensor
|
269
|
+
# Torch Reference:
|
270
|
+
# - https://pytorch.org/docs/stable/generated/torch.einsum.html
|
271
|
+
# - https://github.com/pytorch/pytorch/blob/1b3f8b75896720e88362cbec7db32abc52afa83e/aten/src/ATen/native/Linear.cpp#L255
|
272
|
+
@registry.lower(torch.ops.aten.einsum.default)
|
273
|
+
def _aten_einsum_default(
|
274
|
+
lctx: LoweringContext,
|
275
|
+
equation: str,
|
276
|
+
tensors: list[ir.Value],
|
277
|
+
path=None,
|
278
|
+
):
|
279
|
+
_log_usage(torch.ops.aten.einsum.default)
|
280
|
+
|
281
|
+
@jax_bridge.wrap
|
282
|
+
def jax_lowering(operands):
|
283
|
+
# Ignore the input path and let JAX determine the path.
|
284
|
+
return jnp.einsum(equation, *operands, optimize="optimal")
|
285
|
+
|
286
|
+
return jax_lowering(lctx, tuple(tensors))
|
@@ -46,7 +46,13 @@ def decompositions():
|
|
46
46
|
|
47
47
|
torch._decomp.remove_decompositions(
|
48
48
|
decompositions,
|
49
|
-
[
|
49
|
+
[
|
50
|
+
torch.ops.aten.roll,
|
51
|
+
# Torch's default einsum impl/decompositions is less efficient and
|
52
|
+
# optimized through converter than JAX's impl. Disable einsum
|
53
|
+
# decomposition to use JAX bridge for a more efficient lowering.
|
54
|
+
torch.ops.aten.einsum.default,
|
55
|
+
],
|
50
56
|
)
|
51
57
|
|
52
58
|
# Override _safe_softmax decompositions with regular softmax.
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241212
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -29,7 +29,7 @@ Requires-Dist: safetensors
|
|
29
29
|
Requires-Dist: tabulate
|
30
30
|
Requires-Dist: torch>=2.4.0
|
31
31
|
Requires-Dist: torch-xla>=2.4.0
|
32
|
-
Requires-Dist: tf-nightly>=2.19.0.
|
32
|
+
Requires-Dist: tf-nightly>=2.19.0.dev20241201
|
33
33
|
Requires-Dist: ai-edge-litert-nightly
|
34
34
|
Requires-Dist: ai-edge-quantizer-nightly
|
35
35
|
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=QnJ2_alMOUe5ea0vTpY7AIBr8eoHvuwKaaj917g5DFA,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -64,9 +64,9 @@ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2u
|
|
64
64
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
65
65
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
66
66
|
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=rPFqcsv8RHvjmgfBW9OL6EKxMtVX-ySjBsMP4N8FErk,2816
|
67
|
-
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=
|
67
|
+
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=eICKQkJsJuEUkuvn5ymUsI9CGB-oNbgV7VH7BlmklfQ,4961
|
68
68
|
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
|
69
|
-
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=
|
69
|
+
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=nDyI-wUFJSawu57uLbFENei5l4cciqZ8lM5S5beN0FU,5604
|
70
70
|
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
|
71
71
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
72
72
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
|
@@ -121,7 +121,7 @@ ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbX
|
|
121
121
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
122
122
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
|
123
123
|
ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
|
124
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
124
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=h2btgRHMMjOcyLm8adEmcT0pG6imq4QcWblKJK5MYXA,7479
|
125
125
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
|
126
126
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
127
127
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -151,8 +151,8 @@ ai_edge_torch/generative/utilities/model_builder.py,sha256=rfD6INxunvDVdiUfTUxD7
|
|
151
151
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
152
152
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
153
153
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
154
|
-
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=
|
155
|
-
ai_edge_torch/generative/utilities/verifier.py,sha256=
|
154
|
+
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
155
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=ESSA8W1EYNsd4ntwmXbr-dn-BcIS27hf53XL5RTwjEU,11941
|
156
156
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
157
157
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
158
158
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -186,11 +186,11 @@ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=0GytV1dGnqe1mKityqQDNFNS8T
|
|
186
186
|
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=ufvnaAh6rM_yfoc8ybI3VErHEVBv5W_p4iOe9slfwKM,9948
|
187
187
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
188
188
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
|
189
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
189
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
|
190
190
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
|
191
191
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=rFmzqcdjYrwhcxH8j9zCFStPy21HFF7hkUV_GQ8FPAk,6056
|
192
192
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
193
|
-
ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=
|
193
|
+
ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
|
194
194
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
195
195
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
|
196
196
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
@@ -201,8 +201,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
201
201
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
202
202
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
203
203
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
204
|
-
ai_edge_torch_nightly-0.3.0.
|
205
|
-
ai_edge_torch_nightly-0.3.0.
|
206
|
-
ai_edge_torch_nightly-0.3.0.
|
207
|
-
ai_edge_torch_nightly-0.3.0.
|
208
|
-
ai_edge_torch_nightly-0.3.0.
|
204
|
+
ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
205
|
+
ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/METADATA,sha256=tkJXZvoB1p4WKAKgK9Ql071JxwI7BwU3gKmdJR5jcrs,1897
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
208
|
+
ai_edge_torch_nightly-0.3.0.dev20241212.dist-info/RECORD,,
|
File without changes
|
File without changes
|