ai-edge-torch-nightly 0.3.0.dev20250121__py3-none-any.whl → 0.3.0.dev20250123__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. ai_edge_torch/__init__.py +1 -1
  2. ai_edge_torch/_convert/conversion.py +6 -10
  3. ai_edge_torch/_convert/fx_passes/__init__.py +1 -1
  4. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +6 -3
  5. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +9 -11
  6. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -3
  7. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -0
  8. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py +33 -0
  9. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -0
  10. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +3 -3
  11. ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +4 -4
  12. ai_edge_torch/_convert/test/test_convert.py +2 -2
  13. ai_edge_torch/fx_infra/__init__.py +32 -0
  14. ai_edge_torch/fx_infra/_canonicalize_pass.py +27 -0
  15. ai_edge_torch/fx_infra/_safe_run_decompositions.py +57 -0
  16. ai_edge_torch/fx_infra/decomp.py +80 -0
  17. ai_edge_torch/fx_infra/graph_utils.py +42 -0
  18. ai_edge_torch/{fx_pass_base.py → fx_infra/pass_base.py} +0 -28
  19. ai_edge_torch/generative/fx_passes/__init__.py +3 -3
  20. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -3
  21. ai_edge_torch/hlfb/mark_pattern/__init__.py +0 -1
  22. ai_edge_torch/hlfb/mark_pattern/fx_utils.py +5 -21
  23. ai_edge_torch/hlfb/mark_pattern/pattern.py +20 -11
  24. ai_edge_torch/hlfb/test/test_mark_pattern.py +18 -15
  25. ai_edge_torch/odml_torch/_torch_future.py +0 -27
  26. ai_edge_torch/odml_torch/export.py +6 -8
  27. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
  28. ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +65 -0
  29. ai_edge_torch/version.py +1 -1
  30. {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/METADATA +1 -1
  31. {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/RECORD +34 -28
  32. ai_edge_torch/odml_torch/lowerings/decomp.py +0 -69
  33. {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/LICENSE +0 -0
  34. {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/WHEEL +0 -0
  35. {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py CHANGED
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
16
15
  from ai_edge_torch._config import config
17
16
  from ai_edge_torch._convert.converter import convert
18
17
  from ai_edge_torch._convert.converter import signature
@@ -20,6 +19,7 @@ from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
20
19
  from ai_edge_torch.model import Model
21
20
  from ai_edge_torch.version import __version__
22
21
 
22
+
23
23
  def load(path: str) -> Model:
24
24
  """Imports an ai_edge_torch model from disk.
25
25
 
@@ -17,7 +17,7 @@ import logging
17
17
  from typing import Any, Literal, Optional, Union
18
18
 
19
19
  import ai_edge_torch
20
- from ai_edge_torch import fx_pass_base
20
+ from ai_edge_torch import fx_infra
21
21
  from ai_edge_torch import lowertools
22
22
  from ai_edge_torch import model
23
23
  from ai_edge_torch._convert import fx_passes
@@ -53,7 +53,7 @@ def _run_convert_passes(
53
53
  fx_passes.CanonicalizePass(),
54
54
  ]
55
55
 
56
- exported_program = fx_pass_base.run_passes(exported_program, passes)
56
+ exported_program = fx_infra.run_passes(exported_program, passes)
57
57
  return exported_program
58
58
 
59
59
 
@@ -125,14 +125,10 @@ def convert_signatures(
125
125
  else:
126
126
  exported_program = torch.export.export(**kwargs, strict=True)
127
127
 
128
- if hasattr(torch._decomp, "_decomp_table_to_post_autograd_aten"):
129
- # Available after torch 2.5.0: `_decomp_table_to_post_autograd_aten` is a
130
- # stop-gap table which replicates the old behaviour of post-dispatch IR.
131
- # This could help ensure the collection of aten ops remaining still as the
132
- # implementation of torch.export changes.
133
- exported_program = exported_program.run_decompositions(
134
- torch._decomp._decomp_table_to_post_autograd_aten()
135
- )
128
+ exported_program = fx_infra.safe_run_decompositions(
129
+ exported_program,
130
+ fx_infra.decomp.pre_convert_decomp(),
131
+ )
136
132
  return exported_program
137
133
 
138
134
  exported_programs: torch.export.ExportedProgram = [
@@ -20,4 +20,4 @@ from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import Bu
20
20
  from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
21
21
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
22
22
  from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
23
- from ai_edge_torch.fx_pass_base import CanonicalizePass
23
+ from ai_edge_torch.fx_infra import CanonicalizePass
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from typing import Any, Callable
17
- from ai_edge_torch import fx_pass_base
17
+ from ai_edge_torch import fx_infra
18
18
  from ai_edge_torch import lowertools
19
19
  import torch
20
20
  import torch.utils._pytree as pytree
@@ -25,6 +25,9 @@ _composite_builders: dict[
25
25
 
26
26
 
27
27
  def _register_composite_builder(op):
28
+ # Remove op from pre_convert_decomp to keep this in the decomposed graph.
29
+ fx_infra.decomp.remove_pre_convert_decomp(op)
30
+
28
31
  def inner(func):
29
32
  if isinstance(op, torch._ops.OpOverloadPacket):
30
33
  for overload in op.overloads():
@@ -276,7 +279,7 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
276
279
  node.target = embedding
277
280
 
278
281
 
279
- class BuildAtenCompositePass(fx_pass_base.PassBase):
282
+ class BuildAtenCompositePass(fx_infra.PassBase):
280
283
 
281
284
  def call(self, graph_module: torch.fx.GraphModule):
282
285
  for node in graph_module.graph.nodes:
@@ -285,4 +288,4 @@ class BuildAtenCompositePass(fx_pass_base.PassBase):
285
288
 
286
289
  graph_module.graph.lint()
287
290
  graph_module.recompile()
288
- return fx_pass_base.PassResult(graph_module, True)
291
+ return fx_infra.PassResult(graph_module, True)
@@ -16,7 +16,7 @@
16
16
 
17
17
  import functools
18
18
 
19
- from ai_edge_torch import fx_pass_base
19
+ from ai_edge_torch import fx_infra
20
20
  from ai_edge_torch.hlfb import mark_pattern
21
21
  from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
22
22
  import torch
@@ -41,7 +41,7 @@ def _get_upsample_bilinear2d_pattern():
41
41
  x, scale_factor=2, mode="bilinear", align_corners=False
42
42
  ),
43
43
  export_args=(torch.rand(1, 3, 100, 100),),
44
- decomp_table=_INTERPOLATE_DECOMPOSITIONS,
44
+ extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
45
45
  )
46
46
 
47
47
  @pattern.register_attr_builder
@@ -65,7 +65,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
65
65
  x, scale_factor=2, mode="bilinear", align_corners=True
66
66
  ),
67
67
  export_args=(torch.rand(1, 3, 100, 100),),
68
- decomp_table=_INTERPOLATE_DECOMPOSITIONS,
68
+ extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
69
69
  )
70
70
 
71
71
  @pattern.register_attr_builder
@@ -89,7 +89,7 @@ def _get_interpolate_nearest2d_pattern():
89
89
  x, scale_factor=2, mode="nearest"
90
90
  ),
91
91
  export_args=(torch.rand(1, 3, 100, 100),),
92
- decomp_table=_INTERPOLATE_DECOMPOSITIONS,
92
+ extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
93
93
  )
94
94
 
95
95
  @pattern.register_attr_builder
@@ -104,7 +104,7 @@ def _get_interpolate_nearest2d_pattern():
104
104
  return pattern
105
105
 
106
106
 
107
- class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
107
+ class BuildInterpolateCompositePass(fx_infra.ExportedProgramPassBase):
108
108
 
109
109
  def __init__(self):
110
110
  super().__init__()
@@ -115,11 +115,9 @@ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
115
115
  ]
116
116
 
117
117
  def call(self, exported_program: torch.export.ExportedProgram):
118
- exported_program = fx_pass_base.run_passes(
119
- exported_program, [fx_pass_base.CanonicalizePass()]
120
- )
121
- exported_program = exported_program.run_decompositions(
122
- _INTERPOLATE_DECOMPOSITIONS
118
+ exported_program = fx_infra.safe_run_decompositions(
119
+ exported_program,
120
+ _INTERPOLATE_DECOMPOSITIONS,
123
121
  )
124
122
 
125
123
  graph_module = exported_program.graph_module
@@ -128,4 +126,4 @@ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
128
126
 
129
127
  graph_module.graph.lint()
130
128
  graph_module.recompile()
131
- return fx_pass_base.ExportedProgramPassResult(exported_program, True)
129
+ return fx_infra.ExportedProgramPassResult(exported_program, True)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from ai_edge_torch import fx_pass_base
16
+ from ai_edge_torch import fx_infra
17
17
  from ai_edge_torch import lowertools
18
18
  import torch
19
19
  import torch.utils._pytree as pytree
@@ -61,7 +61,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
61
61
  node.target = debuginfo_writer
62
62
 
63
63
 
64
- class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
64
+ class InjectMlirDebuginfoPass(fx_infra.PassBase):
65
65
  """DEPRECATED: Debuginfo is injected automatically by odml_torch."""
66
66
 
67
67
  def call(self, graph_module: torch.fx.GraphModule):
@@ -70,4 +70,4 @@ class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
70
70
 
71
71
  graph_module.graph.lint()
72
72
  graph_module.recompile()
73
- return fx_pass_base.PassResult(graph_module, True)
73
+ return fx_infra.PassResult(graph_module, True)
@@ -13,4 +13,5 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import _decomp_registry
16
17
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass.pass_body import OptimizeLayoutTransposesPass # NOQA
@@ -0,0 +1,33 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Remove decompositions for ops to keep in layout optimization."""
16
+ from ai_edge_torch import fx_infra
17
+ import torch
18
+
19
+ __all__ = []
20
+
21
+ aten = torch.ops.aten
22
+
23
+ _OPS_TO_KEEP = [
24
+ aten.conv2d,
25
+ aten.max_pool2d,
26
+ aten._softmax.default,
27
+ aten.group_norm.default,
28
+ aten.native_group_norm.default,
29
+ aten.reflection_pad2d.default,
30
+ ]
31
+
32
+ for op in _OPS_TO_KEEP:
33
+ fx_infra.decomp.remove_pre_convert_decomp(op)
@@ -20,6 +20,7 @@ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import uti
20
20
  class OpFuncRegistry(dict):
21
21
 
22
22
  def register(self, op):
23
+
23
24
  ops = utils.flatten_torch_op_overloads(op)
24
25
 
25
26
  def inner(func):
@@ -18,7 +18,7 @@ import operator
18
18
  import os
19
19
  from typing import Union
20
20
 
21
- from ai_edge_torch import fx_pass_base
21
+ from ai_edge_torch import fx_infra
22
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
23
23
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
24
24
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
@@ -30,7 +30,7 @@ import torch.ao.quantization.quantize_pt2e
30
30
  TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
31
31
 
32
32
 
33
- class OptimizeLayoutTransposesPass(fx_pass_base.ExportedProgramPassBase):
33
+ class OptimizeLayoutTransposesPass(fx_infra.ExportedProgramPassBase):
34
34
 
35
35
  def get_source_meta(self, node: torch.fx.Node):
36
36
  keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
@@ -300,4 +300,4 @@ class OptimizeLayoutTransposesPass(fx_pass_base.ExportedProgramPassBase):
300
300
  # Mark const node again for debugging
301
301
  self.mark_const_nodes(exported_program)
302
302
 
303
- return fx_pass_base.ExportedProgramPassResult(exported_program, True)
303
+ return fx_infra.ExportedProgramPassResult(exported_program, True)
@@ -15,11 +15,11 @@
15
15
  """Pass to remove all non user outputs from exported program."""
16
16
 
17
17
 
18
- from ai_edge_torch import fx_pass_base
18
+ from ai_edge_torch import fx_infra
19
19
  import torch
20
20
 
21
21
 
22
- class RemoveNonUserOutputsPass(fx_pass_base.ExportedProgramPassBase):
22
+ class RemoveNonUserOutputsPass(fx_infra.ExportedProgramPassBase):
23
23
  """This pass removes all non user outputs from the exported program's output.
24
24
 
25
25
  The FX graph may output more tensors/data than what user's original model
@@ -47,6 +47,6 @@ class RemoveNonUserOutputsPass(fx_pass_base.ExportedProgramPassBase):
47
47
  node.args = (tuple(new_outputs),)
48
48
  exported_program.graph_signature.output_specs = new_output_specs
49
49
 
50
- exported_program.graph_module.graph.lint()
50
+ exported_program.graph.eliminate_dead_code()
51
51
  exported_program.graph_module.recompile()
52
- return fx_pass_base.ExportedProgramPassResult(exported_program, True)
52
+ return fx_infra.ExportedProgramPassResult(exported_program, True)
@@ -511,7 +511,7 @@ class TestConvert(googletest.TestCase):
511
511
  # Step 1: export resnet18
512
512
  args = (torch.randn(1, 3, 224, 224),)
513
513
  m = torchvision.models.resnet18().eval()
514
- m = torch._export.capture_pre_autograd_graph(m, args)
514
+ m = torch.export.export_for_training(m, args).module()
515
515
 
516
516
  # Step 2: Insert observers or fake quantize modules
517
517
  quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
@@ -533,7 +533,7 @@ class TestConvert(googletest.TestCase):
533
533
  # Step 1: export resnet18
534
534
  args = (torch.randn(1, 3, 224, 224),)
535
535
  m = torchvision.models.resnet18().eval()
536
- m = torch._export.capture_pre_autograd_graph(m, args)
536
+ m = torch.export.export_for_training(m, args).module()
537
537
 
538
538
  # Step 2: Insert observers or fake quantize modules
539
539
  quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
@@ -0,0 +1,32 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from ai_edge_torch.fx_infra import _canonicalize_pass
17
+ from ai_edge_torch.fx_infra import _safe_run_decompositions
18
+ from ai_edge_torch.fx_infra import decomp
19
+ from ai_edge_torch.fx_infra import graph_utils
20
+ from ai_edge_torch.fx_infra import pass_base
21
+
22
+
23
+ PassBase = pass_base.PassBase
24
+ PassResult = pass_base.PassResult
25
+ FxPassBase = pass_base.FxPassBase
26
+ FxPassResult = pass_base.FxPassResult
27
+ ExportedProgramPassBase = pass_base.ExportedProgramPassBase
28
+ ExportedProgramPassResult = pass_base.ExportedProgramPassResult
29
+ run_passes = pass_base.run_passes
30
+
31
+ CanonicalizePass = _canonicalize_pass.CanonicalizePass
32
+ safe_run_decompositions = _safe_run_decompositions.safe_run_decompositions
@@ -0,0 +1,27 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from ai_edge_torch.fx_infra import _safe_run_decompositions
16
+ from ai_edge_torch.fx_infra import pass_base
17
+ import torch
18
+
19
+
20
+ class CanonicalizePass(pass_base.ExportedProgramPassBase):
21
+
22
+ def call(self, exported_program: torch.export.ExportedProgram):
23
+ exported_program = _safe_run_decompositions.safe_run_decompositions(
24
+ exported_program, {}
25
+ )
26
+
27
+ return pass_base.ExportedProgramPassResult(exported_program, True)
@@ -0,0 +1,57 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
16
+ import torch
17
+
18
+
19
+ # A dummy decomp table for running ExportedProgram.run_decompositions without
20
+ # any op decompositions but just aot_export_module. Due to the check in
21
+ # run_decompositions, if None or an empty dict is passed as decomp_table,
22
+ # it will run the default aten-coreaten decompositions. Therefore a non-empty
23
+ # dummy decomp table is needed.
24
+ # Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
25
+ _DUMMY_DECOMP_TABLE = {
26
+ torch._ops.OperatorBase(): lambda: None,
27
+ }
28
+
29
+
30
+ def safe_run_decompositions(exported_program, decomp_table=None):
31
+ """Wrapper for ExportedProgram.run_decompositions to handle unexpected export behavior."""
32
+
33
+ if decomp_table is not None and not decomp_table:
34
+ # Empty decomp table means no op decompositions. Use dummy decomp table
35
+ # instead for backward compatibility.
36
+ decomp_table = _DUMMY_DECOMP_TABLE
37
+
38
+ for node in exported_program.graph.nodes:
39
+ if node.target == torch.ops.aten.view.default:
40
+ # Passes or torch.export may generate aten.view nodes not respecting the
41
+ # tensor memory format. Changes all the aten.view to torch.reshape
42
+ # for retracing. If the input memory format is already contiguous,
43
+ # retracing in run_decomposition below would decompose torch.reshape
44
+ # back to one aten.view.
45
+ node.target = lambda self, size: torch.reshape(self.contiguous(), size)
46
+
47
+ exported_program = exported_program.run_decompositions(decomp_table)
48
+
49
+ if hasattr(torch.ops.aten, "_assert_tensor_metadata"):
50
+ for node in exported_program.graph.nodes:
51
+ if node.target == torch.ops.aten._assert_tensor_metadata.default:
52
+ exported_program.graph.erase_node(node)
53
+
54
+ exported_program.graph.eliminate_dead_code()
55
+ exported_program.graph_module.recompile()
56
+
57
+ return exported_program
@@ -0,0 +1,80 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
16
+ import torch
17
+
18
+ # Decompositions to be run after torch.export before conversion and any
19
+ # passes. Remove ops from this decomposition table if they need to be preserved
20
+ # in passes.
21
+ _pre_convert_decomp = torch._decomp.core_aten_decompositions().copy()
22
+
23
+
24
+ # Decompositions to be run after conversion before odml_torch passes and
25
+ # lowerings.
26
+ _pre_lower_decomp = torch._decomp.core_aten_decompositions().copy()
27
+
28
+
29
+ def _get_ops(op):
30
+ if isinstance(op, torch._ops.OpOverloadPacket):
31
+ return [getattr(op, overload) for overload in op.overloads()]
32
+ else:
33
+ return [op]
34
+
35
+
36
+ def pre_convert_decomp():
37
+ """Decompositions to be run after torch.export before conversion and any passes."""
38
+ return _pre_convert_decomp.copy()
39
+
40
+
41
+ def pre_lower_decomp():
42
+ """Decompositions to be run after conversion before odml_torch passes and lowerings."""
43
+ return _pre_lower_decomp.copy()
44
+
45
+
46
+ def remove_pre_lower_decomp(op):
47
+ # Also remove from pre_convert_decomp which always run before pre_lower_
48
+ # decomp.
49
+ remove_pre_convert_decomp(op)
50
+
51
+ for op_ in _get_ops(op):
52
+ _pre_lower_decomp.pop(op_, None)
53
+
54
+
55
+ def remove_pre_convert_decomp(op):
56
+ for op_ in _get_ops(op):
57
+ _pre_convert_decomp.pop(op_, None)
58
+
59
+
60
+ def add_pre_convert_decomp(op, decomp):
61
+ # Also add decomp to pre_lower_decomp which runs after pre_convert_decomp.
62
+ add_pre_lower_decomp(op, decomp)
63
+
64
+ for op_ in _get_ops(op):
65
+ _pre_convert_decomp[op_] = decomp
66
+
67
+
68
+ def add_pre_lower_decomp(op, decomp):
69
+ for op_ in _get_ops(op):
70
+ _pre_lower_decomp[op_] = decomp
71
+
72
+
73
+ def update_pre_convert_decomp(decomps):
74
+ for op, decomp in decomps.items():
75
+ add_pre_convert_decomp(op, decomp)
76
+
77
+
78
+ def update_pre_lower_decomp(decomps):
79
+ for op, decomp in decomps.items():
80
+ add_pre_lower_decomp(op, decomp)
@@ -0,0 +1,42 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """FX graph utilities."""
16
+ import torch
17
+
18
+
19
+ def remove_dangling_args(graph_module: torch.fx.GraphModule):
20
+ """Removes dangling args from the graph."""
21
+ for node in graph_module.graph.nodes:
22
+ if node.op == "placeholder" and not node.users:
23
+ graph_module.graph.erase_node(node)
24
+
25
+ graph_module.graph.lint()
26
+ graph_module.recompile()
27
+ return graph_module
28
+
29
+
30
+ def remove_assert_tensor_metadata_nodes(graph_module: torch.fx.GraphModule):
31
+ """Removes aten._assert_tensor_metadata nodes from the graph.
32
+
33
+ This op is inserted by torch.export to check tensor metadata on custom ops. It
34
+ can break patten matching and lowering.
35
+ """
36
+ for node in graph_module.graph.nodes:
37
+ if node.target == torch.ops.aten._assert_tensor_metadata.default:
38
+ graph_module.graph.erase_node(node)
39
+
40
+ graph_module.graph.lint()
41
+ graph_module.recompile()
42
+ return graph_module
@@ -80,31 +80,3 @@ def run_passes(
80
80
  constants=exported_program.constants,
81
81
  )
82
82
  return exported_program
83
-
84
-
85
- class CanonicalizePass(ExportedProgramPassBase):
86
-
87
- # A dummy decomp table for running ExportedProgram.run_decompositions without
88
- # any op decompositions but just aot_export_module. Due to the check in
89
- # run_decompositions, if None or an empty dict is passed as decomp_table,
90
- # it will run the default aten-coreaten decompositions. Therefore a non-empty
91
- # dummy decomp table is needed.
92
- # Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
93
- _DUMMY_DECOMP_TABLE = {
94
- torch._ops.OperatorBase(): lambda: None,
95
- }
96
-
97
- def call(self, exported_program: torch.export.ExportedProgram):
98
- for node in exported_program.graph.nodes:
99
- if node.target == torch.ops.aten.view.default:
100
- # Passes or torch.export may generate aten.view nodes not respecting the
101
- # tensor memory format. Changes all the aten.view to torch.reshape
102
- # for retracing. If the input memory format is already contiguous,
103
- # retracing in run_decomposition below would decompose torch.reshape
104
- # back to one aten.view.
105
- node.target = lambda self, size: torch.reshape(self.contiguous(), size)
106
-
107
- exported_program = exported_program.run_decompositions(
108
- self._DUMMY_DECOMP_TABLE
109
- )
110
- return ExportedProgramPassResult(exported_program, True)
@@ -12,8 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ai_edge_torch import fx_pass_base
16
- from ai_edge_torch.fx_pass_base import CanonicalizePass
15
+ from ai_edge_torch import fx_infra
16
+ from ai_edge_torch.fx_infra import CanonicalizePass
17
17
  from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass
18
18
  import torch
19
19
 
@@ -21,7 +21,7 @@ import torch
21
21
  def run_generative_passes(
22
22
  exported_program: torch.export.ExportedProgram,
23
23
  ) -> torch.export.ExportedProgram:
24
- return fx_pass_base.run_passes(
24
+ return fx_infra.run_passes(
25
25
  exported_program,
26
26
  [
27
27
  RemoveSDPACompositeZeroMaskPass(),
@@ -12,12 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ai_edge_torch import fx_pass_base
15
+ from ai_edge_torch import fx_infra
16
16
  from ai_edge_torch import lowertools
17
17
  import torch
18
18
 
19
19
 
20
- class RemoveSDPACompositeZeroMaskPass(fx_pass_base.ExportedProgramPassBase):
20
+ class RemoveSDPACompositeZeroMaskPass(fx_infra.ExportedProgramPassBase):
21
21
 
22
22
  def is_zero_tensor_node(self, node: torch.fx.Node):
23
23
  return node.target == torch.ops.aten.zeros.default
@@ -47,4 +47,4 @@ class RemoveSDPACompositeZeroMaskPass(fx_pass_base.ExportedProgramPassBase):
47
47
 
48
48
  exported_program.graph_module.graph.lint()
49
49
  exported_program.graph_module.recompile()
50
- return fx_pass_base.ExportedProgramPassResult(exported_program, True)
50
+ return fx_infra.ExportedProgramPassResult(exported_program, True)
@@ -90,7 +90,6 @@ def mark_pattern(
90
90
  graph_module_to_match = fx_utils.remove_clone_ops(graph_module_to_match)
91
91
 
92
92
  match_with_attrs = pattern.match(graph_module_to_match)
93
-
94
93
  for match, attr in match_with_attrs:
95
94
  match_id = _get_uuid()
96
95
 
@@ -14,8 +14,13 @@
14
14
  # ==============================================================================
15
15
  """FX graph utilities for pattern matching clean ups."""
16
16
 
17
+ from ai_edge_torch import fx_infra
17
18
  import torch
18
19
 
20
+ remove_dangling_args = fx_infra.graph_utils.remove_dangling_args
21
+ remove_assert_tensor_metadata_nodes = (
22
+ fx_infra.graph_utils.remove_assert_tensor_metadata_nodes
23
+ )
19
24
 
20
25
  def is_clone_op(node: torch.fx.Node) -> bool:
21
26
  """Checks if the node is a clone op."""
@@ -46,24 +51,3 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
46
51
  gm.graph.lint()
47
52
  gm.recompile()
48
53
  return gm
49
-
50
-
51
- def remove_dangling_args(gm: torch.fx.GraphModule):
52
- """Removes dangling args from the graph.
53
-
54
- Args:
55
- gm: The graph module to remove dangling args from.
56
-
57
- Returns:
58
- The graph module with dangling args removed.
59
- """
60
- nodes_to_erase = []
61
- for node in gm.graph.nodes:
62
- if node.op == "placeholder" and len(node.users) == 0:
63
- nodes_to_erase.append(node)
64
- for node in nodes_to_erase:
65
- gm.graph.erase_node(node)
66
-
67
- gm.graph.lint()
68
- gm.recompile()
69
- return gm
@@ -17,7 +17,7 @@
17
17
  import dataclasses
18
18
  from typing import Any, Callable, Optional, Union
19
19
 
20
- from ai_edge_torch import fx_pass_base
20
+ from ai_edge_torch import fx_infra
21
21
  from ai_edge_torch.hlfb.mark_pattern import fx_utils
22
22
  import torch
23
23
 
@@ -118,7 +118,7 @@ def _find_scalar_attr(
118
118
  track_args[tracker.pattern_arg_pos] = source
119
119
  ep = torch.export.export(pattern_module, tuple(track_args))
120
120
  if decomp_table is not None:
121
- ep = fx_pass_base.run_passes(ep, [fx_pass_base.CanonicalizePass()])
121
+ ep = fx_infra.run_passes(ep, [fx_infra.CanonicalizePass()])
122
122
  ep = ep.run_decompositions(decomp_table)
123
123
 
124
124
  scalar_locs = set()
@@ -152,13 +152,15 @@ class Pattern:
152
152
  self,
153
153
  name: str,
154
154
  module: Union[Callable, torch.nn.Module],
155
- export_args: tuple[Any],
155
+ export_args: tuple[Any, ...],
156
156
  *,
157
157
  attr_builder: Callable[
158
158
  ["Pattern", GraphModule, InternalMatch], Optional[dict[str, Any]]
159
159
  ] = None,
160
160
  scalar_attr_trackers: list[ScalarAttrTracker] = None,
161
- decomp_table: Optional[dict[torch._ops.OperatorBase, Callable]] = None,
161
+ extra_decomp_table: Optional[
162
+ dict[torch._ops.OperatorBase, Callable]
163
+ ] = None,
162
164
  ):
163
165
  """The PyTorch computation pattern to match against a model.
164
166
 
@@ -177,8 +179,9 @@ class Pattern:
177
179
  scalar_attr_trackers (list[ScalarAttrTracker]): the trackers for scalar
178
180
  args in `export_args`, which are used to track the attr occurrence(s)
179
181
  and retrieve their values from the matched subgraph.
180
- decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]): The
181
- decomposition table to be run on the pattern's exported program.
182
+ extra_decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]):
183
+ Extra decomposition to be run on the pattern's exported program (in
184
+ addition to the default pre_convert_decomp in fx_infra.decomp).
182
185
  """
183
186
  if not isinstance(module, torch.nn.Module):
184
187
 
@@ -200,11 +203,14 @@ class Pattern:
200
203
  )
201
204
 
202
205
  exported_program = torch.export.export(module, export_args)
203
- if decomp_table is not None:
204
- exported_program = fx_pass_base.run_passes(
205
- exported_program, [fx_pass_base.CanonicalizePass()]
206
- )
207
- exported_program = exported_program.run_decompositions(decomp_table)
206
+
207
+ decomp_table = fx_infra.decomp.pre_convert_decomp()
208
+ if extra_decomp_table is not None:
209
+ decomp_table.update(extra_decomp_table)
210
+
211
+ exported_program = fx_infra.safe_run_decompositions(
212
+ exported_program, decomp_table
213
+ )
208
214
 
209
215
  self.exported_program = exported_program
210
216
  self.graph_module = self.exported_program.graph_module
@@ -222,6 +228,9 @@ class Pattern:
222
228
  # sanitization.
223
229
  self.graph_module = fx_utils.remove_clone_ops(self.graph_module)
224
230
  self.graph_module = fx_utils.remove_dangling_args(self.graph_module)
231
+ self.graph_module = fx_utils.remove_assert_tensor_metadata_nodes(
232
+ self.graph_module
233
+ )
225
234
 
226
235
  # Builds list of ordered input and output nodes.
227
236
  self.graph_nodes_map = {}
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Tests for mark_pattern."""
16
16
 
17
+ from ai_edge_torch import fx_infra
17
18
  from ai_edge_torch import lowertools
18
19
  from ai_edge_torch.hlfb import mark_pattern
19
20
  from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
@@ -22,11 +23,13 @@ import torch
22
23
  from absl.testing import absltest as googletest
23
24
 
24
25
 
25
- def _export_stablehlo_mlir(model, args=None):
26
- if not isinstance(model, torch.export.ExportedProgram):
27
- ep = torch.export.export(model, args)
28
- else:
29
- ep = model
26
+ def _export_and_decomp(mod, args):
27
+ ep = torch.export.export(mod, args)
28
+ ep = fx_infra.safe_run_decompositions(ep, fx_infra.decomp.pre_lower_decomp())
29
+ return ep
30
+
31
+
32
+ def _to_mlir(ep: torch.export.ExportedProgram):
30
33
  return lowertools.exported_program_to_mlir_text(ep)
31
34
 
32
35
 
@@ -47,9 +50,9 @@ class TestMarkPattern(googletest.TestCase):
47
50
 
48
51
  model = TestModel().eval()
49
52
  args = (torch.rand(20, 20),)
50
- exported_program = torch.export.export(model, args)
53
+ exported_program = _export_and_decomp(model, args)
51
54
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
52
- mlir = _export_stablehlo_mlir(exported_program)
55
+ mlir = _to_mlir(exported_program)
53
56
 
54
57
  lowertools.assert_string_count(
55
58
  self,
@@ -73,9 +76,9 @@ class TestMarkPattern(googletest.TestCase):
73
76
 
74
77
  model = TestModel().eval()
75
78
  args = (torch.rand(20, 20),)
76
- exported_program = torch.export.export(model, args)
79
+ exported_program = _export_and_decomp(model, args)
77
80
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
78
- mlir = _export_stablehlo_mlir(exported_program)
81
+ mlir = _to_mlir(exported_program)
79
82
 
80
83
  lowertools.assert_string_count(
81
84
  self,
@@ -99,9 +102,9 @@ class TestMarkPattern(googletest.TestCase):
99
102
 
100
103
  model = TestModel().eval()
101
104
  args = (torch.rand(20, 20),)
102
- exported_program = torch.export.export(model, args)
105
+ exported_program = _export_and_decomp(model, args)
103
106
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
104
- mlir = _export_stablehlo_mlir(exported_program)
107
+ mlir = _to_mlir(exported_program)
105
108
 
106
109
  lowertools.assert_string_count(
107
110
  self,
@@ -137,9 +140,9 @@ class TestMarkPattern(googletest.TestCase):
137
140
 
138
141
  model = TestModel().eval()
139
142
  args = (torch.rand(10, 10),)
140
- exported_program = torch.export.export(model, args)
143
+ exported_program = _export_and_decomp(model, args)
141
144
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
142
- mlir = _export_stablehlo_mlir(exported_program)
145
+ mlir = _to_mlir(exported_program)
143
146
 
144
147
  lowertools.assert_string_count(
145
148
  self,
@@ -169,9 +172,9 @@ class TestMarkPattern(googletest.TestCase):
169
172
 
170
173
  model = TestModel().eval()
171
174
  args = (torch.rand(20, 20), torch.rand(20, 20))
172
- exported_program = torch.export.export(model, args)
175
+ exported_program = _export_and_decomp(model, args)
173
176
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
174
- mlir = _export_stablehlo_mlir(exported_program)
177
+ mlir = _to_mlir(exported_program)
175
178
 
176
179
  lowertools.assert_string_count(
177
180
  self,
@@ -59,30 +59,3 @@ def graph_module_flat_inputs(ep: torch.export.ExportedProgram, args, kwargs):
59
59
  ordered_tensor_constants = tuple()
60
60
 
61
61
  return (*param_buffer_values, *flat_args, *ordered_tensor_constants)
62
-
63
-
64
- # TODO(b/331481564): Replace this with CanonicalizePass + run_decomposition
65
- def safe_run_decompositions(exported_program, decomp_table=None):
66
- for node in exported_program.graph.nodes:
67
- if node.target == torch.ops.aten.view.default:
68
- # Passes or torch.export may generate aten.view nodes not respecting the
69
- # tensor memory format. Changes all the aten.view to torch.reshape
70
- # for retracing. If the input memory format is already contiguous,
71
- # retracing in run_decomposition below would decompose torch.reshape
72
- # back to one aten.view.
73
- node.target = lambda self, size: torch.reshape(self.contiguous(), size)
74
-
75
- return exported_program.run_decompositions(decomp_table)
76
-
77
-
78
- def dummy_decomp_table():
79
- """Build dummy decomp table for run_decompositions without any decompositions.
80
-
81
- Compatible for torch<=2.5.
82
-
83
- Returns:
84
- Decomp table for ExportedProgram.run_decompositions.
85
- """
86
- return {
87
- torch._ops.OperatorBase(): lambda: None,
88
- }
@@ -20,6 +20,7 @@ import io
20
20
  import operator
21
21
  from typing import Any, Callable, Optional
22
22
 
23
+ from ai_edge_torch import fx_infra
23
24
  from jax.lib import xla_extension
24
25
  from jax._src.lib.mlir import ir
25
26
  from jax._src.lib.mlir.dialects import func
@@ -302,16 +303,13 @@ def exported_program_to_mlir(
302
303
  exported_program: torch.export.ExportedProgram,
303
304
  ) -> MlirLowered:
304
305
  """Lower the exported program to MLIR."""
305
- exported_program = _torch_future.safe_run_decompositions(
306
- exported_program, lowerings.decompositions()
306
+ exported_program = fx_infra.safe_run_decompositions(
307
+ exported_program,
308
+ fx_infra.decomp.pre_lower_decomp(),
307
309
  )
308
-
309
310
  _convert_i64_to_i32(exported_program)
310
-
311
- # No decompositions but just retracing/cananicalization.
312
- exported_program = _torch_future.safe_run_decompositions(
313
- exported_program, _torch_future.dummy_decomp_table()
314
- )
311
+ # Run decompositions for retracing and cananicalization.
312
+ exported_program = fx_infra.safe_run_decompositions(exported_program, {})
315
313
 
316
314
  # Passes below mutate the exported program to a state not executable by torch.
317
315
  # Do not call run_decompositions after applying the passes.
@@ -15,6 +15,7 @@
15
15
  from . import _basic
16
16
  from . import _batch_norm
17
17
  from . import _convolution
18
+ from . import _decomp_registry
18
19
  from . import _jax_lowerings
19
20
  from . import _layer_norm
20
21
  from . import _quantized_decomposed
@@ -22,6 +23,5 @@ from . import _rand
22
23
  from . import context
23
24
  from . import registry
24
25
  from . import utils
25
- from .decomp import decompositions
26
26
  from .registry import lookup
27
27
  from .registry import lower
@@ -0,0 +1,65 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Torch export decompositions to run before lowering."""
16
+
17
+ from ai_edge_torch import fx_infra
18
+ import torch
19
+
20
+
21
+ fx_infra.decomp.update_pre_lower_decomp(
22
+ torch._decomp.get_decompositions([
23
+ torch.ops.aten.upsample_nearest2d,
24
+ torch.ops.aten._native_batch_norm_legit.no_stats,
25
+ torch.ops.aten._native_batch_norm_legit_functional,
26
+ torch.ops.aten._adaptive_avg_pool2d,
27
+ torch.ops.aten._adaptive_avg_pool3d,
28
+ torch.ops.aten.grid_sampler_2d,
29
+ torch.ops.aten.native_group_norm,
30
+ torch.ops.aten.native_dropout,
31
+ torch.ops.aten.reflection_pad1d,
32
+ torch.ops.aten.reflection_pad2d,
33
+ torch.ops.aten.reflection_pad3d,
34
+ torch.ops.aten.replication_pad1d,
35
+ torch.ops.aten.replication_pad2d,
36
+ torch.ops.aten.replication_pad3d,
37
+ torch.ops.aten.addmm,
38
+ ])
39
+ )
40
+
41
+ fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.roll)
42
+
43
+ # Torch's default einsum impl/decompositions is less efficient and
44
+ # optimized through converter than JAX's impl. Disable einsum
45
+ # decomposition to use JAX bridge for a more efficient lowering.
46
+ fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.einsum.default)
47
+
48
+
49
+ # Override noop aten op decompositions for faster run_decompositions.
50
+ fx_infra.decomp.add_pre_convert_decomp(
51
+ torch.ops.aten.alias.default, lambda x: x
52
+ )
53
+ fx_infra.decomp.add_pre_convert_decomp(
54
+ torch.ops.aten.detach.default, lambda x: x
55
+ )
56
+
57
+ # Override _safe_softmax decompositions with regular softmax.
58
+ # _safe_softmax introduces additional check-select ops to guard extreme
59
+ # input values to softmax, which could make the converted model inefficient
60
+ # on-device.
61
+ if hasattr(torch.ops.aten, "_safe_softmax"):
62
+ fx_infra.decomp.add_pre_convert_decomp(
63
+ torch.ops.aten._safe_softmax.default,
64
+ torch.softmax,
65
+ )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250121"
16
+ __version__ = "0.3.0.dev20250123"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250121
3
+ Version: 0.3.0.dev20250123
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,32 +1,32 @@
1
- ai_edge_torch/__init__.py,sha256=rq9ZtMJLG8yYNC4tNE4rpl94UAUClZW7f4GAr6HBVDQ,1208
1
+ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,1208
2
2
  ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
- ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=Zlgc-dAVs0gcH1g86CJVS-yy-T9eJBuCXDFHRUscOVA,706
5
+ ai_edge_torch/version.py,sha256=szrxg2aB7mcm59IL_QVIqapmbw9Nz8AQ28vc9684bqY,706
7
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
- ai_edge_torch/_convert/conversion.py,sha256=pSDY0CzZQP_jAMjSfQ1O7Ud_AF5ZDeDF-nE3nAu_hoo,5815
7
+ ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
9
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
10
9
  ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
11
10
  ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
12
11
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
13
- ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
14
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
15
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=qb4JBDi4Xca14JJUIcaaZQIJiyqKyHJF49jsRCIFCVA,4335
16
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=NxT-iCOHq3r3jeZ8qhNoPXV5w8l2eRMu4yEcBri3NxY,2398
17
- ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
18
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
12
+ ai_edge_torch/_convert/fx_passes/__init__.py,sha256=dG4WIICk0FqCH9euvbYHHsybRN7B1cYcuxN_OYxmjWo,1263
13
+ ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=a1KhqLetFb_efRHjX4T-zH0vF-U37Ha5I1CPIAsIluE,9211
14
+ ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=3JyjiHpn17Zhfq3yGQXK5LMH71DQPXHb_4GOkP9uAjY,4251
15
+ ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
16
+ ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
17
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=UKC-wM93-oe8spxyFqgybJ0TwnSRw8f-SOA2glCh2FA,890
18
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py,sha256=aWO_zHDF4j_hokoKJQNFIFmua4ysXztsgS6pcyBUht0,1082
19
19
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=S_Bniv6jY16oOoFUzlyECQ0I2HDjG2D1MOI-QYPk3jQ,8061
20
20
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
21
21
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=zoAZ2TXKvxUnWnT11U4tx2uF0J5kkNXydgaW7JzfkXI,13811
22
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=bsYnudRlXp1PJlu4GF25KSogSkBGQPSaecBrUTONKaw,1031
23
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=t94Am3iPbYQekg-rrtc-jS_aDWtEgAAj7pAKHrG0-9U,10563
22
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=OhisegHY2j4cv_m9auCh9Mq9qmm1lUqpFLVO9X-oBlc,1032
23
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=lgoH32l6zAbWTCpa_4-RWkHjqbNaPsBnhSObLIX8dL4,10551
24
24
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
25
25
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=D8VX8SbCzfoyvPgMFHK7yxD7R-bzLxp2gfdKxgrWekA,742
26
26
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
27
27
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
28
28
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
29
- ai_edge_torch/_convert/test/test_convert.py,sha256=gK9QJuLbpjXt0l6tVnzl9Miq6GLkJR-hB67i3VE13Og,17224
29
+ ai_edge_torch/_convert/test/test_convert.py,sha256=o6tuJkD-ESaQxLxJpN104qpchm3LCtPmHinzQxe6PSg,17226
30
30
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
31
31
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
32
32
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
@@ -37,6 +37,12 @@ ai_edge_torch/debug/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf
37
37
  ai_edge_torch/debug/test/test_culprit.py,sha256=fRN-8jJicawJ2mhPRQNAQUZ8AdGg-s0tYMXyhnLAlWw,3875
38
38
  ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOawhbgiSEu96PmioPE,1668
39
39
  ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
40
+ ai_edge_torch/fx_infra/__init__.py,sha256=APjkSqEfwDxcnI8k53rGi3Ef-G2L-M8fdaPGpxXtuiI,1347
41
+ ai_edge_torch/fx_infra/_canonicalize_pass.py,sha256=GDRoDdPVQw--QQFTT5J_C3TVuphL31m6K6F1-67SE4s,1097
42
+ ai_edge_torch/fx_infra/_safe_run_decompositions.py,sha256=ZbWheeZ8ydsxCk2aVGUgUynrkEkBOMjBCzPhS5uq4sU,2595
43
+ ai_edge_torch/fx_infra/decomp.py,sha256=S58SCgwMHYVFl_hJwlJxvu2wcI-AGNn82gel3qmTPrU,2500
44
+ ai_edge_torch/fx_infra/graph_utils.py,sha256=3UZAOHWOUh2LCj1E2_AKQn3gRDILi9JCdqSScjyOd4M,1535
45
+ ai_edge_torch/fx_infra/pass_base.py,sha256=Ic2AlhSoRFscz6l7gJKvWVNMDLQFfAw5kRf84-ZR9qM,2904
40
46
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
41
47
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
48
  ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -115,8 +121,8 @@ ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6
115
121
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=VU0c5pgvrUtaTboT1xuDBGjpKOM85aqtaB_hYfSBuEk,2544
116
122
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
117
123
  ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
118
- ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
119
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
124
+ ai_edge_torch/generative/fx_passes/__init__.py,sha256=4rFrppMRKlTwwZeX1ON_cdp4yUqoTOES161IZQkJF6c,1143
125
+ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=n5TbXdhBZi8jQe4j7-rox_MugMVvW8ReOhkTA3pfQkw,1919
120
126
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
121
127
  ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8kQ4U3YANfSiTJKn8,13776
122
128
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
@@ -159,11 +165,11 @@ ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yG
159
165
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
160
166
  ai_edge_torch/generative/utilities/verifier.py,sha256=6lnBU9Cy5GanB8JWK3-2_VU3PxqunDWGe-SgSLba5Yw,12065
161
167
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
162
- ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=-BYE7MGMxr-VfBy8tAiiOaCqYv8ytJ0w5l2P8B7h3eM,5387
163
- ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=taWLpF5IVglxlsF9HM2dIoKDXuQREaCRAXtJeG5gKzs,2073
164
- ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=7bv9XqRkm1pjxiVL4Cm1cArExnolId8hQKFHtvlkCI8,10061
168
+ ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=JsVmYrM_JEuN_smMHXUsRlo3Liapp7UyktbPpPARwDk,5386
169
+ ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=YCtMgu-4w2BQ5fpnlpWC6IauKPf_tVqc7Ff91OTqlSw,1796
170
+ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=Ui6BrehF3zJJN7uTxKwbO2yCY9mYjbewlQzAxzZv9Es,10274
165
171
  ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
166
- ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=-5UqJyk__1YbUNGuxi4b2sn0CED0W-G337AXwxPGdEs,5567
172
+ ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=5kmOJWCc7sU1Hrqr1y17BtShUrssTfaV1sMyUvdMbsg,5573
167
173
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
168
174
  ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
169
175
  ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5zo3MlX1QYW4c,4513
@@ -172,9 +178,9 @@ ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUG
172
178
  ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
173
179
  ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
174
180
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
175
- ai_edge_torch/odml_torch/_torch_future.py,sha256=6fMr6C6rP3roVRFQxBH9Yr7656WtodZuNvhxLzKad_g,3320
181
+ ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
176
182
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
177
- ai_edge_torch/odml_torch/export.py,sha256=xaJP_dca7P1sC0KvAI5ExxklIoW9nCrRM3vsdvhVfm8,13451
183
+ ai_edge_torch/odml_torch/export.py,sha256=YN7QPrQ8W6T3YVOdyIGadfSQuBroMjIqAMB9FeUa7Ho,13447
178
184
  ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
179
185
  ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
180
186
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
@@ -186,16 +192,16 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW
186
192
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
187
193
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
188
194
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
189
- ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GWFl7WWgExLXu6FEYxnig5_g6hd_Sfnl8690uFg2-CU,1013
195
+ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
190
196
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw560XsTR4XH-ldTdc,9987
191
197
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
192
198
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
199
+ ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
193
200
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=CJHWkmY4aAVQ5dmFsVc3Ox9TPkoLSNOfa96psD4CLRo,11561
194
201
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
195
202
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
196
203
  ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
197
204
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
198
- ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=0dIkVNN9ubgfwSxeaBFfTmWarWWG8Q1M0HX_1FShHik,2610
199
205
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
200
206
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
201
207
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
@@ -206,8 +212,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
206
212
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
207
213
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
208
214
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
209
- ai_edge_torch_nightly-0.3.0.dev20250121.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
210
- ai_edge_torch_nightly-0.3.0.dev20250121.dist-info/METADATA,sha256=NWhBDavlpQ2NnaCILfBYuSK6xCLz9YOAFzDH7Ydrlgg,1966
211
- ai_edge_torch_nightly-0.3.0.dev20250121.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
212
- ai_edge_torch_nightly-0.3.0.dev20250121.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
213
- ai_edge_torch_nightly-0.3.0.dev20250121.dist-info/RECORD,,
215
+ ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
216
+ ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/METADATA,sha256=1IZCBOcKVCWbEfAQvEMgt39cuATDIzpK6AhW_gTnIY4,1966
217
+ ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
218
+ ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
219
+ ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/RECORD,,
@@ -1,69 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Torch export decompositions to run before lowering."""
16
-
17
- import functools
18
-
19
- import torch
20
-
21
-
22
- @functools.cache
23
- def decompositions():
24
- # Base: Core ATen decompositions
25
- decompositions = torch._decomp.core_aten_decompositions()
26
-
27
- decompositions.update(
28
- torch._decomp.get_decompositions([
29
- torch.ops.aten.upsample_nearest2d,
30
- torch.ops.aten._native_batch_norm_legit.no_stats,
31
- torch.ops.aten._native_batch_norm_legit_functional,
32
- torch.ops.aten._adaptive_avg_pool2d,
33
- torch.ops.aten._adaptive_avg_pool3d,
34
- torch.ops.aten.grid_sampler_2d,
35
- torch.ops.aten.native_group_norm,
36
- torch.ops.aten.native_dropout,
37
- torch.ops.aten.reflection_pad1d,
38
- torch.ops.aten.reflection_pad2d,
39
- torch.ops.aten.reflection_pad3d,
40
- torch.ops.aten.replication_pad1d,
41
- torch.ops.aten.replication_pad2d,
42
- torch.ops.aten.replication_pad3d,
43
- torch.ops.aten.addmm,
44
- ])
45
- )
46
-
47
- torch._decomp.remove_decompositions(
48
- decompositions,
49
- [
50
- torch.ops.aten.roll,
51
- # Torch's default einsum impl/decompositions is less efficient and
52
- # optimized through converter than JAX's impl. Disable einsum
53
- # decomposition to use JAX bridge for a more efficient lowering.
54
- torch.ops.aten.einsum.default,
55
- ],
56
- )
57
-
58
- # Override noop aten op decompositions for faster run_decompositions.
59
- decompositions[torch.ops.aten.alias.default] = lambda x: x
60
- decompositions[torch.ops.aten.detach.default] = lambda x: x
61
-
62
- # Override _safe_softmax decompositions with regular softmax.
63
- # _safe_softmax introduces additional check-select ops to guard extreme
64
- # input values to softmax, which could make the converted model inefficient
65
- # on-device.
66
- if hasattr(torch.ops.aten, "_safe_softmax"):
67
- decompositions[torch.ops.aten._safe_softmax.default] = torch.softmax
68
-
69
- return decompositions