ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (91) hide show
  1. ai_edge_torch/__init__.py +30 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +330 -0
  5. ai_edge_torch/convert/converter.py +171 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
  9. ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +273 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +171 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/debug/__init__.py +16 -0
  27. ai_edge_torch/debug/culprit.py +423 -0
  28. ai_edge_torch/debug/test/__init__.py +14 -0
  29. ai_edge_torch/debug/test/test_culprit.py +133 -0
  30. ai_edge_torch/debug/utils.py +48 -0
  31. ai_edge_torch/experimental/__init__.py +14 -0
  32. ai_edge_torch/generative/__init__.py +14 -0
  33. ai_edge_torch/generative/examples/__init__.py +14 -0
  34. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  35. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  36. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  37. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  39. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  40. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  42. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  43. ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
  44. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
  46. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  47. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  48. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  49. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  50. ai_edge_torch/generative/layers/__init__.py +14 -0
  51. ai_edge_torch/generative/layers/attention.py +288 -0
  52. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  53. ai_edge_torch/generative/layers/builder.py +103 -0
  54. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  55. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  56. ai_edge_torch/generative/layers/model_config.py +135 -0
  57. ai_edge_torch/generative/layers/normalization.py +62 -0
  58. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  59. ai_edge_torch/generative/quantize/__init__.py +14 -0
  60. ai_edge_torch/generative/quantize/example.py +45 -0
  61. ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
  62. ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
  63. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  64. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  65. ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
  66. ai_edge_torch/generative/test/__init__.py +14 -0
  67. ai_edge_torch/generative/test/test_model_conversion.py +201 -0
  68. ai_edge_torch/generative/test/test_quantize.py +109 -0
  69. ai_edge_torch/generative/utilities/__init__.py +15 -0
  70. ai_edge_torch/generative/utilities/loader.py +290 -0
  71. ai_edge_torch/generative/utilities/t5_loader.py +467 -0
  72. ai_edge_torch/hlfb/__init__.py +16 -0
  73. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  74. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  75. ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
  76. ai_edge_torch/hlfb/test/__init__.py +14 -0
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  79. ai_edge_torch/model.py +134 -0
  80. ai_edge_torch/quantize/__init__.py +16 -0
  81. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  83. ai_edge_torch/quantize/quant_config.py +85 -0
  84. ai_edge_torch/testing/__init__.py +14 -0
  85. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  86. ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
  87. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
  88. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
  89. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
  90. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
  91. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
@@ -0,0 +1,59 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Sequence, Union
17
+
18
+ from torch.export import ExportedProgram
19
+ from torch.fx.passes.infra.pass_manager import pass_result_wrapper
20
+ import torch.utils._pytree as pytree
21
+
22
+ from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
23
+ from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
24
+ from ai_edge_torch.convert.fx_passes._pass_base import FxPassBase
25
+ from ai_edge_torch.convert.fx_passes._pass_base import FxPassResult
26
+ from ai_edge_torch.convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA
27
+ from ai_edge_torch.convert.fx_passes.build_upsample_bilinear2d_composite_pass import BuildUpsampleBilinear2DCompositePass # NOQA
28
+ from ai_edge_torch.convert.fx_passes.canonicalize_pass import CanonicalizePass
29
+ from ai_edge_torch.convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
30
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
31
+
32
+
33
+ # TODO(cnchan): make a PassManager class.
34
+ def run_passes(
35
+ exported_program: ExportedProgram,
36
+ passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
37
+ ) -> ExportedProgram:
38
+ passes, _ = pytree.tree_flatten(passes)
39
+ for pass_ in passes:
40
+ if not isinstance(pass_, ExportedProgramPassBase):
41
+ pass_ = pass_result_wrapper(pass_)
42
+ if isinstance(pass_, ExportedProgramPassBase):
43
+ exported_program = pass_(exported_program).exported_program
44
+ else:
45
+ gm = exported_program.graph_module
46
+ gm, modified = pass_(gm)
47
+ if modified and gm is not exported_program.graph_module:
48
+ exported_program = ExportedProgram(
49
+ root=gm,
50
+ graph=gm.graph,
51
+ graph_signature=exported_program.graph_signature,
52
+ state_dict=exported_program.state_dict,
53
+ range_constraints=exported_program.range_constraints,
54
+ module_call_graph=exported_program.module_call_graph,
55
+ example_inputs=exported_program.example_inputs,
56
+ verifier=exported_program.verifier,
57
+ constants=exported_program.constants,
58
+ )
59
+ return exported_program
@@ -0,0 +1,49 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import abc
17
+ from collections import namedtuple
18
+
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+ from torch.fx.passes.infra.pass_base import PassBase as FxPassBase
22
+ from torch.fx.passes.infra.pass_base import PassResult as FxPassResult
23
+
24
+
25
+ class ExportedProgramPassResult(
26
+ namedtuple("ExportedProgramPassResult", ["exported_program", "modified"])
27
+ ):
28
+
29
+ def __new__(cls, exported_program, modified):
30
+ return super().__new__(cls, exported_program, modified)
31
+
32
+
33
+ class ExportedProgramPassBase(abc.ABC):
34
+
35
+ def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
36
+ self.requires(exported_program)
37
+ res = self.call(exported_program)
38
+ self.ensures(exported_program)
39
+ return res
40
+
41
+ @abc.abstractmethod
42
+ def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
43
+ pass
44
+
45
+ def requires(self, exported_program: ExportedProgram) -> None:
46
+ pass
47
+
48
+ def ensures(self, exported_program: ExportedProgram) -> None:
49
+ pass
@@ -0,0 +1,192 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import copy
17
+ import functools
18
+ from typing import Any, Callable
19
+
20
+ import torch
21
+ from torch.fx import GraphModule
22
+ from torch.fx import Node
23
+ from torch.fx.passes.infra.pass_base import PassBase
24
+ from torch.fx.passes.infra.pass_base import PassResult
25
+ import torch.utils._pytree as pytree
26
+
27
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
28
+
29
+ _composite_builders: dict[Callable, Callable[[GraphModule, Node], None]] = {}
30
+
31
+
32
+ def _register_composite_builder(op):
33
+ def inner(func):
34
+ if isinstance(op, torch._ops.OpOverloadPacket):
35
+ for overload in v.overloads():
36
+ _composite_builders[getattr(v, overload)] = func
37
+ else:
38
+ _composite_builders[op] = func
39
+ return func
40
+
41
+ return inner
42
+
43
+
44
+ def _tree_map_to_composite_attr_values(values, *, stringify_incompatible_values=True):
45
+
46
+ def convert(value):
47
+ nonlocal stringify_incompatible_values
48
+ if value is None:
49
+ return "py_None"
50
+ if isinstance(value, (str, int, float, bool)):
51
+ return value
52
+
53
+ if stringify_incompatible_values:
54
+ return str(value)
55
+ return value
56
+
57
+ return pytree.tree_map(convert, values)
58
+
59
+
60
+ class TorchOpArgumentsMapper:
61
+
62
+ def __init__(self, op):
63
+ if isinstance(op, torch._ops.OpOverloadPacket):
64
+ op = op.default
65
+
66
+ assert hasattr(op, "_schema")
67
+ self.op = op
68
+ self.arg_specs = [(spec.name, spec.default_value) for spec in op._schema.arguments]
69
+
70
+ def get_full_kwargs(self, args, kwargs=None) -> dict[str, Any]:
71
+ """Inspect the op's schema and extract all its args and kwargs
72
+ into one single kwargs dict, with default values for those
73
+ unspecified args and kwargs.
74
+ """
75
+ full_kwargs = {**(kwargs or {})}
76
+
77
+ for arg, (name, default_value) in zip(args, self.arg_specs):
78
+ full_kwargs[name] = arg
79
+
80
+ for name, default_value in self.arg_specs[len(args) :]:
81
+ if name not in full_kwargs:
82
+ full_kwargs[name] = default_value
83
+
84
+ return full_kwargs
85
+
86
+
87
+ @_register_composite_builder(torch.ops.aten.hardswish.default)
88
+ def _aten_hardswish(gm: GraphModule, node: Node):
89
+ op = node.target
90
+
91
+ def hardswish(self: torch.Tensor):
92
+ nonlocal op
93
+ builder = StableHLOCompositeBuilder("aten.hardswish.default")
94
+ self = builder.mark_inputs(self)
95
+ output = op(self)
96
+ output = builder.mark_outputs(output)
97
+ return output
98
+
99
+ node.target = hardswish
100
+
101
+
102
+ @_register_composite_builder(torch.ops.aten.avg_pool2d.default)
103
+ def _aten_avg_pool2d(gm: GraphModule, node: Node):
104
+ op = node.target
105
+ args_mapper = TorchOpArgumentsMapper(op)
106
+
107
+ def avg_pool2d(*args, **kwargs):
108
+ nonlocal op, args_mapper
109
+
110
+ full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
111
+
112
+ def is_same_padding(
113
+ input_shape: list[int],
114
+ kernel_size: list[int],
115
+ stride: list[int],
116
+ padding: list[int],
117
+ ):
118
+ for dim_input_size, dim_kernel_size, dim_stride, dim_padding in zip(
119
+ input_shape, kernel_size, stride, padding
120
+ ):
121
+ dim_output_size = int((dim_input_size + dim_stride - 1) / dim_stride)
122
+ padding_needed = max(
123
+ 0, (dim_output_size - 1) * dim_stride + dim_kernel_size - dim_input_size
124
+ )
125
+ if padding_needed % 2 != 0:
126
+ return False
127
+
128
+ if padding_needed // 2 != dim_padding:
129
+ return False
130
+ return True
131
+
132
+ def is_valid_padding(padding: list[int]):
133
+ return not any(padding)
134
+
135
+ # We prefer to avoid passing empty arrays to composite attributes
136
+ # as they will be lowered to an ArrayAttr so canonicalizing according
137
+ # to the default behaviour here.
138
+ if not full_kwargs["stride"]:
139
+ full_kwargs["stride"] = full_kwargs["kernel_size"]
140
+
141
+ # Only wrap in a composite when the underlying converter can handle it.
142
+ # TODO We should be able to remove this if the converter can inline composites when it can not handle them.
143
+
144
+ # We don't cover any cases where ceil_mode is True or divisor_override is set.
145
+ if full_kwargs["ceil_mode"] or full_kwargs["divisor_override"] is not None:
146
+ return op(*args, **kwargs)
147
+
148
+ # We also can not cover a case where count_include_pad is False but the padding is custom.
149
+ if (
150
+ not full_kwargs["count_include_pad"]
151
+ and not is_valid_padding(full_kwargs["padding"])
152
+ and not is_same_padding(
153
+ list(full_kwargs["self"].shape)[2:],
154
+ full_kwargs["kernel_size"],
155
+ full_kwargs["stride"],
156
+ full_kwargs["padding"],
157
+ )
158
+ ):
159
+ return op(*args, **kwargs)
160
+
161
+ builder = StableHLOCompositeBuilder(
162
+ "aten.avg_pool2d.default",
163
+ attr=_tree_map_to_composite_attr_values(
164
+ {
165
+ "kernel_size": full_kwargs["kernel_size"],
166
+ "stride": full_kwargs["stride"],
167
+ "padding": full_kwargs["padding"],
168
+ "ceil_mode": full_kwargs["ceil_mode"],
169
+ "count_include_pad": full_kwargs["count_include_pad"],
170
+ "divisor_override": full_kwargs["divisor_override"],
171
+ }
172
+ ),
173
+ )
174
+
175
+ full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"])
176
+ output = op(**full_kwargs)
177
+ output = builder.mark_outputs(output)
178
+ return output
179
+
180
+ node.target = avg_pool2d
181
+
182
+
183
+ class BuildAtenCompositePass(PassBase):
184
+
185
+ def call(self, graph_module: GraphModule):
186
+ for node in graph_module.graph.nodes:
187
+ if node.target in _composite_builders:
188
+ _composite_builders[node.target](graph_module, node)
189
+
190
+ graph_module.graph.lint()
191
+ graph_module.recompile()
192
+ return PassResult(graph_module, True)
@@ -0,0 +1,84 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import functools
17
+
18
+ import torch
19
+
20
+ from ai_edge_torch.convert.fx_passes import FxPassBase
21
+ from ai_edge_torch.convert.fx_passes import FxPassResult
22
+ from ai_edge_torch.hlfb import mark_pattern
23
+
24
+
25
+ @functools.cache
26
+ def _get_upsample_bilinear2d_pattern():
27
+ pattern = mark_pattern.Pattern(
28
+ "odml.upsample_bilinear2d",
29
+ lambda x: torch.nn.functional.interpolate(
30
+ x, scale_factor=2, mode="bilinear", align_corners=False
31
+ ),
32
+ export_args=(torch.rand(1, 3, 100, 100),),
33
+ )
34
+
35
+ @pattern.register_attr_builder
36
+ def attr_builder(pattern, graph_module, internal_match):
37
+ output = internal_match.returning_nodes[0]
38
+ output_h, output_w = output.meta["val"].shape[-2:]
39
+ return {
40
+ "output": (int(output_h), int(output_w)),
41
+ "align_corners": False,
42
+ }
43
+
44
+ return pattern
45
+
46
+
47
+ @functools.cache
48
+ def _get_upsample_bilinear2d_align_corners_pattern():
49
+ pattern = mark_pattern.Pattern(
50
+ "odml.upsample_bilinear2d",
51
+ lambda x: torch.nn.functional.interpolate(
52
+ x, scale_factor=2, mode="bilinear", align_corners=True
53
+ ),
54
+ export_args=(torch.rand(1, 3, 100, 100),),
55
+ )
56
+
57
+ @pattern.register_attr_builder
58
+ def attr_builder(graph_module, pattern, internal_match):
59
+ output = internal_match.returning_nodes[0]
60
+ output_h, output_w = output.meta["val"].shape[-2:]
61
+ return {
62
+ "output": (int(output_h), int(output_w)),
63
+ "align_corners": True,
64
+ }
65
+
66
+ return pattern
67
+
68
+
69
+ class BuildUpsampleBilinear2DCompositePass(FxPassBase):
70
+
71
+ def __init__(self):
72
+ super().__init__()
73
+ self._patterns = [
74
+ _get_upsample_bilinear2d_pattern(),
75
+ _get_upsample_bilinear2d_align_corners_pattern(),
76
+ ]
77
+
78
+ def call(self, graph_module: torch.fx.GraphModule):
79
+ for pattern in self._patterns:
80
+ graph_module = mark_pattern.mark_pattern(graph_module, pattern)
81
+
82
+ graph_module.graph.lint()
83
+ graph_module.recompile()
84
+ return FxPassResult(graph_module, True)
@@ -0,0 +1,37 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import torch
17
+ from torch.export import ExportedProgram
18
+
19
+ from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
20
+ from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
21
+
22
+ # A dummy decomp table for running ExportedProgram.run_decompositions without
23
+ # any op decompositions but just aot_export_module. Due to the check in
24
+ # run_decompositions, if None or an empty dict is passed as decomp_table,
25
+ # it will run the default aten-coreaten decompositions. Therefore a non-empty
26
+ # dummy decomp table is needed.
27
+ # Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
28
+ _dummy_decomp_table = {
29
+ torch._ops.OperatorBase(): lambda: None,
30
+ }
31
+
32
+
33
+ class CanonicalizePass(ExportedProgramPassBase):
34
+
35
+ def call(self, exported_program: ExportedProgram):
36
+ exported_program = exported_program.run_decompositions(_dummy_decomp_table)
37
+ return ExportedProgramPassResult(exported_program, True)
@@ -0,0 +1,73 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import torch
17
+ from torch.fx.passes.infra.pass_base import PassBase
18
+ from torch.fx.passes.infra.pass_base import PassResult
19
+ import torch.utils._pytree as pytree
20
+ import torch_xla.experimental.xla_mlir_debuginfo # Import required to register torch.ops.xla.write_mlir_debuginfo
21
+
22
+
23
+ def _get_mlir_debuginfo(node: torch.fx.Node):
24
+ def class_fullname(cls):
25
+ module = cls.__module__
26
+ if module == "builtins":
27
+ return cls.__qualname__
28
+ return module + "." + cls.__qualname__
29
+
30
+ def get_hierarchy(node: torch.fx.Node):
31
+ nn_module_stack = node.meta.get("nn_module_stack", {})
32
+ layers = []
33
+ for name, layer in nn_module_stack.values():
34
+ iid = ("_" + name.split(".")[-1]) if name else ""
35
+ layer_str = layer if isinstance(layer, str) else class_fullname(layer)
36
+ layers.append(layer_str + iid)
37
+
38
+ hierachy_str = "/".join(layers) + ";"
39
+ return hierachy_str
40
+
41
+ # TODO(yijieyang): Encode aten op and attrs.
42
+ return get_hierarchy(node)
43
+
44
+
45
+ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
46
+ if not node.op.startswith("call_function"):
47
+ return
48
+
49
+ target = node.target
50
+ debuginfo = _get_mlir_debuginfo(node)
51
+
52
+ def debuginfo_writer(*args, **kwargs):
53
+ nonlocal target, debuginfo
54
+ outputs = target(*args, **kwargs)
55
+ outputs = pytree.tree_map_only(
56
+ torch.Tensor,
57
+ lambda x: torch.ops.xla.write_mlir_debuginfo(x, debuginfo),
58
+ outputs,
59
+ )
60
+ return outputs
61
+
62
+ node.target = debuginfo_writer
63
+
64
+
65
+ class InjectMlirDebuginfoPass(PassBase):
66
+
67
+ def call(self, graph_module: torch.fx.GraphModule):
68
+ for node in graph_module.graph.nodes:
69
+ _wrap_call_function_node_with_debuginfo_writer(node)
70
+
71
+ graph_module.graph.lint()
72
+ graph_module.recompile()
73
+ return PassResult(graph_module, True)
@@ -0,0 +1,16 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass.pass_body import OptimizeLayoutTransposesPass # NOQA