ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (91) hide show
  1. ai_edge_torch/__init__.py +30 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +330 -0
  5. ai_edge_torch/convert/converter.py +171 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
  9. ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +273 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +171 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/debug/__init__.py +16 -0
  27. ai_edge_torch/debug/culprit.py +423 -0
  28. ai_edge_torch/debug/test/__init__.py +14 -0
  29. ai_edge_torch/debug/test/test_culprit.py +133 -0
  30. ai_edge_torch/debug/utils.py +48 -0
  31. ai_edge_torch/experimental/__init__.py +14 -0
  32. ai_edge_torch/generative/__init__.py +14 -0
  33. ai_edge_torch/generative/examples/__init__.py +14 -0
  34. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  35. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  36. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  37. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  39. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  40. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  42. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  43. ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
  44. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
  46. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  47. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  48. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  49. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  50. ai_edge_torch/generative/layers/__init__.py +14 -0
  51. ai_edge_torch/generative/layers/attention.py +288 -0
  52. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  53. ai_edge_torch/generative/layers/builder.py +103 -0
  54. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  55. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  56. ai_edge_torch/generative/layers/model_config.py +135 -0
  57. ai_edge_torch/generative/layers/normalization.py +62 -0
  58. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  59. ai_edge_torch/generative/quantize/__init__.py +14 -0
  60. ai_edge_torch/generative/quantize/example.py +45 -0
  61. ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
  62. ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
  63. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  64. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  65. ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
  66. ai_edge_torch/generative/test/__init__.py +14 -0
  67. ai_edge_torch/generative/test/test_model_conversion.py +201 -0
  68. ai_edge_torch/generative/test/test_quantize.py +109 -0
  69. ai_edge_torch/generative/utilities/__init__.py +15 -0
  70. ai_edge_torch/generative/utilities/loader.py +290 -0
  71. ai_edge_torch/generative/utilities/t5_loader.py +467 -0
  72. ai_edge_torch/hlfb/__init__.py +16 -0
  73. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  74. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  75. ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
  76. ai_edge_torch/hlfb/test/__init__.py +14 -0
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  79. ai_edge_torch/model.py +134 -0
  80. ai_edge_torch/quantize/__init__.py +16 -0
  81. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  83. ai_edge_torch/quantize/quant_config.py +85 -0
  84. ai_edge_torch/testing/__init__.py +14 -0
  85. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  86. ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
  87. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
  88. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
  89. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
  90. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
  91. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
@@ -0,0 +1,260 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import copy
16
+ import dataclasses
17
+ from typing import Any, Callable, Optional, Union
18
+
19
+ import torch
20
+ from torch.export.graph_signature import TensorArgument
21
+ from torch.fx import Graph
22
+ from torch.fx import GraphModule
23
+ from torch.fx.passes.utils.matcher_utils import InternalMatch
24
+ from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
25
+
26
+ from ai_edge_torch.hlfb.mark_pattern import passes
27
+
28
+
29
+ def _are_equal(x: Any, y: Any) -> bool:
30
+ if type(x) != type(y):
31
+ return False
32
+ if type(x) in [int, str]:
33
+ return x == y
34
+ if isinstance(x, float):
35
+ rel_tol = 1e-07
36
+ abs_tol = 0.0
37
+ return abs(x - y) <= max(rel_tol * max(abs(x), abs(y)), abs_tol)
38
+ if isinstance(x, list):
39
+ if len(x) != len(y):
40
+ return False
41
+ return all([_are_equal(a, b) for a, b in zip(x, y)])
42
+
43
+ raise Exception(f"Cannot compare type: {type(x)}")
44
+
45
+
46
+ @dataclasses.dataclass
47
+ class ScalarAttrTracker:
48
+ """ScalarAttrTracker is used to track the occurrence of a pattern's
49
+ scalar arg/attr in the pattern decomposed graph. Since a scalar attr
50
+ to the pattern can be transformed and turned into a/some ops' scalar
51
+ arg in the decomposed graph, it would be hard to programmatically get
52
+ the attr value from the pattern match. With the tracker and tracking info,
53
+ we could target the position of the decomposed op's scalar arg derived
54
+ from the pattern arg/attr and retrieve the value from the InternalMatch.
55
+
56
+ Args:
57
+ name (str): name of the attr to track.
58
+ pattern_arg_pos (int): the index of the attr to track in the pattern's
59
+ export_args.
60
+ transform (Callable): the transform function used when targeting the
61
+ occurrence of the attr value in the decomposed graph. An attr value
62
+ may be transformed during the decomposition and appear as a derived
63
+ value.
64
+ inverse_transform (Callable): the inverse transform function that maps
65
+ the transformed value back to the original attr value.
66
+ """
67
+
68
+ attr_name: str
69
+ pattern_arg_pos: int
70
+ transform: Callable = lambda x: x
71
+ inverse_transform: Callable = lambda x: x
72
+ _source_targets: list[tuple[Any, Any]] = dataclasses.field(default_factory=list)
73
+
74
+ def track(self, *sources):
75
+ """Register magic values to track the (transformed) attr values in
76
+ the pattern decomposed graph.
77
+ """
78
+ for source in sources:
79
+ target = self.transform(source)
80
+ if not _are_equal(self.inverse_transform(target), source):
81
+ raise Exception(f"Invalid transform/inverse_transform for {self.attr_name}")
82
+ self._source_targets.append([source, target])
83
+ return self
84
+
85
+
86
+ @dataclasses.dataclass
87
+ class ScalarAttrLocation:
88
+ attr_name: str
89
+ node_name: str
90
+ pos: Union[int, str]
91
+ _tracker: ScalarAttrTracker
92
+
93
+ @property
94
+ def index(self):
95
+ return self.pos if isinstance(self.pos, int) else None
96
+
97
+ @property
98
+ def key(self):
99
+ return self.pos if isinstance(self.pos, str) else None
100
+
101
+
102
+ def _find_scalar_attr(
103
+ pattern_module: torch.nn.Module, export_args: tuple[Any], tracker: ScalarAttrTracker
104
+ ) -> ScalarAttrLocation:
105
+ scalar_loc_intersections = None
106
+ for source, target in tracker._source_targets:
107
+ track_args = list(export_args)
108
+ track_args[tracker.pattern_arg_pos] = source
109
+ ep = torch.export.export(pattern_module, tuple(track_args))
110
+
111
+ scalar_locs = set()
112
+ nodes = ep.graph_module.graph.nodes
113
+ for n in nodes:
114
+ for arg_pos, arg in enumerate(n.args):
115
+ if type(arg) == type(target) and arg == target:
116
+ scalar_locs.add((n.name, arg_pos))
117
+ for attr, val in n.kwargs.items():
118
+ if type(val) == type(target) and val == target:
119
+ scalar_locs.add((n.name, attr))
120
+
121
+ if scalar_loc_intersections is None:
122
+ scalar_loc_intersections = scalar_locs
123
+ else:
124
+ scalar_loc_intersections = scalar_loc_intersections & scalar_locs
125
+
126
+ if not scalar_loc_intersections:
127
+ break
128
+
129
+ if not scalar_loc_intersections:
130
+ return None
131
+ # Choose any occurrence as the attr provider
132
+ node_name, pos = scalar_loc_intersections.pop()
133
+ return ScalarAttrLocation(tracker.attr_name, node_name, pos, _tracker=tracker)
134
+
135
+
136
+ class Pattern:
137
+
138
+ def __init__(
139
+ self,
140
+ name: str,
141
+ module: Union[Callable, torch.nn.Module],
142
+ export_args: tuple[Any],
143
+ *,
144
+ attr_builder: Callable[
145
+ ["Pattern", GraphModule, InternalMatch], Optional[dict[str, Any]]
146
+ ] = None,
147
+ scalar_attr_trackers: list[ScalarAttrTracker] = None,
148
+ ):
149
+ """The PyTorch computation pattern to match against a model.
150
+
151
+ Args:
152
+ name (str): the name of the pattern. It would be propagated to
153
+ the `name` attr in StableHLO composite ops for the matched
154
+ model subgraphs in the lowering.
155
+ module (torch.nn.Module or Callable): the PyTorch computation.
156
+ export_args (tuple[Any]): the args used to export the pattern module
157
+ with torch.export.export. If export_args contains non-tensor
158
+ Python scalars, there must be a corresponding attr tracker
159
+ in `scalar_attr_trackers` for each scalar arg.
160
+ attr_builder (Callable[[Pattern, GraphModule, InternalMatch], Optional[dict[str, Any]]]):
161
+ the callable that produces the a scalar attrs dict, which would be
162
+ propagated to `attr` in StableHLO composite ops for the matched
163
+ model subgraphs in the lowering.
164
+ scalar_attr_trackers (list[ScalarAttrTracker]): the trackers
165
+ for scalar args in `export_args`, which are used to track
166
+ the attr occurrence(s) and retrieve their values from the
167
+ matched subgraph.
168
+ """
169
+ if not isinstance(module, torch.nn.Module):
170
+
171
+ class PatternModule(torch.nn.Module):
172
+
173
+ def __init__(self, func):
174
+ super().__init__()
175
+ self.func = func
176
+
177
+ def forward(self, *args, **kwargs):
178
+ return self.func(*args, **kwargs)
179
+
180
+ module = PatternModule(module).eval()
181
+
182
+ self.name = name
183
+ self.exported_program = torch.export.export(module, export_args)
184
+ self.graph_module = self.exported_program.graph_module
185
+ self.attr_builder = attr_builder
186
+ self._scalar_attr_trackers = scalar_attr_trackers if scalar_attr_trackers else []
187
+
188
+ # Sanitize graph_module for more precise pattern matching.
189
+ # The graph_module to match against this pattern should apply equivalent
190
+ # sanitization.
191
+ self.graph_module = passes.remove_clone_ops(self.graph_module)
192
+ self.graph_module = passes.remove_dangling_args(self.graph_module)
193
+
194
+ self._scalar_attr_locations = []
195
+ for tracker in self._scalar_attr_trackers:
196
+ self._scalar_attr_locations.append(
197
+ _find_scalar_attr(module, export_args, tracker)
198
+ )
199
+
200
+ # Builds list of ordered input and output nodes.
201
+ self.graph_nodes_map = {}
202
+ for node in self.graph_module.graph.nodes:
203
+ self.graph_nodes_map[node.name] = node
204
+
205
+ self.input_nodes = tuple(
206
+ self.graph_nodes_map[spec.arg.name]
207
+ for spec in self.exported_program.graph_signature.input_specs
208
+ if isinstance(spec.arg, TensorArgument)
209
+ )
210
+ self.output_nodes = tuple(
211
+ self.graph_nodes_map[spec.arg.name]
212
+ for spec in self.exported_program.graph_signature.output_specs
213
+ )
214
+
215
+ def register_attr_builder(self, attr_builder):
216
+ self.attr_builder = attr_builder
217
+ return attr_builder
218
+
219
+ def match(
220
+ self,
221
+ graph_module: GraphModule,
222
+ ) -> list[tuple[InternalMatch, dict[str, Any]]]:
223
+ matcher = SubgraphMatcher(
224
+ self.graph_module.graph,
225
+ match_output=False,
226
+ match_placeholder=False,
227
+ remove_overlapping_matches=True,
228
+ ignore_literals=True,
229
+ )
230
+ matches = matcher.match(graph_module.graph)
231
+
232
+ match_with_attrs = []
233
+ # Graph traversal must be done in the reverser order (from SubgraphMatcher).
234
+ for match in matches[::-1]:
235
+ if self.attr_builder is not None:
236
+ attrs = self.attr_builder(self, graph_module, match)
237
+ else:
238
+ attrs = {}
239
+
240
+ for loc in self._scalar_attr_locations:
241
+ attrs[loc.attr_name] = self._get_attr_value_from_pattern_match(match, loc)
242
+
243
+ attrs = attrs if attrs else None
244
+ match_with_attrs.append((match, attrs))
245
+ return match_with_attrs
246
+
247
+ def _get_attr_value_from_pattern_match(
248
+ self,
249
+ match: InternalMatch,
250
+ loc: ScalarAttrLocation,
251
+ ):
252
+ matched_val = None
253
+ for k, v in match.nodes_map.items():
254
+ if k.name == loc.node_name:
255
+ if loc.index:
256
+ matched_val = v.args[loc.index]
257
+ elif loc.key in v.kwargs.keys():
258
+ matched_val = v.kwargs[loc.key]
259
+ attr_val = loc._tracker.inverse_transform(matched_val)
260
+ return attr_val
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,133 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import torch
19
+ import torch_xla
20
+
21
+ from ai_edge_torch.hlfb import mark_pattern
22
+
23
+
24
+ def _export_stablehlo_mlir(model, args=None):
25
+ if not isinstance(model, torch.export.ExportedProgram):
26
+ ep = torch.export.export(model, args)
27
+ else:
28
+ ep = model
29
+ stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
30
+ return stablehlo_gm.get_stablehlo_text()
31
+
32
+
33
+ class TestMarkPattern(unittest.TestCase):
34
+
35
+ def test_mark_pattern(self):
36
+
37
+ class TestModel(torch.nn.Module):
38
+
39
+ def forward(self, x):
40
+ return x * x + x + x
41
+
42
+ pattern = mark_pattern.Pattern(
43
+ "test.add",
44
+ lambda a, b: a + b,
45
+ export_args=(torch.rand(2, 2), torch.rand(2, 2)),
46
+ )
47
+
48
+ model = TestModel().eval()
49
+ args = (torch.rand(20, 20),)
50
+ exported_program = torch.export.export(model, args)
51
+ mark_pattern.mark_pattern(exported_program.graph_module, pattern)
52
+ mlir = _export_stablehlo_mlir(exported_program)
53
+
54
+ self.assertEqual(mlir.count('stablehlo.composite "test.add"'), 2)
55
+
56
+ def test_mark_pattern_with_attr_builder(self):
57
+ class TestModel(torch.nn.Module):
58
+
59
+ def forward(self, x):
60
+ return x * x * x + x - x * x + x
61
+
62
+ pattern = mark_pattern.Pattern(
63
+ "test.add",
64
+ lambda a, b: a + b,
65
+ export_args=(torch.rand(2, 2), torch.rand(2, 2)),
66
+ attr_builder=lambda *args: {"alias": "test.test_add"},
67
+ )
68
+
69
+ model = TestModel().eval()
70
+ args = (torch.rand(20, 20),)
71
+ exported_program = torch.export.export(model, args)
72
+ mark_pattern.mark_pattern(exported_program.graph_module, pattern)
73
+ mlir = _export_stablehlo_mlir(exported_program)
74
+
75
+ self.assertEqual(mlir.count('stablehlo.composite "test.add"'), 2)
76
+ self.assertEqual(mlir.count('composite_attributes = {alias = "test.test_add"}'), 2)
77
+
78
+ def test_mark_pattern_with_scalar_attr_tracker(self):
79
+ class TestModel(torch.nn.Module):
80
+
81
+ def forward(self, x):
82
+ r = x
83
+ for idx in range(5):
84
+ r = torch.nn.LogSoftmax(dim=idx % 2)(r) * x
85
+ return r
86
+
87
+ pattern = mark_pattern.Pattern(
88
+ "test.log_softmax",
89
+ lambda x, dim: torch.nn.functional.log_softmax(x, dim=dim),
90
+ export_args=(torch.rand(10, 10, 10), 1),
91
+ scalar_attr_trackers=[
92
+ mark_pattern.ScalarAttrTracker("dim", pattern_arg_pos=1)
93
+ .track(0)
94
+ .track(1)
95
+ .track(2),
96
+ ],
97
+ )
98
+
99
+ model = TestModel().eval()
100
+ args = (torch.rand(10, 10),)
101
+ exported_program = torch.export.export(model, args)
102
+ mark_pattern.mark_pattern(exported_program.graph_module, pattern)
103
+ mlir = _export_stablehlo_mlir(exported_program)
104
+
105
+ self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 5)
106
+ self.assertEqual(mlir.count("composite_attributes = {dim = 0 : i64}"), 3)
107
+ self.assertEqual(mlir.count("composite_attributes = {dim = 1 : i64}"), 2)
108
+
109
+ def test_mark_tangent_model_and_pattern_input(self):
110
+ class TestModel(torch.nn.Module):
111
+
112
+ def forward(self, x, y):
113
+ z = torch.ops.aten.relu(x)
114
+ z = z + y
115
+ return z
116
+
117
+ pattern = mark_pattern.Pattern(
118
+ "test.relu",
119
+ lambda x: torch.ops.aten.relu(x),
120
+ export_args=(torch.rand(2, 2),),
121
+ )
122
+
123
+ model = TestModel().eval()
124
+ args = (torch.rand(20, 20), torch.rand(20, 20))
125
+ exported_program = torch.export.export(model, args)
126
+ mark_pattern.mark_pattern(exported_program.graph_module, pattern)
127
+ mlir = _export_stablehlo_mlir(exported_program)
128
+
129
+ self.assertEqual(mlir.count('stablehlo.composite "test.relu'), 1)
130
+
131
+
132
+ if __name__ == "__main__":
133
+ unittest.main()
@@ -0,0 +1,270 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import math
16
+ import unittest
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import torch_xla
21
+
22
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
23
+
24
+
25
+ def _export_stablehlo_mlir(model, args):
26
+ ep = torch.export.export(model, args)
27
+ stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
28
+ return stablehlo_gm.get_stablehlo_text()
29
+
30
+
31
+ class TestStableHLOCompositeBuilder(unittest.TestCase):
32
+
33
+ def test_build_composite(self):
34
+ class SampleModel(torch.nn.Module):
35
+
36
+ def forward(self, x):
37
+ builder = StableHLOCompositeBuilder(name="test.plus_two")
38
+ y = x + 1
39
+ y = builder.mark_inputs(y)
40
+ z = y + 2
41
+ z = builder.mark_outputs(z)
42
+ return z
43
+
44
+ mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
45
+ self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 1)
46
+
47
+ def test_build_multiple_composites(self):
48
+ class SampleModel(torch.nn.Module):
49
+
50
+ def plus_one(self, x: torch.Tensor):
51
+ builder = StableHLOCompositeBuilder("test.plus_one")
52
+ x = builder.mark_inputs(x)
53
+ y = x + 1
54
+ y = builder.mark_outputs(y)
55
+ return y
56
+
57
+ def plus_two(self, x: torch.Tensor):
58
+ builder = StableHLOCompositeBuilder("test.plus_two")
59
+ x = builder.mark_inputs(x)
60
+ y = x + 2
61
+ y = builder.mark_outputs(y)
62
+ return y
63
+
64
+ def forward(self, x):
65
+ x = self.plus_two(x)
66
+ x = x + 3
67
+ x = self.plus_one(x)
68
+ x = x + 4
69
+ x = self.plus_two(x)
70
+ return x
71
+
72
+ mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
73
+ self.assertEqual(mlir.count('stablehlo.composite "test.plus_one"'), 1)
74
+ self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 2)
75
+
76
+ def test_build_composite_with_attr(self):
77
+ class SampleModel(torch.nn.Module):
78
+
79
+ def __init__(self):
80
+ super().__init__()
81
+
82
+ def log_softmax(self, x: torch.Tensor, dim: int):
83
+ builder = StableHLOCompositeBuilder(name="test.log_softmax", attr={"dim": dim})
84
+ x = builder.mark_inputs(x)
85
+ y = torch.nn.functional.log_softmax(x, dim=dim)
86
+ y = builder.mark_outputs(y)
87
+ return y
88
+
89
+ def forward(self, x):
90
+ x = x + 1
91
+ x = self.log_softmax(x, 0)
92
+ x = self.log_softmax(x, 1)
93
+ return x
94
+
95
+ mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
96
+ self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 2)
97
+ self.assertEqual(mlir.count("composite_attributes = {dim = 0 : i64}"), 1)
98
+ self.assertEqual(mlir.count("composite_attributes = {dim = 1 : i64}"), 1)
99
+
100
+ def test_build_composite_with_mix_type_attrs(self):
101
+ class SampleModel(torch.nn.Module):
102
+
103
+ def __init__(self):
104
+ super().__init__()
105
+
106
+ def log_softmax(self, x: torch.Tensor, dim: int):
107
+ builder = StableHLOCompositeBuilder(
108
+ name="test.log_softmax",
109
+ attr={
110
+ "dim": dim,
111
+ "source": "torch.nn",
112
+ "version": 1.0,
113
+ },
114
+ )
115
+ x = builder.mark_inputs(x)
116
+ y = torch.nn.functional.log_softmax(x, dim=dim)
117
+ y = builder.mark_outputs(y)
118
+ return y
119
+
120
+ def forward(self, x):
121
+ x = x + 1
122
+ x = self.log_softmax(x, 0)
123
+ return x
124
+
125
+ mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
126
+ self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 1)
127
+ self.assertEqual(
128
+ mlir.count(
129
+ 'composite_attributes = {dim = 0 : i64, source = "torch.nn", version = 1.000000e+00 : f32}'
130
+ ),
131
+ 1,
132
+ )
133
+
134
+ def test_sdpa_composite(self):
135
+ class SDPAModel(torch.nn.Module):
136
+
137
+ def scaled_dot_product_attention(
138
+ self,
139
+ q: torch.Tensor,
140
+ k: torch.Tensor,
141
+ v: torch.Tensor,
142
+ head_size: int,
143
+ mask: torch.Tensor,
144
+ ):
145
+ builder = StableHLOCompositeBuilder("test.scaled_dot_product_attention")
146
+ q, k, v, mask = builder.mark_inputs(q, k, v, mask)
147
+
148
+ scale = 1.0 / math.sqrt(head_size)
149
+
150
+ q = q.transpose(1, 2)
151
+ k = k.transpose(1, 2)
152
+ v = v.transpose(1, 2)
153
+ y = F.scaled_dot_product_attention(
154
+ q,
155
+ k,
156
+ v,
157
+ attn_mask=mask,
158
+ dropout_p=0.0,
159
+ is_causal=mask is None,
160
+ scale=scale,
161
+ )
162
+ result = y.transpose(1, 2)
163
+ result = builder.mark_outputs(result)
164
+ return result
165
+
166
+ def forward(self, q, k, v, mask):
167
+ x = self.scaled_dot_product_attention(
168
+ q,
169
+ k,
170
+ v,
171
+ 8,
172
+ mask,
173
+ )
174
+ return x
175
+
176
+ query = torch.rand(1, 1, 32, 4)
177
+ key = torch.rand(1, 500, 1, 4)
178
+ value = torch.rand(1, 500, 1, 4)
179
+ mask = torch.rand(1, 1, 1, 500)
180
+
181
+ mlir = _export_stablehlo_mlir(
182
+ SDPAModel().eval(),
183
+ (query, key, value, mask),
184
+ )
185
+ self.assertEqual(
186
+ mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 1
187
+ )
188
+
189
+ def test_sdpa_composite_with_attr(self):
190
+ class SDPAModel(torch.nn.Module):
191
+
192
+ def scaled_dot_product_attention(
193
+ self,
194
+ q: torch.Tensor,
195
+ k: torch.Tensor,
196
+ v: torch.Tensor,
197
+ head_size: int,
198
+ include_captanh: bool,
199
+ ):
200
+ builder = StableHLOCompositeBuilder(
201
+ name="test.scaled_dot_product_attention",
202
+ attr={"include_captanh": include_captanh},
203
+ )
204
+ q, k, v = builder.mark_inputs(q, k, v)
205
+
206
+ scale = 1.0 / math.sqrt(head_size)
207
+
208
+ q = q.transpose(1, 2)
209
+ k = k.transpose(1, 2)
210
+ v = v.transpose(1, 2)
211
+ y = F.scaled_dot_product_attention(
212
+ q,
213
+ k,
214
+ v,
215
+ attn_mask=None,
216
+ dropout_p=0.0,
217
+ is_causal=True,
218
+ scale=scale,
219
+ )
220
+ result = y.transpose(1, 2)
221
+ result = builder.mark_outputs(result)
222
+ return result
223
+
224
+ def forward(self, q, k, v):
225
+ x = self.scaled_dot_product_attention(q, k, v, 8, True)
226
+ y = self.scaled_dot_product_attention(q, k, v, 8, False)
227
+ return x + y
228
+
229
+ query = torch.rand(1, 1, 32, 4)
230
+ key = torch.rand(1, 500, 1, 4)
231
+ value = torch.rand(1, 500, 1, 4)
232
+ mlir = _export_stablehlo_mlir(
233
+ SDPAModel().eval(),
234
+ (query, key, value),
235
+ )
236
+ self.assertEqual(
237
+ mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 2
238
+ )
239
+ self.assertEqual(mlir.count("composite_attributes = {include_captanh = true}"), 1)
240
+ self.assertEqual(mlir.count("composite_attributes = {include_captanh = false}"), 1)
241
+
242
+ def test_build_composite_with_multiple_inputs_outputs(self):
243
+ class SampleModel(torch.nn.Module):
244
+
245
+ def mimo_sample(self, a, b, c):
246
+ builder = StableHLOCompositeBuilder(name="test.mimo_sample")
247
+
248
+ a, b, c = builder.mark_inputs(a, b, c)
249
+ x = a + b + c
250
+ y = (a - b) * x
251
+ z = (c + 1.0) * a
252
+ x, y, z = builder.mark_outputs(x, y, z)
253
+
254
+ result = x + y * z
255
+ return result
256
+
257
+ def forward(self, a, b, c):
258
+ x = self.mimo_sample(a, b, c)
259
+ x = self.mimo_sample(a, b, x)
260
+ x = self.mimo_sample(x, x, c)
261
+ return x
262
+
263
+ mlir = _export_stablehlo_mlir(
264
+ SampleModel().eval(), (torch.rand(2), torch.rand(2), torch.rand(2))
265
+ )
266
+ self.assertEqual(mlir.count('stablehlo.composite "test.mimo_sample"'), 3)
267
+
268
+
269
+ if __name__ == "__main__":
270
+ unittest.main()