ai-edge-torch-nightly 0.4.0.dev20250313__py3-none-any.whl → 0.4.0.dev20250314__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,14 +18,16 @@ import functools
18
18
  import numbers
19
19
  from typing import Any
20
20
  from typing import Optional
21
-
21
+ from ai_edge_torch.odml_torch import export_utils
22
22
  from jax._src.lib.mlir import ir
23
23
  from jax._src.lib.mlir.dialects import hlo as stablehlo
24
24
  import numpy as np
25
25
  import torch
26
+ import torch.utils._pytree as pytree
26
27
 
27
28
 
28
- def torch_dtype_to_ir_element_type(dtype):
29
+ def torch_dtype_to_ir_element_type(dtype) -> ir.Type:
30
+ """Builds ir.Type from torch dtype."""
29
31
  ty_get = {
30
32
  torch.double: ir.F64Type.get,
31
33
  torch.float32: ir.F32Type.get,
@@ -39,6 +41,27 @@ def torch_dtype_to_ir_element_type(dtype):
39
41
  return ty_get()
40
42
 
41
43
 
44
+ def node_meta_to_ir_types(node: torch.fx.Node) -> list[ir.Type]:
45
+ """Builds IR result types from torch FX node meta."""
46
+ tensor_meta = node.meta.get("tensor_meta") or node.meta.get("val")
47
+ if not tensor_meta:
48
+ raise RuntimeError(f"{node.name} does not have tensor meta")
49
+
50
+ tensor_meta_list, _ = pytree.tree_flatten(
51
+ [tensor_meta],
52
+ is_leaf=lambda x: hasattr(x, "dtype") and hasattr(x, "shape"),
53
+ )
54
+ results = []
55
+ for meta in tensor_meta_list:
56
+ shape = [
57
+ export_utils.IR_DYNAMIC if export_utils.is_torch_dynamic(dim) else dim
58
+ for dim in meta.shape
59
+ ]
60
+ elty = torch_dtype_to_ir_element_type(meta.dtype)
61
+ results.append(ir.RankedTensorType.get(shape, elty))
62
+ return results
63
+
64
+
42
65
  def splat(val, ty, shape=tuple(), *, loc: Optional[Any] = None):
43
66
  if isinstance(ty, ir.IntegerType):
44
67
  if ty.width == 1:
@@ -12,3 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from ai_edge_torch.testing import export
16
+ from ai_edge_torch.testing import model_coverage
17
+
18
+ export_with_tensor_inputs_only = export.export_with_tensor_inputs_only
@@ -0,0 +1,83 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Torch export utilities for testing."""
16
+
17
+ from collections.abc import Callable
18
+ from typing import Any
19
+
20
+ import torch
21
+ from torch.utils import _pytree as pytree
22
+
23
+
24
+ def export_with_tensor_inputs_only(
25
+ model: Callable[..., Any],
26
+ args: tuple[Any, ...],
27
+ kwargs: dict[str, Any],
28
+ ) -> torch.export.ExportedProgram:
29
+ """Exports a PyTorch model, treating only tensor inputs as export inputs.
30
+
31
+ This function takes a PyTorch model and its input arguments (positional and
32
+ keyword) and exports it using `torch.export.export`. However, it modifies
33
+ the export process such that only the `torch.Tensor` arguments in the
34
+ inputs are considered as export inputs to the exported graph. All other
35
+ argument types (e.g., scalars, lists, tuples containing non-tensors) are
36
+ treated as constants.
37
+
38
+ This is useful for testing scenarios where you want to export a model but
39
+ want to avoid issues that might arise from non-tensor inputs
40
+ being treated as variables, or when you specifically want to focus on the
41
+ graph structure based on tensor operations.
42
+
43
+ Args:
44
+ model: The PyTorch `nn.Module` to be exported.
45
+ args: A tuple of positional arguments to be passed to the model's `forward`
46
+ method.
47
+ kwargs: A dictionary of keyword arguments to be passed to the model's
48
+ `forward` method.
49
+
50
+ Returns:
51
+ torch.export.ExportedProgram: The exported program representing the model
52
+ computation with only tensor inputs being export inputs.
53
+ """
54
+ flatten_args, treespec = pytree.tree_flatten([args, kwargs])
55
+
56
+ export_args = []
57
+ indices = []
58
+ for i, arg in enumerate(flatten_args):
59
+ if isinstance(arg, torch.Tensor):
60
+ export_args.append(arg)
61
+ indices.append(i)
62
+
63
+ class ModuleWrapper(torch.nn.Module):
64
+
65
+ def __init__(self, func, original_args, original_kwargs):
66
+ super().__init__()
67
+ self.original_args = list(flatten_args)
68
+ self.func = func
69
+
70
+ def forward(self, *export_args):
71
+ flatten_args = self.original_args.copy()
72
+ for i, arg in zip(indices, export_args):
73
+ flatten_args[i] = arg
74
+ args, kwargs = pytree.tree_unflatten(flatten_args, treespec)
75
+ return self.func(*args, **kwargs)
76
+
77
+ export_args = tuple(export_args)
78
+ export_kwargs = {}
79
+ return torch.export.export(
80
+ ModuleWrapper(model, args, kwargs).eval(),
81
+ export_args,
82
+ export_kwargs,
83
+ )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.4.0.dev20250313"
16
+ __version__ = "0.4.0.dev20250314"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.4.0.dev20250313
3
+ Version: 0.4.0.dev20250314
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=p8uPatCB4LYJx_OsddOLTOjd9FYSuB1_RgF44IzGGJ0,706
5
+ ai_edge_torch/version.py,sha256=PjlstuIJ-GlyKyFBMrwc7RQFRecNIkHpz5aIzvYNRKo,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -233,17 +233,18 @@ ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQD
233
233
  ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
234
234
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
235
235
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
236
- ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
236
+ ai_edge_torch/odml_torch/lowerings/utils.py,sha256=tIyZiSy2Rbtea9ZlRYUfTaYE5vW_lAU6itT6_rUp8Qg,7028
237
237
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
238
238
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
239
239
  ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
240
240
  ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=nuO3w9gOj9sKcsTBBexVDw8UZnd06KsjNrFr_gyNaiA,36710
241
241
  ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9PphCRdO8o,3172
242
- ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
242
+ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPGhPbo,833
243
+ ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
243
244
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
244
245
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
245
- ai_edge_torch_nightly-0.4.0.dev20250313.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
246
- ai_edge_torch_nightly-0.4.0.dev20250313.dist-info/METADATA,sha256=fdcAPu4DlPvyxTrNjWoiqKQ5ZXcTjdHlrcngjnOLOV4,1966
247
- ai_edge_torch_nightly-0.4.0.dev20250313.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
248
- ai_edge_torch_nightly-0.4.0.dev20250313.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
249
- ai_edge_torch_nightly-0.4.0.dev20250313.dist-info/RECORD,,
246
+ ai_edge_torch_nightly-0.4.0.dev20250314.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
247
+ ai_edge_torch_nightly-0.4.0.dev20250314.dist-info/METADATA,sha256=n_c6T76WR-J-SCOmKKKzzuPoyM4i_2W2TO6ub8AuDw0,1966
248
+ ai_edge_torch_nightly-0.4.0.dev20250314.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
249
+ ai_edge_torch_nightly-0.4.0.dev20250314.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
250
+ ai_edge_torch_nightly-0.4.0.dev20250314.dist-info/RECORD,,