ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +31 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +400 -0
- ai_edge_torch/convert/converter.py +202 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +311 -0
- ai_edge_torch/convert/test/test_convert_composites.py +192 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/convert/to_channel_last_io.py +85 -0
- ai_edge_torch/debug/__init__.py +17 -0
- ai_edge_torch/debug/culprit.py +464 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/test/test_search_model.py +50 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/fx_passes/__init__.py +31 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +354 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +131 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +158 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
- ai_edge_torch/generative/layers/unet/builder.py +47 -0
- ai_edge_torch/generative/layers/unet/model_config.py +269 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/loader_test.py +80 -0
- ai_edge_torch/generative/test/test_model_conversion.py +235 -0
- ai_edge_torch/generative/test/test_quantize.py +162 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +328 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
- ai_edge_torch/generative/utilities/t5_loader.py +483 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +142 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +81 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import copy
|
|
16
|
+
import dataclasses
|
|
17
|
+
from typing import Any, Callable, Optional, Union
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from torch.export.graph_signature import TensorArgument
|
|
21
|
+
from torch.fx import Graph
|
|
22
|
+
from torch.fx import GraphModule
|
|
23
|
+
from torch.fx.passes.utils.matcher_utils import InternalMatch
|
|
24
|
+
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
|
|
25
|
+
|
|
26
|
+
from ai_edge_torch.hlfb.mark_pattern import passes
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _are_equal(x: Any, y: Any) -> bool:
|
|
30
|
+
if type(x) != type(y):
|
|
31
|
+
return False
|
|
32
|
+
if type(x) in [int, str]:
|
|
33
|
+
return x == y
|
|
34
|
+
if isinstance(x, float):
|
|
35
|
+
rel_tol = 1e-07
|
|
36
|
+
abs_tol = 0.0
|
|
37
|
+
return abs(x - y) <= max(rel_tol * max(abs(x), abs(y)), abs_tol)
|
|
38
|
+
if isinstance(x, list):
|
|
39
|
+
if len(x) != len(y):
|
|
40
|
+
return False
|
|
41
|
+
return all([_are_equal(a, b) for a, b in zip(x, y)])
|
|
42
|
+
|
|
43
|
+
raise Exception(f"Cannot compare type: {type(x)}")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclasses.dataclass
|
|
47
|
+
class ScalarAttrTracker:
|
|
48
|
+
"""ScalarAttrTracker is used to track the occurrence of a pattern's
|
|
49
|
+
scalar arg/attr in the pattern decomposed graph. Since a scalar attr
|
|
50
|
+
to the pattern can be transformed and turned into a/some ops' scalar
|
|
51
|
+
arg in the decomposed graph, it would be hard to programmatically get
|
|
52
|
+
the attr value from the pattern match. With the tracker and tracking info,
|
|
53
|
+
we could target the position of the decomposed op's scalar arg derived
|
|
54
|
+
from the pattern arg/attr and retrieve the value from the InternalMatch.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
name (str): name of the attr to track.
|
|
58
|
+
pattern_arg_pos (int): the index of the attr to track in the pattern's
|
|
59
|
+
export_args.
|
|
60
|
+
transform (Callable): the transform function used when targeting the
|
|
61
|
+
occurrence of the attr value in the decomposed graph. An attr value
|
|
62
|
+
may be transformed during the decomposition and appear as a derived
|
|
63
|
+
value.
|
|
64
|
+
inverse_transform (Callable): the inverse transform function that maps
|
|
65
|
+
the transformed value back to the original attr value.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
attr_name: str
|
|
69
|
+
pattern_arg_pos: int
|
|
70
|
+
transform: Callable = lambda x: x
|
|
71
|
+
inverse_transform: Callable = lambda x: x
|
|
72
|
+
_source_targets: list[tuple[Any, Any]] = dataclasses.field(default_factory=list)
|
|
73
|
+
|
|
74
|
+
def track(self, *sources):
|
|
75
|
+
"""Register magic values to track the (transformed) attr values in
|
|
76
|
+
the pattern decomposed graph.
|
|
77
|
+
"""
|
|
78
|
+
for source in sources:
|
|
79
|
+
target = self.transform(source)
|
|
80
|
+
if not _are_equal(self.inverse_transform(target), source):
|
|
81
|
+
raise Exception(f"Invalid transform/inverse_transform for {self.attr_name}")
|
|
82
|
+
self._source_targets.append([source, target])
|
|
83
|
+
return self
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclasses.dataclass
|
|
87
|
+
class ScalarAttrLocation:
|
|
88
|
+
attr_name: str
|
|
89
|
+
node_name: str
|
|
90
|
+
pos: Union[int, str]
|
|
91
|
+
_tracker: ScalarAttrTracker
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def index(self):
|
|
95
|
+
return self.pos if isinstance(self.pos, int) else None
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def key(self):
|
|
99
|
+
return self.pos if isinstance(self.pos, str) else None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _find_scalar_attr(
|
|
103
|
+
pattern_module: torch.nn.Module,
|
|
104
|
+
export_args: tuple[Any],
|
|
105
|
+
tracker: ScalarAttrTracker,
|
|
106
|
+
decomp_table=None,
|
|
107
|
+
) -> ScalarAttrLocation:
|
|
108
|
+
scalar_loc_intersections = None
|
|
109
|
+
for source, target in tracker._source_targets:
|
|
110
|
+
track_args = list(export_args)
|
|
111
|
+
track_args[tracker.pattern_arg_pos] = source
|
|
112
|
+
ep = torch.export.export(pattern_module, tuple(track_args))
|
|
113
|
+
if decomp_table is not None:
|
|
114
|
+
ep = ep.run_decompositions(decomp_table)
|
|
115
|
+
|
|
116
|
+
scalar_locs = set()
|
|
117
|
+
nodes = ep.graph_module.graph.nodes
|
|
118
|
+
for n in nodes:
|
|
119
|
+
for arg_pos, arg in enumerate(n.args):
|
|
120
|
+
if type(arg) == type(target) and arg == target:
|
|
121
|
+
scalar_locs.add((n.name, arg_pos))
|
|
122
|
+
for attr, val in n.kwargs.items():
|
|
123
|
+
if type(val) == type(target) and val == target:
|
|
124
|
+
scalar_locs.add((n.name, attr))
|
|
125
|
+
|
|
126
|
+
if scalar_loc_intersections is None:
|
|
127
|
+
scalar_loc_intersections = scalar_locs
|
|
128
|
+
else:
|
|
129
|
+
scalar_loc_intersections = scalar_loc_intersections & scalar_locs
|
|
130
|
+
|
|
131
|
+
if not scalar_loc_intersections:
|
|
132
|
+
break
|
|
133
|
+
|
|
134
|
+
if not scalar_loc_intersections:
|
|
135
|
+
return None
|
|
136
|
+
# Choose any occurrence as the attr provider
|
|
137
|
+
node_name, pos = scalar_loc_intersections.pop()
|
|
138
|
+
return ScalarAttrLocation(tracker.attr_name, node_name, pos, _tracker=tracker)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class Pattern:
|
|
142
|
+
|
|
143
|
+
def __init__(
|
|
144
|
+
self,
|
|
145
|
+
name: str,
|
|
146
|
+
module: Union[Callable, torch.nn.Module],
|
|
147
|
+
export_args: tuple[Any],
|
|
148
|
+
*,
|
|
149
|
+
attr_builder: Callable[
|
|
150
|
+
["Pattern", GraphModule, InternalMatch], Optional[dict[str, Any]]
|
|
151
|
+
] = None,
|
|
152
|
+
scalar_attr_trackers: list[ScalarAttrTracker] = None,
|
|
153
|
+
decomp_table: Optional[dict[torch._ops.OperatorBase, Callable]] = None,
|
|
154
|
+
):
|
|
155
|
+
"""The PyTorch computation pattern to match against a model.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
name (str): the name of the pattern. It would be propagated to
|
|
159
|
+
the `name` attr in StableHLO composite ops for the matched
|
|
160
|
+
model subgraphs in the lowering.
|
|
161
|
+
module (torch.nn.Module or Callable): the PyTorch computation.
|
|
162
|
+
export_args (tuple[Any]): the args used to export the pattern module
|
|
163
|
+
with torch.export.export. If export_args contains non-tensor
|
|
164
|
+
Python scalars, there must be a corresponding attr tracker
|
|
165
|
+
in `scalar_attr_trackers` for each scalar arg.
|
|
166
|
+
attr_builder (Callable[[Pattern, GraphModule, InternalMatch], Optional[dict[str, Any]]]):
|
|
167
|
+
the callable that produces the a scalar attrs dict, which would be
|
|
168
|
+
propagated to `attr` in StableHLO composite ops for the matched
|
|
169
|
+
model subgraphs in the lowering.
|
|
170
|
+
scalar_attr_trackers (list[ScalarAttrTracker]): the trackers
|
|
171
|
+
for scalar args in `export_args`, which are used to track
|
|
172
|
+
the attr occurrence(s) and retrieve their values from the
|
|
173
|
+
matched subgraph.
|
|
174
|
+
decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]):
|
|
175
|
+
The decomposition table to be run on the pattern's exported program.
|
|
176
|
+
"""
|
|
177
|
+
if not isinstance(module, torch.nn.Module):
|
|
178
|
+
|
|
179
|
+
class PatternModule(torch.nn.Module):
|
|
180
|
+
|
|
181
|
+
def __init__(self, func):
|
|
182
|
+
super().__init__()
|
|
183
|
+
self.func = func
|
|
184
|
+
|
|
185
|
+
def forward(self, *args, **kwargs):
|
|
186
|
+
return self.func(*args, **kwargs)
|
|
187
|
+
|
|
188
|
+
module = PatternModule(module).eval()
|
|
189
|
+
|
|
190
|
+
self.name = name
|
|
191
|
+
self.attr_builder = attr_builder
|
|
192
|
+
self._scalar_attr_trackers = scalar_attr_trackers if scalar_attr_trackers else []
|
|
193
|
+
|
|
194
|
+
exported_program = torch.export.export(module, export_args)
|
|
195
|
+
if decomp_table is not None:
|
|
196
|
+
exported_program = exported_program.run_decompositions(decomp_table)
|
|
197
|
+
|
|
198
|
+
self.exported_program = exported_program
|
|
199
|
+
self.graph_module = self.exported_program.graph_module
|
|
200
|
+
|
|
201
|
+
self._scalar_attr_locations = []
|
|
202
|
+
for tracker in self._scalar_attr_trackers:
|
|
203
|
+
self._scalar_attr_locations.append(
|
|
204
|
+
_find_scalar_attr(module, export_args, tracker, decomp_table=decomp_table)
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Sanitize graph_module for more precise pattern matching.
|
|
208
|
+
# The graph_module to match against this pattern should apply equivalent
|
|
209
|
+
# sanitization.
|
|
210
|
+
self.graph_module = passes.remove_clone_ops(self.graph_module)
|
|
211
|
+
self.graph_module = passes.remove_dangling_args(self.graph_module)
|
|
212
|
+
|
|
213
|
+
# Builds list of ordered input and output nodes.
|
|
214
|
+
self.graph_nodes_map = {}
|
|
215
|
+
for node in self.graph_module.graph.nodes:
|
|
216
|
+
self.graph_nodes_map[node.name] = node
|
|
217
|
+
|
|
218
|
+
self.input_nodes = tuple(
|
|
219
|
+
self.graph_nodes_map[spec.arg.name]
|
|
220
|
+
for spec in self.exported_program.graph_signature.input_specs
|
|
221
|
+
if isinstance(spec.arg, TensorArgument)
|
|
222
|
+
)
|
|
223
|
+
self.output_nodes = tuple(
|
|
224
|
+
self.graph_nodes_map[spec.arg.name]
|
|
225
|
+
for spec in self.exported_program.graph_signature.output_specs
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def register_attr_builder(self, attr_builder):
|
|
229
|
+
self.attr_builder = attr_builder
|
|
230
|
+
return attr_builder
|
|
231
|
+
|
|
232
|
+
def match(
|
|
233
|
+
self,
|
|
234
|
+
graph_module: GraphModule,
|
|
235
|
+
) -> list[tuple[InternalMatch, dict[str, Any]]]:
|
|
236
|
+
matcher = SubgraphMatcher(
|
|
237
|
+
self.graph_module.graph,
|
|
238
|
+
match_output=False,
|
|
239
|
+
match_placeholder=False,
|
|
240
|
+
remove_overlapping_matches=True,
|
|
241
|
+
ignore_literals=True,
|
|
242
|
+
)
|
|
243
|
+
matches = matcher.match(graph_module.graph)
|
|
244
|
+
|
|
245
|
+
match_with_attrs = []
|
|
246
|
+
# Graph traversal must be done in the reverser order (from SubgraphMatcher).
|
|
247
|
+
for match in matches[::-1]:
|
|
248
|
+
if self.attr_builder is not None:
|
|
249
|
+
attrs = self.attr_builder(self, graph_module, match)
|
|
250
|
+
else:
|
|
251
|
+
attrs = {}
|
|
252
|
+
|
|
253
|
+
for loc in self._scalar_attr_locations:
|
|
254
|
+
attrs[loc.attr_name] = self._get_attr_value_from_pattern_match(match, loc)
|
|
255
|
+
|
|
256
|
+
attrs = attrs if attrs else None
|
|
257
|
+
match_with_attrs.append((match, attrs))
|
|
258
|
+
return match_with_attrs
|
|
259
|
+
|
|
260
|
+
def _get_attr_value_from_pattern_match(
|
|
261
|
+
self,
|
|
262
|
+
match: InternalMatch,
|
|
263
|
+
loc: ScalarAttrLocation,
|
|
264
|
+
):
|
|
265
|
+
matched_val = None
|
|
266
|
+
for k, v in match.nodes_map.items():
|
|
267
|
+
if k.name == loc.node_name:
|
|
268
|
+
if loc.index:
|
|
269
|
+
matched_val = v.args[loc.index]
|
|
270
|
+
elif loc.key in v.kwargs.keys():
|
|
271
|
+
matched_val = v.kwargs[loc.key]
|
|
272
|
+
attr_val = loc._tracker.inverse_transform(matched_val)
|
|
273
|
+
return attr_val
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import unittest
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch_xla
|
|
20
|
+
|
|
21
|
+
from ai_edge_torch.hlfb import mark_pattern
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _export_stablehlo_mlir(model, args=None):
|
|
25
|
+
if not isinstance(model, torch.export.ExportedProgram):
|
|
26
|
+
ep = torch.export.export(model, args)
|
|
27
|
+
else:
|
|
28
|
+
ep = model
|
|
29
|
+
stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
|
|
30
|
+
return stablehlo_gm.get_stablehlo_text()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TestMarkPattern(unittest.TestCase):
|
|
34
|
+
|
|
35
|
+
def test_mark_pattern(self):
|
|
36
|
+
|
|
37
|
+
class TestModel(torch.nn.Module):
|
|
38
|
+
|
|
39
|
+
def forward(self, x):
|
|
40
|
+
return x * x + x + x
|
|
41
|
+
|
|
42
|
+
pattern = mark_pattern.Pattern(
|
|
43
|
+
"test.add",
|
|
44
|
+
lambda a, b: a + b,
|
|
45
|
+
export_args=(torch.rand(2, 2), torch.rand(2, 2)),
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
model = TestModel().eval()
|
|
49
|
+
args = (torch.rand(20, 20),)
|
|
50
|
+
exported_program = torch.export.export(model, args)
|
|
51
|
+
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
|
52
|
+
mlir = _export_stablehlo_mlir(exported_program)
|
|
53
|
+
|
|
54
|
+
self.assertEqual(mlir.count('stablehlo.composite "test.add"'), 2)
|
|
55
|
+
|
|
56
|
+
def test_mark_pattern_with_attr_builder(self):
|
|
57
|
+
class TestModel(torch.nn.Module):
|
|
58
|
+
|
|
59
|
+
def forward(self, x):
|
|
60
|
+
return x * x * x + x - x * x + x
|
|
61
|
+
|
|
62
|
+
pattern = mark_pattern.Pattern(
|
|
63
|
+
"test.add",
|
|
64
|
+
lambda a, b: a + b,
|
|
65
|
+
export_args=(torch.rand(2, 2), torch.rand(2, 2)),
|
|
66
|
+
attr_builder=lambda *args: {"alias": "test.test_add"},
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
model = TestModel().eval()
|
|
70
|
+
args = (torch.rand(20, 20),)
|
|
71
|
+
exported_program = torch.export.export(model, args)
|
|
72
|
+
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
|
73
|
+
mlir = _export_stablehlo_mlir(exported_program)
|
|
74
|
+
|
|
75
|
+
self.assertEqual(mlir.count('stablehlo.composite "test.add"'), 2)
|
|
76
|
+
self.assertEqual(mlir.count('composite_attributes = {alias = "test.test_add"}'), 2)
|
|
77
|
+
|
|
78
|
+
def test_mark_pattern_with_scalar_attr_tracker(self):
|
|
79
|
+
class TestModel(torch.nn.Module):
|
|
80
|
+
|
|
81
|
+
def forward(self, x):
|
|
82
|
+
r = x
|
|
83
|
+
for idx in range(5):
|
|
84
|
+
r = torch.nn.LogSoftmax(dim=idx % 2)(r) * x
|
|
85
|
+
return r
|
|
86
|
+
|
|
87
|
+
pattern = mark_pattern.Pattern(
|
|
88
|
+
"test.log_softmax",
|
|
89
|
+
lambda x, dim: torch.nn.functional.log_softmax(x, dim=dim),
|
|
90
|
+
export_args=(torch.rand(10, 10, 10), 1),
|
|
91
|
+
scalar_attr_trackers=[
|
|
92
|
+
mark_pattern.ScalarAttrTracker("dim", pattern_arg_pos=1)
|
|
93
|
+
.track(0)
|
|
94
|
+
.track(1)
|
|
95
|
+
.track(2),
|
|
96
|
+
],
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
model = TestModel().eval()
|
|
100
|
+
args = (torch.rand(10, 10),)
|
|
101
|
+
exported_program = torch.export.export(model, args)
|
|
102
|
+
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
|
103
|
+
mlir = _export_stablehlo_mlir(exported_program)
|
|
104
|
+
|
|
105
|
+
self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 5)
|
|
106
|
+
self.assertEqual(mlir.count("composite_attributes = {dim = 0 : i64}"), 3)
|
|
107
|
+
self.assertEqual(mlir.count("composite_attributes = {dim = 1 : i64}"), 2)
|
|
108
|
+
|
|
109
|
+
def test_mark_tangent_model_and_pattern_input(self):
|
|
110
|
+
class TestModel(torch.nn.Module):
|
|
111
|
+
|
|
112
|
+
def forward(self, x, y):
|
|
113
|
+
z = torch.ops.aten.relu(x)
|
|
114
|
+
z = z + y
|
|
115
|
+
return z
|
|
116
|
+
|
|
117
|
+
pattern = mark_pattern.Pattern(
|
|
118
|
+
"test.relu",
|
|
119
|
+
lambda x: torch.ops.aten.relu(x),
|
|
120
|
+
export_args=(torch.rand(2, 2),),
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
model = TestModel().eval()
|
|
124
|
+
args = (torch.rand(20, 20), torch.rand(20, 20))
|
|
125
|
+
exported_program = torch.export.export(model, args)
|
|
126
|
+
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
|
127
|
+
mlir = _export_stablehlo_mlir(exported_program)
|
|
128
|
+
|
|
129
|
+
self.assertEqual(mlir.count('stablehlo.composite "test.relu'), 1)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
if __name__ == "__main__":
|
|
133
|
+
unittest.main()
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import math
|
|
16
|
+
import unittest
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
import torch_xla
|
|
21
|
+
|
|
22
|
+
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _export_stablehlo_mlir(model, args):
|
|
26
|
+
ep = torch.export.export(model, args)
|
|
27
|
+
stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
|
|
28
|
+
return stablehlo_gm.get_stablehlo_text()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TestStableHLOCompositeBuilder(unittest.TestCase):
|
|
32
|
+
|
|
33
|
+
def test_build_composite(self):
|
|
34
|
+
class SampleModel(torch.nn.Module):
|
|
35
|
+
|
|
36
|
+
def forward(self, x):
|
|
37
|
+
builder = StableHLOCompositeBuilder(name="test.plus_two")
|
|
38
|
+
y = x + 1
|
|
39
|
+
y = builder.mark_inputs(y)
|
|
40
|
+
z = y + 2
|
|
41
|
+
z = builder.mark_outputs(z)
|
|
42
|
+
return z
|
|
43
|
+
|
|
44
|
+
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
|
45
|
+
self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 1)
|
|
46
|
+
|
|
47
|
+
def test_build_multiple_composites(self):
|
|
48
|
+
class SampleModel(torch.nn.Module):
|
|
49
|
+
|
|
50
|
+
def plus_one(self, x: torch.Tensor):
|
|
51
|
+
builder = StableHLOCompositeBuilder("test.plus_one")
|
|
52
|
+
x = builder.mark_inputs(x)
|
|
53
|
+
y = x + 1
|
|
54
|
+
y = builder.mark_outputs(y)
|
|
55
|
+
return y
|
|
56
|
+
|
|
57
|
+
def plus_two(self, x: torch.Tensor):
|
|
58
|
+
builder = StableHLOCompositeBuilder("test.plus_two")
|
|
59
|
+
x = builder.mark_inputs(x)
|
|
60
|
+
y = x + 2
|
|
61
|
+
y = builder.mark_outputs(y)
|
|
62
|
+
return y
|
|
63
|
+
|
|
64
|
+
def forward(self, x):
|
|
65
|
+
x = self.plus_two(x)
|
|
66
|
+
x = x + 3
|
|
67
|
+
x = self.plus_one(x)
|
|
68
|
+
x = x + 4
|
|
69
|
+
x = self.plus_two(x)
|
|
70
|
+
return x
|
|
71
|
+
|
|
72
|
+
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
|
73
|
+
self.assertEqual(mlir.count('stablehlo.composite "test.plus_one"'), 1)
|
|
74
|
+
self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 2)
|
|
75
|
+
|
|
76
|
+
def test_build_composite_with_attr(self):
|
|
77
|
+
class SampleModel(torch.nn.Module):
|
|
78
|
+
|
|
79
|
+
def __init__(self):
|
|
80
|
+
super().__init__()
|
|
81
|
+
|
|
82
|
+
def log_softmax(self, x: torch.Tensor, dim: int):
|
|
83
|
+
builder = StableHLOCompositeBuilder(name="test.log_softmax", attr={"dim": dim})
|
|
84
|
+
x = builder.mark_inputs(x)
|
|
85
|
+
y = torch.nn.functional.log_softmax(x, dim=dim)
|
|
86
|
+
y = builder.mark_outputs(y)
|
|
87
|
+
return y
|
|
88
|
+
|
|
89
|
+
def forward(self, x):
|
|
90
|
+
x = x + 1
|
|
91
|
+
x = self.log_softmax(x, 0)
|
|
92
|
+
x = self.log_softmax(x, 1)
|
|
93
|
+
return x
|
|
94
|
+
|
|
95
|
+
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
|
96
|
+
self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 2)
|
|
97
|
+
self.assertEqual(mlir.count("composite_attributes = {dim = 0 : i64}"), 1)
|
|
98
|
+
self.assertEqual(mlir.count("composite_attributes = {dim = 1 : i64}"), 1)
|
|
99
|
+
|
|
100
|
+
def test_build_composite_with_mix_type_attrs(self):
|
|
101
|
+
class SampleModel(torch.nn.Module):
|
|
102
|
+
|
|
103
|
+
def __init__(self):
|
|
104
|
+
super().__init__()
|
|
105
|
+
|
|
106
|
+
def log_softmax(self, x: torch.Tensor, dim: int):
|
|
107
|
+
builder = StableHLOCompositeBuilder(
|
|
108
|
+
name="test.log_softmax",
|
|
109
|
+
attr={
|
|
110
|
+
"dim": dim,
|
|
111
|
+
"source": "torch.nn",
|
|
112
|
+
"version": 1.0,
|
|
113
|
+
},
|
|
114
|
+
)
|
|
115
|
+
x = builder.mark_inputs(x)
|
|
116
|
+
y = torch.nn.functional.log_softmax(x, dim=dim)
|
|
117
|
+
y = builder.mark_outputs(y)
|
|
118
|
+
return y
|
|
119
|
+
|
|
120
|
+
def forward(self, x):
|
|
121
|
+
x = x + 1
|
|
122
|
+
x = self.log_softmax(x, 0)
|
|
123
|
+
return x
|
|
124
|
+
|
|
125
|
+
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
|
126
|
+
self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 1)
|
|
127
|
+
self.assertEqual(
|
|
128
|
+
mlir.count(
|
|
129
|
+
'composite_attributes = {dim = 0 : i64, source = "torch.nn", version = 1.000000e+00 : f32}'
|
|
130
|
+
),
|
|
131
|
+
1,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
def test_sdpa_composite(self):
|
|
135
|
+
class SDPAModel(torch.nn.Module):
|
|
136
|
+
|
|
137
|
+
def scaled_dot_product_attention(
|
|
138
|
+
self,
|
|
139
|
+
q: torch.Tensor,
|
|
140
|
+
k: torch.Tensor,
|
|
141
|
+
v: torch.Tensor,
|
|
142
|
+
head_size: int,
|
|
143
|
+
mask: torch.Tensor,
|
|
144
|
+
):
|
|
145
|
+
builder = StableHLOCompositeBuilder("test.scaled_dot_product_attention")
|
|
146
|
+
q, k, v, mask = builder.mark_inputs(q, k, v, mask)
|
|
147
|
+
|
|
148
|
+
scale = 1.0 / math.sqrt(head_size)
|
|
149
|
+
|
|
150
|
+
q = q.transpose(1, 2)
|
|
151
|
+
k = k.transpose(1, 2)
|
|
152
|
+
v = v.transpose(1, 2)
|
|
153
|
+
y = F.scaled_dot_product_attention(
|
|
154
|
+
q,
|
|
155
|
+
k,
|
|
156
|
+
v,
|
|
157
|
+
attn_mask=mask,
|
|
158
|
+
dropout_p=0.0,
|
|
159
|
+
is_causal=mask is None,
|
|
160
|
+
scale=scale,
|
|
161
|
+
)
|
|
162
|
+
result = y.transpose(1, 2)
|
|
163
|
+
result = builder.mark_outputs(result)
|
|
164
|
+
return result
|
|
165
|
+
|
|
166
|
+
def forward(self, q, k, v, mask):
|
|
167
|
+
x = self.scaled_dot_product_attention(
|
|
168
|
+
q,
|
|
169
|
+
k,
|
|
170
|
+
v,
|
|
171
|
+
8,
|
|
172
|
+
mask,
|
|
173
|
+
)
|
|
174
|
+
return x
|
|
175
|
+
|
|
176
|
+
query = torch.rand(1, 1, 32, 4)
|
|
177
|
+
key = torch.rand(1, 500, 1, 4)
|
|
178
|
+
value = torch.rand(1, 500, 1, 4)
|
|
179
|
+
mask = torch.rand(1, 1, 1, 500)
|
|
180
|
+
|
|
181
|
+
mlir = _export_stablehlo_mlir(
|
|
182
|
+
SDPAModel().eval(),
|
|
183
|
+
(query, key, value, mask),
|
|
184
|
+
)
|
|
185
|
+
self.assertEqual(
|
|
186
|
+
mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 1
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def test_sdpa_composite_with_attr(self):
|
|
190
|
+
class SDPAModel(torch.nn.Module):
|
|
191
|
+
|
|
192
|
+
def scaled_dot_product_attention(
|
|
193
|
+
self,
|
|
194
|
+
q: torch.Tensor,
|
|
195
|
+
k: torch.Tensor,
|
|
196
|
+
v: torch.Tensor,
|
|
197
|
+
head_size: int,
|
|
198
|
+
include_captanh: bool,
|
|
199
|
+
):
|
|
200
|
+
builder = StableHLOCompositeBuilder(
|
|
201
|
+
name="test.scaled_dot_product_attention",
|
|
202
|
+
attr={"include_captanh": include_captanh},
|
|
203
|
+
)
|
|
204
|
+
q, k, v = builder.mark_inputs(q, k, v)
|
|
205
|
+
|
|
206
|
+
scale = 1.0 / math.sqrt(head_size)
|
|
207
|
+
|
|
208
|
+
q = q.transpose(1, 2)
|
|
209
|
+
k = k.transpose(1, 2)
|
|
210
|
+
v = v.transpose(1, 2)
|
|
211
|
+
y = F.scaled_dot_product_attention(
|
|
212
|
+
q,
|
|
213
|
+
k,
|
|
214
|
+
v,
|
|
215
|
+
attn_mask=None,
|
|
216
|
+
dropout_p=0.0,
|
|
217
|
+
is_causal=True,
|
|
218
|
+
scale=scale,
|
|
219
|
+
)
|
|
220
|
+
result = y.transpose(1, 2)
|
|
221
|
+
result = builder.mark_outputs(result)
|
|
222
|
+
return result
|
|
223
|
+
|
|
224
|
+
def forward(self, q, k, v):
|
|
225
|
+
x = self.scaled_dot_product_attention(q, k, v, 8, True)
|
|
226
|
+
y = self.scaled_dot_product_attention(q, k, v, 8, False)
|
|
227
|
+
return x + y
|
|
228
|
+
|
|
229
|
+
query = torch.rand(1, 1, 32, 4)
|
|
230
|
+
key = torch.rand(1, 500, 1, 4)
|
|
231
|
+
value = torch.rand(1, 500, 1, 4)
|
|
232
|
+
mlir = _export_stablehlo_mlir(
|
|
233
|
+
SDPAModel().eval(),
|
|
234
|
+
(query, key, value),
|
|
235
|
+
)
|
|
236
|
+
self.assertEqual(
|
|
237
|
+
mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 2
|
|
238
|
+
)
|
|
239
|
+
self.assertEqual(mlir.count("composite_attributes = {include_captanh = true}"), 1)
|
|
240
|
+
self.assertEqual(mlir.count("composite_attributes = {include_captanh = false}"), 1)
|
|
241
|
+
|
|
242
|
+
def test_build_composite_with_multiple_inputs_outputs(self):
|
|
243
|
+
class SampleModel(torch.nn.Module):
|
|
244
|
+
|
|
245
|
+
def mimo_sample(self, a, b, c):
|
|
246
|
+
builder = StableHLOCompositeBuilder(name="test.mimo_sample")
|
|
247
|
+
|
|
248
|
+
a, b, c = builder.mark_inputs(a, b, c)
|
|
249
|
+
x = a + b + c
|
|
250
|
+
y = (a - b) * x
|
|
251
|
+
z = (c + 1.0) * a
|
|
252
|
+
x, y, z = builder.mark_outputs(x, y, z)
|
|
253
|
+
|
|
254
|
+
result = x + y * z
|
|
255
|
+
return result
|
|
256
|
+
|
|
257
|
+
def forward(self, a, b, c):
|
|
258
|
+
x = self.mimo_sample(a, b, c)
|
|
259
|
+
x = self.mimo_sample(a, b, x)
|
|
260
|
+
x = self.mimo_sample(x, x, c)
|
|
261
|
+
return x
|
|
262
|
+
|
|
263
|
+
mlir = _export_stablehlo_mlir(
|
|
264
|
+
SampleModel().eval(), (torch.rand(2), torch.rand(2), torch.rand(2))
|
|
265
|
+
)
|
|
266
|
+
self.assertEqual(mlir.count('stablehlo.composite "test.mimo_sample"'), 3)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
if __name__ == "__main__":
|
|
270
|
+
unittest.main()
|