ai-edge-torch-nightly 0.3.0.dev20250121__py3-none-any.whl → 0.3.0.dev20250123__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (35) hide show
  1. ai_edge_torch/__init__.py +1 -1
  2. ai_edge_torch/_convert/conversion.py +6 -10
  3. ai_edge_torch/_convert/fx_passes/__init__.py +1 -1
  4. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +6 -3
  5. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +9 -11
  6. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -3
  7. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -0
  8. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py +33 -0
  9. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -0
  10. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +3 -3
  11. ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +4 -4
  12. ai_edge_torch/_convert/test/test_convert.py +2 -2
  13. ai_edge_torch/fx_infra/__init__.py +32 -0
  14. ai_edge_torch/fx_infra/_canonicalize_pass.py +27 -0
  15. ai_edge_torch/fx_infra/_safe_run_decompositions.py +57 -0
  16. ai_edge_torch/fx_infra/decomp.py +80 -0
  17. ai_edge_torch/fx_infra/graph_utils.py +42 -0
  18. ai_edge_torch/{fx_pass_base.py → fx_infra/pass_base.py} +0 -28
  19. ai_edge_torch/generative/fx_passes/__init__.py +3 -3
  20. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -3
  21. ai_edge_torch/hlfb/mark_pattern/__init__.py +0 -1
  22. ai_edge_torch/hlfb/mark_pattern/fx_utils.py +5 -21
  23. ai_edge_torch/hlfb/mark_pattern/pattern.py +20 -11
  24. ai_edge_torch/hlfb/test/test_mark_pattern.py +18 -15
  25. ai_edge_torch/odml_torch/_torch_future.py +0 -27
  26. ai_edge_torch/odml_torch/export.py +6 -8
  27. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
  28. ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +65 -0
  29. ai_edge_torch/version.py +1 -1
  30. {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/METADATA +1 -1
  31. {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/RECORD +34 -28
  32. ai_edge_torch/odml_torch/lowerings/decomp.py +0 -69
  33. {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/LICENSE +0 -0
  34. {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/WHEEL +0 -0
  35. {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py CHANGED
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
16
15
  from ai_edge_torch._config import config
17
16
  from ai_edge_torch._convert.converter import convert
18
17
  from ai_edge_torch._convert.converter import signature
@@ -20,6 +19,7 @@ from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
20
19
  from ai_edge_torch.model import Model
21
20
  from ai_edge_torch.version import __version__
22
21
 
22
+
23
23
  def load(path: str) -> Model:
24
24
  """Imports an ai_edge_torch model from disk.
25
25
 
@@ -17,7 +17,7 @@ import logging
17
17
  from typing import Any, Literal, Optional, Union
18
18
 
19
19
  import ai_edge_torch
20
- from ai_edge_torch import fx_pass_base
20
+ from ai_edge_torch import fx_infra
21
21
  from ai_edge_torch import lowertools
22
22
  from ai_edge_torch import model
23
23
  from ai_edge_torch._convert import fx_passes
@@ -53,7 +53,7 @@ def _run_convert_passes(
53
53
  fx_passes.CanonicalizePass(),
54
54
  ]
55
55
 
56
- exported_program = fx_pass_base.run_passes(exported_program, passes)
56
+ exported_program = fx_infra.run_passes(exported_program, passes)
57
57
  return exported_program
58
58
 
59
59
 
@@ -125,14 +125,10 @@ def convert_signatures(
125
125
  else:
126
126
  exported_program = torch.export.export(**kwargs, strict=True)
127
127
 
128
- if hasattr(torch._decomp, "_decomp_table_to_post_autograd_aten"):
129
- # Available after torch 2.5.0: `_decomp_table_to_post_autograd_aten` is a
130
- # stop-gap table which replicates the old behaviour of post-dispatch IR.
131
- # This could help ensure the collection of aten ops remaining still as the
132
- # implementation of torch.export changes.
133
- exported_program = exported_program.run_decompositions(
134
- torch._decomp._decomp_table_to_post_autograd_aten()
135
- )
128
+ exported_program = fx_infra.safe_run_decompositions(
129
+ exported_program,
130
+ fx_infra.decomp.pre_convert_decomp(),
131
+ )
136
132
  return exported_program
137
133
 
138
134
  exported_programs: torch.export.ExportedProgram = [
@@ -20,4 +20,4 @@ from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import Bu
20
20
  from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
21
21
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
22
22
  from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
23
- from ai_edge_torch.fx_pass_base import CanonicalizePass
23
+ from ai_edge_torch.fx_infra import CanonicalizePass
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from typing import Any, Callable
17
- from ai_edge_torch import fx_pass_base
17
+ from ai_edge_torch import fx_infra
18
18
  from ai_edge_torch import lowertools
19
19
  import torch
20
20
  import torch.utils._pytree as pytree
@@ -25,6 +25,9 @@ _composite_builders: dict[
25
25
 
26
26
 
27
27
  def _register_composite_builder(op):
28
+ # Remove op from pre_convert_decomp to keep this in the decomposed graph.
29
+ fx_infra.decomp.remove_pre_convert_decomp(op)
30
+
28
31
  def inner(func):
29
32
  if isinstance(op, torch._ops.OpOverloadPacket):
30
33
  for overload in op.overloads():
@@ -276,7 +279,7 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
276
279
  node.target = embedding
277
280
 
278
281
 
279
- class BuildAtenCompositePass(fx_pass_base.PassBase):
282
+ class BuildAtenCompositePass(fx_infra.PassBase):
280
283
 
281
284
  def call(self, graph_module: torch.fx.GraphModule):
282
285
  for node in graph_module.graph.nodes:
@@ -285,4 +288,4 @@ class BuildAtenCompositePass(fx_pass_base.PassBase):
285
288
 
286
289
  graph_module.graph.lint()
287
290
  graph_module.recompile()
288
- return fx_pass_base.PassResult(graph_module, True)
291
+ return fx_infra.PassResult(graph_module, True)
@@ -16,7 +16,7 @@
16
16
 
17
17
  import functools
18
18
 
19
- from ai_edge_torch import fx_pass_base
19
+ from ai_edge_torch import fx_infra
20
20
  from ai_edge_torch.hlfb import mark_pattern
21
21
  from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
22
22
  import torch
@@ -41,7 +41,7 @@ def _get_upsample_bilinear2d_pattern():
41
41
  x, scale_factor=2, mode="bilinear", align_corners=False
42
42
  ),
43
43
  export_args=(torch.rand(1, 3, 100, 100),),
44
- decomp_table=_INTERPOLATE_DECOMPOSITIONS,
44
+ extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
45
45
  )
46
46
 
47
47
  @pattern.register_attr_builder
@@ -65,7 +65,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
65
65
  x, scale_factor=2, mode="bilinear", align_corners=True
66
66
  ),
67
67
  export_args=(torch.rand(1, 3, 100, 100),),
68
- decomp_table=_INTERPOLATE_DECOMPOSITIONS,
68
+ extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
69
69
  )
70
70
 
71
71
  @pattern.register_attr_builder
@@ -89,7 +89,7 @@ def _get_interpolate_nearest2d_pattern():
89
89
  x, scale_factor=2, mode="nearest"
90
90
  ),
91
91
  export_args=(torch.rand(1, 3, 100, 100),),
92
- decomp_table=_INTERPOLATE_DECOMPOSITIONS,
92
+ extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
93
93
  )
94
94
 
95
95
  @pattern.register_attr_builder
@@ -104,7 +104,7 @@ def _get_interpolate_nearest2d_pattern():
104
104
  return pattern
105
105
 
106
106
 
107
- class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
107
+ class BuildInterpolateCompositePass(fx_infra.ExportedProgramPassBase):
108
108
 
109
109
  def __init__(self):
110
110
  super().__init__()
@@ -115,11 +115,9 @@ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
115
115
  ]
116
116
 
117
117
  def call(self, exported_program: torch.export.ExportedProgram):
118
- exported_program = fx_pass_base.run_passes(
119
- exported_program, [fx_pass_base.CanonicalizePass()]
120
- )
121
- exported_program = exported_program.run_decompositions(
122
- _INTERPOLATE_DECOMPOSITIONS
118
+ exported_program = fx_infra.safe_run_decompositions(
119
+ exported_program,
120
+ _INTERPOLATE_DECOMPOSITIONS,
123
121
  )
124
122
 
125
123
  graph_module = exported_program.graph_module
@@ -128,4 +126,4 @@ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
128
126
 
129
127
  graph_module.graph.lint()
130
128
  graph_module.recompile()
131
- return fx_pass_base.ExportedProgramPassResult(exported_program, True)
129
+ return fx_infra.ExportedProgramPassResult(exported_program, True)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from ai_edge_torch import fx_pass_base
16
+ from ai_edge_torch import fx_infra
17
17
  from ai_edge_torch import lowertools
18
18
  import torch
19
19
  import torch.utils._pytree as pytree
@@ -61,7 +61,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
61
61
  node.target = debuginfo_writer
62
62
 
63
63
 
64
- class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
64
+ class InjectMlirDebuginfoPass(fx_infra.PassBase):
65
65
  """DEPRECATED: Debuginfo is injected automatically by odml_torch."""
66
66
 
67
67
  def call(self, graph_module: torch.fx.GraphModule):
@@ -70,4 +70,4 @@ class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
70
70
 
71
71
  graph_module.graph.lint()
72
72
  graph_module.recompile()
73
- return fx_pass_base.PassResult(graph_module, True)
73
+ return fx_infra.PassResult(graph_module, True)
@@ -13,4 +13,5 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import _decomp_registry
16
17
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass.pass_body import OptimizeLayoutTransposesPass # NOQA
@@ -0,0 +1,33 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Remove decompositions for ops to keep in layout optimization."""
16
+ from ai_edge_torch import fx_infra
17
+ import torch
18
+
19
+ __all__ = []
20
+
21
+ aten = torch.ops.aten
22
+
23
+ _OPS_TO_KEEP = [
24
+ aten.conv2d,
25
+ aten.max_pool2d,
26
+ aten._softmax.default,
27
+ aten.group_norm.default,
28
+ aten.native_group_norm.default,
29
+ aten.reflection_pad2d.default,
30
+ ]
31
+
32
+ for op in _OPS_TO_KEEP:
33
+ fx_infra.decomp.remove_pre_convert_decomp(op)
@@ -20,6 +20,7 @@ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import uti
20
20
  class OpFuncRegistry(dict):
21
21
 
22
22
  def register(self, op):
23
+
23
24
  ops = utils.flatten_torch_op_overloads(op)
24
25
 
25
26
  def inner(func):
@@ -18,7 +18,7 @@ import operator
18
18
  import os
19
19
  from typing import Union
20
20
 
21
- from ai_edge_torch import fx_pass_base
21
+ from ai_edge_torch import fx_infra
22
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
23
23
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
24
24
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
@@ -30,7 +30,7 @@ import torch.ao.quantization.quantize_pt2e
30
30
  TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
31
31
 
32
32
 
33
- class OptimizeLayoutTransposesPass(fx_pass_base.ExportedProgramPassBase):
33
+ class OptimizeLayoutTransposesPass(fx_infra.ExportedProgramPassBase):
34
34
 
35
35
  def get_source_meta(self, node: torch.fx.Node):
36
36
  keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
@@ -300,4 +300,4 @@ class OptimizeLayoutTransposesPass(fx_pass_base.ExportedProgramPassBase):
300
300
  # Mark const node again for debugging
301
301
  self.mark_const_nodes(exported_program)
302
302
 
303
- return fx_pass_base.ExportedProgramPassResult(exported_program, True)
303
+ return fx_infra.ExportedProgramPassResult(exported_program, True)
@@ -15,11 +15,11 @@
15
15
  """Pass to remove all non user outputs from exported program."""
16
16
 
17
17
 
18
- from ai_edge_torch import fx_pass_base
18
+ from ai_edge_torch import fx_infra
19
19
  import torch
20
20
 
21
21
 
22
- class RemoveNonUserOutputsPass(fx_pass_base.ExportedProgramPassBase):
22
+ class RemoveNonUserOutputsPass(fx_infra.ExportedProgramPassBase):
23
23
  """This pass removes all non user outputs from the exported program's output.
24
24
 
25
25
  The FX graph may output more tensors/data than what user's original model
@@ -47,6 +47,6 @@ class RemoveNonUserOutputsPass(fx_pass_base.ExportedProgramPassBase):
47
47
  node.args = (tuple(new_outputs),)
48
48
  exported_program.graph_signature.output_specs = new_output_specs
49
49
 
50
- exported_program.graph_module.graph.lint()
50
+ exported_program.graph.eliminate_dead_code()
51
51
  exported_program.graph_module.recompile()
52
- return fx_pass_base.ExportedProgramPassResult(exported_program, True)
52
+ return fx_infra.ExportedProgramPassResult(exported_program, True)
@@ -511,7 +511,7 @@ class TestConvert(googletest.TestCase):
511
511
  # Step 1: export resnet18
512
512
  args = (torch.randn(1, 3, 224, 224),)
513
513
  m = torchvision.models.resnet18().eval()
514
- m = torch._export.capture_pre_autograd_graph(m, args)
514
+ m = torch.export.export_for_training(m, args).module()
515
515
 
516
516
  # Step 2: Insert observers or fake quantize modules
517
517
  quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
@@ -533,7 +533,7 @@ class TestConvert(googletest.TestCase):
533
533
  # Step 1: export resnet18
534
534
  args = (torch.randn(1, 3, 224, 224),)
535
535
  m = torchvision.models.resnet18().eval()
536
- m = torch._export.capture_pre_autograd_graph(m, args)
536
+ m = torch.export.export_for_training(m, args).module()
537
537
 
538
538
  # Step 2: Insert observers or fake quantize modules
539
539
  quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
@@ -0,0 +1,32 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from ai_edge_torch.fx_infra import _canonicalize_pass
17
+ from ai_edge_torch.fx_infra import _safe_run_decompositions
18
+ from ai_edge_torch.fx_infra import decomp
19
+ from ai_edge_torch.fx_infra import graph_utils
20
+ from ai_edge_torch.fx_infra import pass_base
21
+
22
+
23
+ PassBase = pass_base.PassBase
24
+ PassResult = pass_base.PassResult
25
+ FxPassBase = pass_base.FxPassBase
26
+ FxPassResult = pass_base.FxPassResult
27
+ ExportedProgramPassBase = pass_base.ExportedProgramPassBase
28
+ ExportedProgramPassResult = pass_base.ExportedProgramPassResult
29
+ run_passes = pass_base.run_passes
30
+
31
+ CanonicalizePass = _canonicalize_pass.CanonicalizePass
32
+ safe_run_decompositions = _safe_run_decompositions.safe_run_decompositions
@@ -0,0 +1,27 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from ai_edge_torch.fx_infra import _safe_run_decompositions
16
+ from ai_edge_torch.fx_infra import pass_base
17
+ import torch
18
+
19
+
20
+ class CanonicalizePass(pass_base.ExportedProgramPassBase):
21
+
22
+ def call(self, exported_program: torch.export.ExportedProgram):
23
+ exported_program = _safe_run_decompositions.safe_run_decompositions(
24
+ exported_program, {}
25
+ )
26
+
27
+ return pass_base.ExportedProgramPassResult(exported_program, True)
@@ -0,0 +1,57 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
16
+ import torch
17
+
18
+
19
+ # A dummy decomp table for running ExportedProgram.run_decompositions without
20
+ # any op decompositions but just aot_export_module. Due to the check in
21
+ # run_decompositions, if None or an empty dict is passed as decomp_table,
22
+ # it will run the default aten-coreaten decompositions. Therefore a non-empty
23
+ # dummy decomp table is needed.
24
+ # Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
25
+ _DUMMY_DECOMP_TABLE = {
26
+ torch._ops.OperatorBase(): lambda: None,
27
+ }
28
+
29
+
30
+ def safe_run_decompositions(exported_program, decomp_table=None):
31
+ """Wrapper for ExportedProgram.run_decompositions to handle unexpected export behavior."""
32
+
33
+ if decomp_table is not None and not decomp_table:
34
+ # Empty decomp table means no op decompositions. Use dummy decomp table
35
+ # instead for backward compatibility.
36
+ decomp_table = _DUMMY_DECOMP_TABLE
37
+
38
+ for node in exported_program.graph.nodes:
39
+ if node.target == torch.ops.aten.view.default:
40
+ # Passes or torch.export may generate aten.view nodes not respecting the
41
+ # tensor memory format. Changes all the aten.view to torch.reshape
42
+ # for retracing. If the input memory format is already contiguous,
43
+ # retracing in run_decomposition below would decompose torch.reshape
44
+ # back to one aten.view.
45
+ node.target = lambda self, size: torch.reshape(self.contiguous(), size)
46
+
47
+ exported_program = exported_program.run_decompositions(decomp_table)
48
+
49
+ if hasattr(torch.ops.aten, "_assert_tensor_metadata"):
50
+ for node in exported_program.graph.nodes:
51
+ if node.target == torch.ops.aten._assert_tensor_metadata.default:
52
+ exported_program.graph.erase_node(node)
53
+
54
+ exported_program.graph.eliminate_dead_code()
55
+ exported_program.graph_module.recompile()
56
+
57
+ return exported_program
@@ -0,0 +1,80 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
16
+ import torch
17
+
18
+ # Decompositions to be run after torch.export before conversion and any
19
+ # passes. Remove ops from this decomposition table if they need to be preserved
20
+ # in passes.
21
+ _pre_convert_decomp = torch._decomp.core_aten_decompositions().copy()
22
+
23
+
24
+ # Decompositions to be run after conversion before odml_torch passes and
25
+ # lowerings.
26
+ _pre_lower_decomp = torch._decomp.core_aten_decompositions().copy()
27
+
28
+
29
+ def _get_ops(op):
30
+ if isinstance(op, torch._ops.OpOverloadPacket):
31
+ return [getattr(op, overload) for overload in op.overloads()]
32
+ else:
33
+ return [op]
34
+
35
+
36
+ def pre_convert_decomp():
37
+ """Decompositions to be run after torch.export before conversion and any passes."""
38
+ return _pre_convert_decomp.copy()
39
+
40
+
41
+ def pre_lower_decomp():
42
+ """Decompositions to be run after conversion before odml_torch passes and lowerings."""
43
+ return _pre_lower_decomp.copy()
44
+
45
+
46
+ def remove_pre_lower_decomp(op):
47
+ # Also remove from pre_convert_decomp which always run before pre_lower_
48
+ # decomp.
49
+ remove_pre_convert_decomp(op)
50
+
51
+ for op_ in _get_ops(op):
52
+ _pre_lower_decomp.pop(op_, None)
53
+
54
+
55
+ def remove_pre_convert_decomp(op):
56
+ for op_ in _get_ops(op):
57
+ _pre_convert_decomp.pop(op_, None)
58
+
59
+
60
+ def add_pre_convert_decomp(op, decomp):
61
+ # Also add decomp to pre_lower_decomp which runs after pre_convert_decomp.
62
+ add_pre_lower_decomp(op, decomp)
63
+
64
+ for op_ in _get_ops(op):
65
+ _pre_convert_decomp[op_] = decomp
66
+
67
+
68
+ def add_pre_lower_decomp(op, decomp):
69
+ for op_ in _get_ops(op):
70
+ _pre_lower_decomp[op_] = decomp
71
+
72
+
73
+ def update_pre_convert_decomp(decomps):
74
+ for op, decomp in decomps.items():
75
+ add_pre_convert_decomp(op, decomp)
76
+
77
+
78
+ def update_pre_lower_decomp(decomps):
79
+ for op, decomp in decomps.items():
80
+ add_pre_lower_decomp(op, decomp)
@@ -0,0 +1,42 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """FX graph utilities."""
16
+ import torch
17
+
18
+
19
+ def remove_dangling_args(graph_module: torch.fx.GraphModule):
20
+ """Removes dangling args from the graph."""
21
+ for node in graph_module.graph.nodes:
22
+ if node.op == "placeholder" and not node.users:
23
+ graph_module.graph.erase_node(node)
24
+
25
+ graph_module.graph.lint()
26
+ graph_module.recompile()
27
+ return graph_module
28
+
29
+
30
+ def remove_assert_tensor_metadata_nodes(graph_module: torch.fx.GraphModule):
31
+ """Removes aten._assert_tensor_metadata nodes from the graph.
32
+
33
+ This op is inserted by torch.export to check tensor metadata on custom ops. It
34
+ can break patten matching and lowering.
35
+ """
36
+ for node in graph_module.graph.nodes:
37
+ if node.target == torch.ops.aten._assert_tensor_metadata.default:
38
+ graph_module.graph.erase_node(node)
39
+
40
+ graph_module.graph.lint()
41
+ graph_module.recompile()
42
+ return graph_module
@@ -80,31 +80,3 @@ def run_passes(
80
80
  constants=exported_program.constants,
81
81
  )
82
82
  return exported_program
83
-
84
-
85
- class CanonicalizePass(ExportedProgramPassBase):
86
-
87
- # A dummy decomp table for running ExportedProgram.run_decompositions without
88
- # any op decompositions but just aot_export_module. Due to the check in
89
- # run_decompositions, if None or an empty dict is passed as decomp_table,
90
- # it will run the default aten-coreaten decompositions. Therefore a non-empty
91
- # dummy decomp table is needed.
92
- # Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
93
- _DUMMY_DECOMP_TABLE = {
94
- torch._ops.OperatorBase(): lambda: None,
95
- }
96
-
97
- def call(self, exported_program: torch.export.ExportedProgram):
98
- for node in exported_program.graph.nodes:
99
- if node.target == torch.ops.aten.view.default:
100
- # Passes or torch.export may generate aten.view nodes not respecting the
101
- # tensor memory format. Changes all the aten.view to torch.reshape
102
- # for retracing. If the input memory format is already contiguous,
103
- # retracing in run_decomposition below would decompose torch.reshape
104
- # back to one aten.view.
105
- node.target = lambda self, size: torch.reshape(self.contiguous(), size)
106
-
107
- exported_program = exported_program.run_decompositions(
108
- self._DUMMY_DECOMP_TABLE
109
- )
110
- return ExportedProgramPassResult(exported_program, True)
@@ -12,8 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ai_edge_torch import fx_pass_base
16
- from ai_edge_torch.fx_pass_base import CanonicalizePass
15
+ from ai_edge_torch import fx_infra
16
+ from ai_edge_torch.fx_infra import CanonicalizePass
17
17
  from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass
18
18
  import torch
19
19
 
@@ -21,7 +21,7 @@ import torch
21
21
  def run_generative_passes(
22
22
  exported_program: torch.export.ExportedProgram,
23
23
  ) -> torch.export.ExportedProgram:
24
- return fx_pass_base.run_passes(
24
+ return fx_infra.run_passes(
25
25
  exported_program,
26
26
  [
27
27
  RemoveSDPACompositeZeroMaskPass(),
@@ -12,12 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ai_edge_torch import fx_pass_base
15
+ from ai_edge_torch import fx_infra
16
16
  from ai_edge_torch import lowertools
17
17
  import torch
18
18
 
19
19
 
20
- class RemoveSDPACompositeZeroMaskPass(fx_pass_base.ExportedProgramPassBase):
20
+ class RemoveSDPACompositeZeroMaskPass(fx_infra.ExportedProgramPassBase):
21
21
 
22
22
  def is_zero_tensor_node(self, node: torch.fx.Node):
23
23
  return node.target == torch.ops.aten.zeros.default
@@ -47,4 +47,4 @@ class RemoveSDPACompositeZeroMaskPass(fx_pass_base.ExportedProgramPassBase):
47
47
 
48
48
  exported_program.graph_module.graph.lint()
49
49
  exported_program.graph_module.recompile()
50
- return fx_pass_base.ExportedProgramPassResult(exported_program, True)
50
+ return fx_infra.ExportedProgramPassResult(exported_program, True)
@@ -90,7 +90,6 @@ def mark_pattern(
90
90
  graph_module_to_match = fx_utils.remove_clone_ops(graph_module_to_match)
91
91
 
92
92
  match_with_attrs = pattern.match(graph_module_to_match)
93
-
94
93
  for match, attr in match_with_attrs:
95
94
  match_id = _get_uuid()
96
95
 
@@ -14,8 +14,13 @@
14
14
  # ==============================================================================
15
15
  """FX graph utilities for pattern matching clean ups."""
16
16
 
17
+ from ai_edge_torch import fx_infra
17
18
  import torch
18
19
 
20
+ remove_dangling_args = fx_infra.graph_utils.remove_dangling_args
21
+ remove_assert_tensor_metadata_nodes = (
22
+ fx_infra.graph_utils.remove_assert_tensor_metadata_nodes
23
+ )
19
24
 
20
25
  def is_clone_op(node: torch.fx.Node) -> bool:
21
26
  """Checks if the node is a clone op."""
@@ -46,24 +51,3 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
46
51
  gm.graph.lint()
47
52
  gm.recompile()
48
53
  return gm
49
-
50
-
51
- def remove_dangling_args(gm: torch.fx.GraphModule):
52
- """Removes dangling args from the graph.
53
-
54
- Args:
55
- gm: The graph module to remove dangling args from.
56
-
57
- Returns:
58
- The graph module with dangling args removed.
59
- """
60
- nodes_to_erase = []
61
- for node in gm.graph.nodes:
62
- if node.op == "placeholder" and len(node.users) == 0:
63
- nodes_to_erase.append(node)
64
- for node in nodes_to_erase:
65
- gm.graph.erase_node(node)
66
-
67
- gm.graph.lint()
68
- gm.recompile()
69
- return gm
@@ -17,7 +17,7 @@
17
17
  import dataclasses
18
18
  from typing import Any, Callable, Optional, Union
19
19
 
20
- from ai_edge_torch import fx_pass_base
20
+ from ai_edge_torch import fx_infra
21
21
  from ai_edge_torch.hlfb.mark_pattern import fx_utils
22
22
  import torch
23
23
 
@@ -118,7 +118,7 @@ def _find_scalar_attr(
118
118
  track_args[tracker.pattern_arg_pos] = source
119
119
  ep = torch.export.export(pattern_module, tuple(track_args))
120
120
  if decomp_table is not None:
121
- ep = fx_pass_base.run_passes(ep, [fx_pass_base.CanonicalizePass()])
121
+ ep = fx_infra.run_passes(ep, [fx_infra.CanonicalizePass()])
122
122
  ep = ep.run_decompositions(decomp_table)
123
123
 
124
124
  scalar_locs = set()
@@ -152,13 +152,15 @@ class Pattern:
152
152
  self,
153
153
  name: str,
154
154
  module: Union[Callable, torch.nn.Module],
155
- export_args: tuple[Any],
155
+ export_args: tuple[Any, ...],
156
156
  *,
157
157
  attr_builder: Callable[
158
158
  ["Pattern", GraphModule, InternalMatch], Optional[dict[str, Any]]
159
159
  ] = None,
160
160
  scalar_attr_trackers: list[ScalarAttrTracker] = None,
161
- decomp_table: Optional[dict[torch._ops.OperatorBase, Callable]] = None,
161
+ extra_decomp_table: Optional[
162
+ dict[torch._ops.OperatorBase, Callable]
163
+ ] = None,
162
164
  ):
163
165
  """The PyTorch computation pattern to match against a model.
164
166
 
@@ -177,8 +179,9 @@ class Pattern:
177
179
  scalar_attr_trackers (list[ScalarAttrTracker]): the trackers for scalar
178
180
  args in `export_args`, which are used to track the attr occurrence(s)
179
181
  and retrieve their values from the matched subgraph.
180
- decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]): The
181
- decomposition table to be run on the pattern's exported program.
182
+ extra_decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]):
183
+ Extra decomposition to be run on the pattern's exported program (in
184
+ addition to the default pre_convert_decomp in fx_infra.decomp).
182
185
  """
183
186
  if not isinstance(module, torch.nn.Module):
184
187
 
@@ -200,11 +203,14 @@ class Pattern:
200
203
  )
201
204
 
202
205
  exported_program = torch.export.export(module, export_args)
203
- if decomp_table is not None:
204
- exported_program = fx_pass_base.run_passes(
205
- exported_program, [fx_pass_base.CanonicalizePass()]
206
- )
207
- exported_program = exported_program.run_decompositions(decomp_table)
206
+
207
+ decomp_table = fx_infra.decomp.pre_convert_decomp()
208
+ if extra_decomp_table is not None:
209
+ decomp_table.update(extra_decomp_table)
210
+
211
+ exported_program = fx_infra.safe_run_decompositions(
212
+ exported_program, decomp_table
213
+ )
208
214
 
209
215
  self.exported_program = exported_program
210
216
  self.graph_module = self.exported_program.graph_module
@@ -222,6 +228,9 @@ class Pattern:
222
228
  # sanitization.
223
229
  self.graph_module = fx_utils.remove_clone_ops(self.graph_module)
224
230
  self.graph_module = fx_utils.remove_dangling_args(self.graph_module)
231
+ self.graph_module = fx_utils.remove_assert_tensor_metadata_nodes(
232
+ self.graph_module
233
+ )
225
234
 
226
235
  # Builds list of ordered input and output nodes.
227
236
  self.graph_nodes_map = {}
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Tests for mark_pattern."""
16
16
 
17
+ from ai_edge_torch import fx_infra
17
18
  from ai_edge_torch import lowertools
18
19
  from ai_edge_torch.hlfb import mark_pattern
19
20
  from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
@@ -22,11 +23,13 @@ import torch
22
23
  from absl.testing import absltest as googletest
23
24
 
24
25
 
25
- def _export_stablehlo_mlir(model, args=None):
26
- if not isinstance(model, torch.export.ExportedProgram):
27
- ep = torch.export.export(model, args)
28
- else:
29
- ep = model
26
+ def _export_and_decomp(mod, args):
27
+ ep = torch.export.export(mod, args)
28
+ ep = fx_infra.safe_run_decompositions(ep, fx_infra.decomp.pre_lower_decomp())
29
+ return ep
30
+
31
+
32
+ def _to_mlir(ep: torch.export.ExportedProgram):
30
33
  return lowertools.exported_program_to_mlir_text(ep)
31
34
 
32
35
 
@@ -47,9 +50,9 @@ class TestMarkPattern(googletest.TestCase):
47
50
 
48
51
  model = TestModel().eval()
49
52
  args = (torch.rand(20, 20),)
50
- exported_program = torch.export.export(model, args)
53
+ exported_program = _export_and_decomp(model, args)
51
54
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
52
- mlir = _export_stablehlo_mlir(exported_program)
55
+ mlir = _to_mlir(exported_program)
53
56
 
54
57
  lowertools.assert_string_count(
55
58
  self,
@@ -73,9 +76,9 @@ class TestMarkPattern(googletest.TestCase):
73
76
 
74
77
  model = TestModel().eval()
75
78
  args = (torch.rand(20, 20),)
76
- exported_program = torch.export.export(model, args)
79
+ exported_program = _export_and_decomp(model, args)
77
80
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
78
- mlir = _export_stablehlo_mlir(exported_program)
81
+ mlir = _to_mlir(exported_program)
79
82
 
80
83
  lowertools.assert_string_count(
81
84
  self,
@@ -99,9 +102,9 @@ class TestMarkPattern(googletest.TestCase):
99
102
 
100
103
  model = TestModel().eval()
101
104
  args = (torch.rand(20, 20),)
102
- exported_program = torch.export.export(model, args)
105
+ exported_program = _export_and_decomp(model, args)
103
106
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
104
- mlir = _export_stablehlo_mlir(exported_program)
107
+ mlir = _to_mlir(exported_program)
105
108
 
106
109
  lowertools.assert_string_count(
107
110
  self,
@@ -137,9 +140,9 @@ class TestMarkPattern(googletest.TestCase):
137
140
 
138
141
  model = TestModel().eval()
139
142
  args = (torch.rand(10, 10),)
140
- exported_program = torch.export.export(model, args)
143
+ exported_program = _export_and_decomp(model, args)
141
144
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
142
- mlir = _export_stablehlo_mlir(exported_program)
145
+ mlir = _to_mlir(exported_program)
143
146
 
144
147
  lowertools.assert_string_count(
145
148
  self,
@@ -169,9 +172,9 @@ class TestMarkPattern(googletest.TestCase):
169
172
 
170
173
  model = TestModel().eval()
171
174
  args = (torch.rand(20, 20), torch.rand(20, 20))
172
- exported_program = torch.export.export(model, args)
175
+ exported_program = _export_and_decomp(model, args)
173
176
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
174
- mlir = _export_stablehlo_mlir(exported_program)
177
+ mlir = _to_mlir(exported_program)
175
178
 
176
179
  lowertools.assert_string_count(
177
180
  self,
@@ -59,30 +59,3 @@ def graph_module_flat_inputs(ep: torch.export.ExportedProgram, args, kwargs):
59
59
  ordered_tensor_constants = tuple()
60
60
 
61
61
  return (*param_buffer_values, *flat_args, *ordered_tensor_constants)
62
-
63
-
64
- # TODO(b/331481564): Replace this with CanonicalizePass + run_decomposition
65
- def safe_run_decompositions(exported_program, decomp_table=None):
66
- for node in exported_program.graph.nodes:
67
- if node.target == torch.ops.aten.view.default:
68
- # Passes or torch.export may generate aten.view nodes not respecting the
69
- # tensor memory format. Changes all the aten.view to torch.reshape
70
- # for retracing. If the input memory format is already contiguous,
71
- # retracing in run_decomposition below would decompose torch.reshape
72
- # back to one aten.view.
73
- node.target = lambda self, size: torch.reshape(self.contiguous(), size)
74
-
75
- return exported_program.run_decompositions(decomp_table)
76
-
77
-
78
- def dummy_decomp_table():
79
- """Build dummy decomp table for run_decompositions without any decompositions.
80
-
81
- Compatible for torch<=2.5.
82
-
83
- Returns:
84
- Decomp table for ExportedProgram.run_decompositions.
85
- """
86
- return {
87
- torch._ops.OperatorBase(): lambda: None,
88
- }
@@ -20,6 +20,7 @@ import io
20
20
  import operator
21
21
  from typing import Any, Callable, Optional
22
22
 
23
+ from ai_edge_torch import fx_infra
23
24
  from jax.lib import xla_extension
24
25
  from jax._src.lib.mlir import ir
25
26
  from jax._src.lib.mlir.dialects import func
@@ -302,16 +303,13 @@ def exported_program_to_mlir(
302
303
  exported_program: torch.export.ExportedProgram,
303
304
  ) -> MlirLowered:
304
305
  """Lower the exported program to MLIR."""
305
- exported_program = _torch_future.safe_run_decompositions(
306
- exported_program, lowerings.decompositions()
306
+ exported_program = fx_infra.safe_run_decompositions(
307
+ exported_program,
308
+ fx_infra.decomp.pre_lower_decomp(),
307
309
  )
308
-
309
310
  _convert_i64_to_i32(exported_program)
310
-
311
- # No decompositions but just retracing/cananicalization.
312
- exported_program = _torch_future.safe_run_decompositions(
313
- exported_program, _torch_future.dummy_decomp_table()
314
- )
311
+ # Run decompositions for retracing and cananicalization.
312
+ exported_program = fx_infra.safe_run_decompositions(exported_program, {})
315
313
 
316
314
  # Passes below mutate the exported program to a state not executable by torch.
317
315
  # Do not call run_decompositions after applying the passes.
@@ -15,6 +15,7 @@
15
15
  from . import _basic
16
16
  from . import _batch_norm
17
17
  from . import _convolution
18
+ from . import _decomp_registry
18
19
  from . import _jax_lowerings
19
20
  from . import _layer_norm
20
21
  from . import _quantized_decomposed
@@ -22,6 +23,5 @@ from . import _rand
22
23
  from . import context
23
24
  from . import registry
24
25
  from . import utils
25
- from .decomp import decompositions
26
26
  from .registry import lookup
27
27
  from .registry import lower
@@ -0,0 +1,65 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Torch export decompositions to run before lowering."""
16
+
17
+ from ai_edge_torch import fx_infra
18
+ import torch
19
+
20
+
21
+ fx_infra.decomp.update_pre_lower_decomp(
22
+ torch._decomp.get_decompositions([
23
+ torch.ops.aten.upsample_nearest2d,
24
+ torch.ops.aten._native_batch_norm_legit.no_stats,
25
+ torch.ops.aten._native_batch_norm_legit_functional,
26
+ torch.ops.aten._adaptive_avg_pool2d,
27
+ torch.ops.aten._adaptive_avg_pool3d,
28
+ torch.ops.aten.grid_sampler_2d,
29
+ torch.ops.aten.native_group_norm,
30
+ torch.ops.aten.native_dropout,
31
+ torch.ops.aten.reflection_pad1d,
32
+ torch.ops.aten.reflection_pad2d,
33
+ torch.ops.aten.reflection_pad3d,
34
+ torch.ops.aten.replication_pad1d,
35
+ torch.ops.aten.replication_pad2d,
36
+ torch.ops.aten.replication_pad3d,
37
+ torch.ops.aten.addmm,
38
+ ])
39
+ )
40
+
41
+ fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.roll)
42
+
43
+ # Torch's default einsum impl/decompositions is less efficient and
44
+ # optimized through converter than JAX's impl. Disable einsum
45
+ # decomposition to use JAX bridge for a more efficient lowering.
46
+ fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.einsum.default)
47
+
48
+
49
+ # Override noop aten op decompositions for faster run_decompositions.
50
+ fx_infra.decomp.add_pre_convert_decomp(
51
+ torch.ops.aten.alias.default, lambda x: x
52
+ )
53
+ fx_infra.decomp.add_pre_convert_decomp(
54
+ torch.ops.aten.detach.default, lambda x: x
55
+ )
56
+
57
+ # Override _safe_softmax decompositions with regular softmax.
58
+ # _safe_softmax introduces additional check-select ops to guard extreme
59
+ # input values to softmax, which could make the converted model inefficient
60
+ # on-device.
61
+ if hasattr(torch.ops.aten, "_safe_softmax"):
62
+ fx_infra.decomp.add_pre_convert_decomp(
63
+ torch.ops.aten._safe_softmax.default,
64
+ torch.softmax,
65
+ )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250121"
16
+ __version__ = "0.3.0.dev20250123"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250121
3
+ Version: 0.3.0.dev20250123
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,32 +1,32 @@
1
- ai_edge_torch/__init__.py,sha256=rq9ZtMJLG8yYNC4tNE4rpl94UAUClZW7f4GAr6HBVDQ,1208
1
+ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,1208
2
2
  ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
- ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=Zlgc-dAVs0gcH1g86CJVS-yy-T9eJBuCXDFHRUscOVA,706
5
+ ai_edge_torch/version.py,sha256=szrxg2aB7mcm59IL_QVIqapmbw9Nz8AQ28vc9684bqY,706
7
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
- ai_edge_torch/_convert/conversion.py,sha256=pSDY0CzZQP_jAMjSfQ1O7Ud_AF5ZDeDF-nE3nAu_hoo,5815
7
+ ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
9
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
10
9
  ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
11
10
  ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
12
11
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
13
- ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
14
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
15
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=qb4JBDi4Xca14JJUIcaaZQIJiyqKyHJF49jsRCIFCVA,4335
16
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=NxT-iCOHq3r3jeZ8qhNoPXV5w8l2eRMu4yEcBri3NxY,2398
17
- ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
18
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
12
+ ai_edge_torch/_convert/fx_passes/__init__.py,sha256=dG4WIICk0FqCH9euvbYHHsybRN7B1cYcuxN_OYxmjWo,1263
13
+ ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=a1KhqLetFb_efRHjX4T-zH0vF-U37Ha5I1CPIAsIluE,9211
14
+ ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=3JyjiHpn17Zhfq3yGQXK5LMH71DQPXHb_4GOkP9uAjY,4251
15
+ ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
16
+ ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
17
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=UKC-wM93-oe8spxyFqgybJ0TwnSRw8f-SOA2glCh2FA,890
18
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py,sha256=aWO_zHDF4j_hokoKJQNFIFmua4ysXztsgS6pcyBUht0,1082
19
19
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=S_Bniv6jY16oOoFUzlyECQ0I2HDjG2D1MOI-QYPk3jQ,8061
20
20
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
21
21
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=zoAZ2TXKvxUnWnT11U4tx2uF0J5kkNXydgaW7JzfkXI,13811
22
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=bsYnudRlXp1PJlu4GF25KSogSkBGQPSaecBrUTONKaw,1031
23
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=t94Am3iPbYQekg-rrtc-jS_aDWtEgAAj7pAKHrG0-9U,10563
22
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=OhisegHY2j4cv_m9auCh9Mq9qmm1lUqpFLVO9X-oBlc,1032
23
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=lgoH32l6zAbWTCpa_4-RWkHjqbNaPsBnhSObLIX8dL4,10551
24
24
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
25
25
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=D8VX8SbCzfoyvPgMFHK7yxD7R-bzLxp2gfdKxgrWekA,742
26
26
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
27
27
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
28
28
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
29
- ai_edge_torch/_convert/test/test_convert.py,sha256=gK9QJuLbpjXt0l6tVnzl9Miq6GLkJR-hB67i3VE13Og,17224
29
+ ai_edge_torch/_convert/test/test_convert.py,sha256=o6tuJkD-ESaQxLxJpN104qpchm3LCtPmHinzQxe6PSg,17226
30
30
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
31
31
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
32
32
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
@@ -37,6 +37,12 @@ ai_edge_torch/debug/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf
37
37
  ai_edge_torch/debug/test/test_culprit.py,sha256=fRN-8jJicawJ2mhPRQNAQUZ8AdGg-s0tYMXyhnLAlWw,3875
38
38
  ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOawhbgiSEu96PmioPE,1668
39
39
  ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
40
+ ai_edge_torch/fx_infra/__init__.py,sha256=APjkSqEfwDxcnI8k53rGi3Ef-G2L-M8fdaPGpxXtuiI,1347
41
+ ai_edge_torch/fx_infra/_canonicalize_pass.py,sha256=GDRoDdPVQw--QQFTT5J_C3TVuphL31m6K6F1-67SE4s,1097
42
+ ai_edge_torch/fx_infra/_safe_run_decompositions.py,sha256=ZbWheeZ8ydsxCk2aVGUgUynrkEkBOMjBCzPhS5uq4sU,2595
43
+ ai_edge_torch/fx_infra/decomp.py,sha256=S58SCgwMHYVFl_hJwlJxvu2wcI-AGNn82gel3qmTPrU,2500
44
+ ai_edge_torch/fx_infra/graph_utils.py,sha256=3UZAOHWOUh2LCj1E2_AKQn3gRDILi9JCdqSScjyOd4M,1535
45
+ ai_edge_torch/fx_infra/pass_base.py,sha256=Ic2AlhSoRFscz6l7gJKvWVNMDLQFfAw5kRf84-ZR9qM,2904
40
46
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
41
47
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
48
  ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -115,8 +121,8 @@ ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6
115
121
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=VU0c5pgvrUtaTboT1xuDBGjpKOM85aqtaB_hYfSBuEk,2544
116
122
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
117
123
  ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
118
- ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
119
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
124
+ ai_edge_torch/generative/fx_passes/__init__.py,sha256=4rFrppMRKlTwwZeX1ON_cdp4yUqoTOES161IZQkJF6c,1143
125
+ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=n5TbXdhBZi8jQe4j7-rox_MugMVvW8ReOhkTA3pfQkw,1919
120
126
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
121
127
  ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8kQ4U3YANfSiTJKn8,13776
122
128
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
@@ -159,11 +165,11 @@ ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yG
159
165
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
160
166
  ai_edge_torch/generative/utilities/verifier.py,sha256=6lnBU9Cy5GanB8JWK3-2_VU3PxqunDWGe-SgSLba5Yw,12065
161
167
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
162
- ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=-BYE7MGMxr-VfBy8tAiiOaCqYv8ytJ0w5l2P8B7h3eM,5387
163
- ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=taWLpF5IVglxlsF9HM2dIoKDXuQREaCRAXtJeG5gKzs,2073
164
- ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=7bv9XqRkm1pjxiVL4Cm1cArExnolId8hQKFHtvlkCI8,10061
168
+ ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=JsVmYrM_JEuN_smMHXUsRlo3Liapp7UyktbPpPARwDk,5386
169
+ ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=YCtMgu-4w2BQ5fpnlpWC6IauKPf_tVqc7Ff91OTqlSw,1796
170
+ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=Ui6BrehF3zJJN7uTxKwbO2yCY9mYjbewlQzAxzZv9Es,10274
165
171
  ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
166
- ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=-5UqJyk__1YbUNGuxi4b2sn0CED0W-G337AXwxPGdEs,5567
172
+ ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=5kmOJWCc7sU1Hrqr1y17BtShUrssTfaV1sMyUvdMbsg,5573
167
173
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
168
174
  ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
169
175
  ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5zo3MlX1QYW4c,4513
@@ -172,9 +178,9 @@ ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUG
172
178
  ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
173
179
  ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
174
180
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
175
- ai_edge_torch/odml_torch/_torch_future.py,sha256=6fMr6C6rP3roVRFQxBH9Yr7656WtodZuNvhxLzKad_g,3320
181
+ ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
176
182
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
177
- ai_edge_torch/odml_torch/export.py,sha256=xaJP_dca7P1sC0KvAI5ExxklIoW9nCrRM3vsdvhVfm8,13451
183
+ ai_edge_torch/odml_torch/export.py,sha256=YN7QPrQ8W6T3YVOdyIGadfSQuBroMjIqAMB9FeUa7Ho,13447
178
184
  ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
179
185
  ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
180
186
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
@@ -186,16 +192,16 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW
186
192
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
187
193
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
188
194
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
189
- ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GWFl7WWgExLXu6FEYxnig5_g6hd_Sfnl8690uFg2-CU,1013
195
+ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
190
196
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw560XsTR4XH-ldTdc,9987
191
197
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
192
198
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
199
+ ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
193
200
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=CJHWkmY4aAVQ5dmFsVc3Ox9TPkoLSNOfa96psD4CLRo,11561
194
201
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
195
202
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
196
203
  ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
197
204
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
198
- ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=0dIkVNN9ubgfwSxeaBFfTmWarWWG8Q1M0HX_1FShHik,2610
199
205
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
200
206
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
201
207
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
@@ -206,8 +212,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
206
212
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
207
213
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
208
214
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
209
- ai_edge_torch_nightly-0.3.0.dev20250121.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
210
- ai_edge_torch_nightly-0.3.0.dev20250121.dist-info/METADATA,sha256=NWhBDavlpQ2NnaCILfBYuSK6xCLz9YOAFzDH7Ydrlgg,1966
211
- ai_edge_torch_nightly-0.3.0.dev20250121.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
212
- ai_edge_torch_nightly-0.3.0.dev20250121.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
213
- ai_edge_torch_nightly-0.3.0.dev20250121.dist-info/RECORD,,
215
+ ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
216
+ ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/METADATA,sha256=1IZCBOcKVCWbEfAQvEMgt39cuATDIzpK6AhW_gTnIY4,1966
217
+ ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
218
+ ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
219
+ ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/RECORD,,
@@ -1,69 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Torch export decompositions to run before lowering."""
16
-
17
- import functools
18
-
19
- import torch
20
-
21
-
22
- @functools.cache
23
- def decompositions():
24
- # Base: Core ATen decompositions
25
- decompositions = torch._decomp.core_aten_decompositions()
26
-
27
- decompositions.update(
28
- torch._decomp.get_decompositions([
29
- torch.ops.aten.upsample_nearest2d,
30
- torch.ops.aten._native_batch_norm_legit.no_stats,
31
- torch.ops.aten._native_batch_norm_legit_functional,
32
- torch.ops.aten._adaptive_avg_pool2d,
33
- torch.ops.aten._adaptive_avg_pool3d,
34
- torch.ops.aten.grid_sampler_2d,
35
- torch.ops.aten.native_group_norm,
36
- torch.ops.aten.native_dropout,
37
- torch.ops.aten.reflection_pad1d,
38
- torch.ops.aten.reflection_pad2d,
39
- torch.ops.aten.reflection_pad3d,
40
- torch.ops.aten.replication_pad1d,
41
- torch.ops.aten.replication_pad2d,
42
- torch.ops.aten.replication_pad3d,
43
- torch.ops.aten.addmm,
44
- ])
45
- )
46
-
47
- torch._decomp.remove_decompositions(
48
- decompositions,
49
- [
50
- torch.ops.aten.roll,
51
- # Torch's default einsum impl/decompositions is less efficient and
52
- # optimized through converter than JAX's impl. Disable einsum
53
- # decomposition to use JAX bridge for a more efficient lowering.
54
- torch.ops.aten.einsum.default,
55
- ],
56
- )
57
-
58
- # Override noop aten op decompositions for faster run_decompositions.
59
- decompositions[torch.ops.aten.alias.default] = lambda x: x
60
- decompositions[torch.ops.aten.detach.default] = lambda x: x
61
-
62
- # Override _safe_softmax decompositions with regular softmax.
63
- # _safe_softmax introduces additional check-select ops to guard extreme
64
- # input values to softmax, which could make the converted model inefficient
65
- # on-device.
66
- if hasattr(torch.ops.aten, "_safe_softmax"):
67
- decompositions[torch.ops.aten._safe_softmax.default] = torch.softmax
68
-
69
- return decompositions