ai-edge-torch-nightly 0.4.0.dev20250312__py3-none-any.whl → 0.4.0.dev20250314__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/gemma2.py +3 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +25 -2
- ai_edge_torch/testing/__init__.py +4 -0
- ai_edge_torch/testing/export.py +83 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250312.dist-info → ai_edge_torch_nightly-0.4.0.dev20250314.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250312.dist-info → ai_edge_torch_nightly-0.4.0.dev20250314.dist-info}/RECORD +10 -9
- {ai_edge_torch_nightly-0.4.0.dev20250312.dist-info → ai_edge_torch_nightly-0.4.0.dev20250314.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250312.dist-info → ai_edge_torch_nightly-0.4.0.dev20250314.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250312.dist-info → ai_edge_torch_nightly-0.4.0.dev20250314.dist-info}/top_level.txt +0 -0
@@ -247,6 +247,9 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
247
247
|
rotary_base=10000,
|
248
248
|
rotary_percentage=1.0,
|
249
249
|
qkv_transpose_before_split=True,
|
250
|
+
# The safetensors from HF is not using the interleaved qkv format, so
|
251
|
+
# we need to disable interleaving here in the model config.
|
252
|
+
qkv_fused_interleaved=False,
|
250
253
|
logit_softcap=50.0,
|
251
254
|
sliding_window_size=4096,
|
252
255
|
attn_type=(
|
@@ -18,14 +18,16 @@ import functools
|
|
18
18
|
import numbers
|
19
19
|
from typing import Any
|
20
20
|
from typing import Optional
|
21
|
-
|
21
|
+
from ai_edge_torch.odml_torch import export_utils
|
22
22
|
from jax._src.lib.mlir import ir
|
23
23
|
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
24
24
|
import numpy as np
|
25
25
|
import torch
|
26
|
+
import torch.utils._pytree as pytree
|
26
27
|
|
27
28
|
|
28
|
-
def torch_dtype_to_ir_element_type(dtype):
|
29
|
+
def torch_dtype_to_ir_element_type(dtype) -> ir.Type:
|
30
|
+
"""Builds ir.Type from torch dtype."""
|
29
31
|
ty_get = {
|
30
32
|
torch.double: ir.F64Type.get,
|
31
33
|
torch.float32: ir.F32Type.get,
|
@@ -39,6 +41,27 @@ def torch_dtype_to_ir_element_type(dtype):
|
|
39
41
|
return ty_get()
|
40
42
|
|
41
43
|
|
44
|
+
def node_meta_to_ir_types(node: torch.fx.Node) -> list[ir.Type]:
|
45
|
+
"""Builds IR result types from torch FX node meta."""
|
46
|
+
tensor_meta = node.meta.get("tensor_meta") or node.meta.get("val")
|
47
|
+
if not tensor_meta:
|
48
|
+
raise RuntimeError(f"{node.name} does not have tensor meta")
|
49
|
+
|
50
|
+
tensor_meta_list, _ = pytree.tree_flatten(
|
51
|
+
[tensor_meta],
|
52
|
+
is_leaf=lambda x: hasattr(x, "dtype") and hasattr(x, "shape"),
|
53
|
+
)
|
54
|
+
results = []
|
55
|
+
for meta in tensor_meta_list:
|
56
|
+
shape = [
|
57
|
+
export_utils.IR_DYNAMIC if export_utils.is_torch_dynamic(dim) else dim
|
58
|
+
for dim in meta.shape
|
59
|
+
]
|
60
|
+
elty = torch_dtype_to_ir_element_type(meta.dtype)
|
61
|
+
results.append(ir.RankedTensorType.get(shape, elty))
|
62
|
+
return results
|
63
|
+
|
64
|
+
|
42
65
|
def splat(val, ty, shape=tuple(), *, loc: Optional[Any] = None):
|
43
66
|
if isinstance(ty, ir.IntegerType):
|
44
67
|
if ty.width == 1:
|
@@ -12,3 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
from ai_edge_torch.testing import export
|
16
|
+
from ai_edge_torch.testing import model_coverage
|
17
|
+
|
18
|
+
export_with_tensor_inputs_only = export.export_with_tensor_inputs_only
|
@@ -0,0 +1,83 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Torch export utilities for testing."""
|
16
|
+
|
17
|
+
from collections.abc import Callable
|
18
|
+
from typing import Any
|
19
|
+
|
20
|
+
import torch
|
21
|
+
from torch.utils import _pytree as pytree
|
22
|
+
|
23
|
+
|
24
|
+
def export_with_tensor_inputs_only(
|
25
|
+
model: Callable[..., Any],
|
26
|
+
args: tuple[Any, ...],
|
27
|
+
kwargs: dict[str, Any],
|
28
|
+
) -> torch.export.ExportedProgram:
|
29
|
+
"""Exports a PyTorch model, treating only tensor inputs as export inputs.
|
30
|
+
|
31
|
+
This function takes a PyTorch model and its input arguments (positional and
|
32
|
+
keyword) and exports it using `torch.export.export`. However, it modifies
|
33
|
+
the export process such that only the `torch.Tensor` arguments in the
|
34
|
+
inputs are considered as export inputs to the exported graph. All other
|
35
|
+
argument types (e.g., scalars, lists, tuples containing non-tensors) are
|
36
|
+
treated as constants.
|
37
|
+
|
38
|
+
This is useful for testing scenarios where you want to export a model but
|
39
|
+
want to avoid issues that might arise from non-tensor inputs
|
40
|
+
being treated as variables, or when you specifically want to focus on the
|
41
|
+
graph structure based on tensor operations.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
model: The PyTorch `nn.Module` to be exported.
|
45
|
+
args: A tuple of positional arguments to be passed to the model's `forward`
|
46
|
+
method.
|
47
|
+
kwargs: A dictionary of keyword arguments to be passed to the model's
|
48
|
+
`forward` method.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
torch.export.ExportedProgram: The exported program representing the model
|
52
|
+
computation with only tensor inputs being export inputs.
|
53
|
+
"""
|
54
|
+
flatten_args, treespec = pytree.tree_flatten([args, kwargs])
|
55
|
+
|
56
|
+
export_args = []
|
57
|
+
indices = []
|
58
|
+
for i, arg in enumerate(flatten_args):
|
59
|
+
if isinstance(arg, torch.Tensor):
|
60
|
+
export_args.append(arg)
|
61
|
+
indices.append(i)
|
62
|
+
|
63
|
+
class ModuleWrapper(torch.nn.Module):
|
64
|
+
|
65
|
+
def __init__(self, func, original_args, original_kwargs):
|
66
|
+
super().__init__()
|
67
|
+
self.original_args = list(flatten_args)
|
68
|
+
self.func = func
|
69
|
+
|
70
|
+
def forward(self, *export_args):
|
71
|
+
flatten_args = self.original_args.copy()
|
72
|
+
for i, arg in zip(indices, export_args):
|
73
|
+
flatten_args[i] = arg
|
74
|
+
args, kwargs = pytree.tree_unflatten(flatten_args, treespec)
|
75
|
+
return self.func(*args, **kwargs)
|
76
|
+
|
77
|
+
export_args = tuple(export_args)
|
78
|
+
export_kwargs = {}
|
79
|
+
return torch.export.export(
|
80
|
+
ModuleWrapper(model, args, kwargs).eval(),
|
81
|
+
export_args,
|
82
|
+
export_kwargs,
|
83
|
+
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.4.0.
|
3
|
+
Version: 0.4.0.dev20250314
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=PjlstuIJ-GlyKyFBMrwc7RQFRecNIkHpz5aIzvYNRKo,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -57,7 +57,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
|
|
57
57
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
|
58
58
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
|
59
59
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
|
60
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
60
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=CK1lHw-YQPAr26KMdrYA6icQHvKH59yHAQ4eC4X636o,11539
|
61
61
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
62
62
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
63
63
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
@@ -233,17 +233,18 @@ ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQD
|
|
233
233
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
234
234
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
235
235
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
236
|
-
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=
|
236
|
+
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=tIyZiSy2Rbtea9ZlRYUfTaYE5vW_lAU6itT6_rUp8Qg,7028
|
237
237
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
238
238
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
239
239
|
ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
|
240
240
|
ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=nuO3w9gOj9sKcsTBBexVDw8UZnd06KsjNrFr_gyNaiA,36710
|
241
241
|
ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9PphCRdO8o,3172
|
242
|
-
ai_edge_torch/testing/__init__.py,sha256=
|
242
|
+
ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPGhPbo,833
|
243
|
+
ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
|
243
244
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
244
245
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
245
|
-
ai_edge_torch_nightly-0.4.0.
|
246
|
-
ai_edge_torch_nightly-0.4.0.
|
247
|
-
ai_edge_torch_nightly-0.4.0.
|
248
|
-
ai_edge_torch_nightly-0.4.0.
|
249
|
-
ai_edge_torch_nightly-0.4.0.
|
246
|
+
ai_edge_torch_nightly-0.4.0.dev20250314.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
247
|
+
ai_edge_torch_nightly-0.4.0.dev20250314.dist-info/METADATA,sha256=n_c6T76WR-J-SCOmKKKzzuPoyM4i_2W2TO6ub8AuDw0,1966
|
248
|
+
ai_edge_torch_nightly-0.4.0.dev20250314.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
249
|
+
ai_edge_torch_nightly-0.4.0.dev20250314.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
250
|
+
ai_edge_torch_nightly-0.4.0.dev20250314.dist-info/RECORD,,
|
File without changes
|
File without changes
|