ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +31 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +400 -0
- ai_edge_torch/convert/converter.py +202 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +311 -0
- ai_edge_torch/convert/test/test_convert_composites.py +192 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/convert/to_channel_last_io.py +85 -0
- ai_edge_torch/debug/__init__.py +17 -0
- ai_edge_torch/debug/culprit.py +464 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/test/test_search_model.py +50 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/fx_passes/__init__.py +31 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +354 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +131 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +158 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
- ai_edge_torch/generative/layers/unet/builder.py +47 -0
- ai_edge_torch/generative/layers/unet/model_config.py +269 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/loader_test.py +80 -0
- ai_edge_torch/generative/test/test_model_conversion.py +235 -0
- ai_edge_torch/generative/test/test_quantize.py +162 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +328 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
- ai_edge_torch/generative/utilities/t5_loader.py +483 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +142 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +81 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torch.fx.passes.infra.pass_base import PassBase
|
|
18
|
+
from torch.fx.passes.infra.pass_base import PassResult
|
|
19
|
+
import torch.utils._pytree as pytree
|
|
20
|
+
import torch_xla.experimental.xla_mlir_debuginfo # Import required to register torch.ops.xla.write_mlir_debuginfo
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _get_mlir_debuginfo(node: torch.fx.Node):
|
|
24
|
+
def class_fullname(cls):
|
|
25
|
+
module = cls.__module__
|
|
26
|
+
if module == "builtins":
|
|
27
|
+
return cls.__qualname__
|
|
28
|
+
return module + "." + cls.__qualname__
|
|
29
|
+
|
|
30
|
+
def get_hierarchy(node: torch.fx.Node):
|
|
31
|
+
nn_module_stack = node.meta.get("nn_module_stack", {})
|
|
32
|
+
layers = []
|
|
33
|
+
for name, layer in nn_module_stack.values():
|
|
34
|
+
iid = ("_" + name.split(".")[-1]) if name else ""
|
|
35
|
+
layer_str = layer if isinstance(layer, str) else class_fullname(layer)
|
|
36
|
+
layers.append(layer_str + iid)
|
|
37
|
+
|
|
38
|
+
hierachy_str = "/".join(layers) + ";"
|
|
39
|
+
return hierachy_str
|
|
40
|
+
|
|
41
|
+
# TODO(yijieyang): Encode aten op and attrs.
|
|
42
|
+
return get_hierarchy(node)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
|
|
46
|
+
if not node.op.startswith("call_function"):
|
|
47
|
+
return
|
|
48
|
+
|
|
49
|
+
target = node.target
|
|
50
|
+
debuginfo = _get_mlir_debuginfo(node)
|
|
51
|
+
|
|
52
|
+
def debuginfo_writer(*args, **kwargs):
|
|
53
|
+
nonlocal target, debuginfo
|
|
54
|
+
outputs = target(*args, **kwargs)
|
|
55
|
+
outputs = pytree.tree_map_only(
|
|
56
|
+
torch.Tensor,
|
|
57
|
+
lambda x: torch.ops.xla.write_mlir_debuginfo(x, debuginfo),
|
|
58
|
+
outputs,
|
|
59
|
+
)
|
|
60
|
+
return outputs
|
|
61
|
+
|
|
62
|
+
node.target = debuginfo_writer
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class InjectMlirDebuginfoPass(PassBase):
|
|
66
|
+
|
|
67
|
+
def call(self, graph_module: torch.fx.GraphModule):
|
|
68
|
+
for node in graph_module.graph.nodes:
|
|
69
|
+
_wrap_call_function_node_with_debuginfo_writer(node)
|
|
70
|
+
|
|
71
|
+
graph_module.graph.lint()
|
|
72
|
+
graph_module.recompile()
|
|
73
|
+
return PassResult(graph_module, True)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass.pass_body import OptimizeLayoutTransposesPass # NOQA
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import dataclasses
|
|
16
|
+
import operator
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch.fx import Node
|
|
20
|
+
|
|
21
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
|
22
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite # NOQA
|
|
23
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
|
|
24
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass.op_func_registry import OpFuncRegistry # NOQA
|
|
25
|
+
|
|
26
|
+
aten = torch.ops.aten
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"is_4d",
|
|
30
|
+
"can_be_nhwc",
|
|
31
|
+
"must_be_nhwc",
|
|
32
|
+
"get_layout_sensitive_inputs",
|
|
33
|
+
"get_no_rewriter_nhwc_ops",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class LayoutSensitiveInputsGettersRegistry(OpFuncRegistry):
|
|
38
|
+
|
|
39
|
+
def __missing__(self, op):
|
|
40
|
+
|
|
41
|
+
def _default_getter(node: Node):
|
|
42
|
+
"""Default layout sensitive inputs are all input nodes."""
|
|
43
|
+
return node.all_input_nodes
|
|
44
|
+
|
|
45
|
+
return _default_getter
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclasses.dataclass
|
|
49
|
+
class NHWCable:
|
|
50
|
+
can_be: bool
|
|
51
|
+
must_be: bool
|
|
52
|
+
|
|
53
|
+
def __bool__(self):
|
|
54
|
+
raise RuntimeError(
|
|
55
|
+
"Boolean value on NHWCable is disabled. Please call .can_be or .must_be"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class NHWCableNodeCheckersRegistry(OpFuncRegistry):
|
|
60
|
+
|
|
61
|
+
def __init__(self):
|
|
62
|
+
self.no_rewriter_nhwc_ops = set()
|
|
63
|
+
|
|
64
|
+
def __missing__(self, op):
|
|
65
|
+
|
|
66
|
+
def _default_checker(node: Node):
|
|
67
|
+
"""Default checker for most of the layout insensitive ops.
|
|
68
|
+
|
|
69
|
+
The node should be marked and rewritten to NHWC if:
|
|
70
|
+
1. The node output is a single 4-D tensor.
|
|
71
|
+
2. All layout sensitive input nodes (default all inputs) of this
|
|
72
|
+
node are all marked as NHWC.
|
|
73
|
+
3. All layout sensitive input nodes return 4-D tensors.
|
|
74
|
+
4. There exists a rewrite rule for this node (explicit registry
|
|
75
|
+
required for noop.)
|
|
76
|
+
"""
|
|
77
|
+
nonlocal self
|
|
78
|
+
layout_sensitive_inputs = get_layout_sensitive_inputs(node)
|
|
79
|
+
|
|
80
|
+
can_be_nhwc = is_4d(node) and all_layout_sensitive_inputs_are_4d(node)
|
|
81
|
+
has_rewriter = layout_rewrite.has_nhwc_rewriter(node)
|
|
82
|
+
|
|
83
|
+
if can_be_nhwc and not has_rewriter:
|
|
84
|
+
self.no_rewriter_nhwc_ops.add(node.target)
|
|
85
|
+
|
|
86
|
+
return NHWCable(can_be_nhwc and has_rewriter, must_be=False)
|
|
87
|
+
|
|
88
|
+
return _default_checker
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
nhwcable_node_checkers = NHWCableNodeCheckersRegistry()
|
|
92
|
+
layout_sensitive_inputs_getters = LayoutSensitiveInputsGettersRegistry()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def can_be_nhwc(node: Node):
|
|
96
|
+
return nhwcable_node_checkers[node.target](node).can_be
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def must_be_nhwc(node: Node):
|
|
100
|
+
return nhwcable_node_checkers[node.target](node).must_be
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_layout_sensitive_inputs(node: Node):
|
|
104
|
+
return layout_sensitive_inputs_getters[node.target](node)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def get_no_rewriter_nhwc_ops():
|
|
108
|
+
"""Debug only: get the ops that may be NHWC but not due to no rewriter registered."""
|
|
109
|
+
return nhwcable_node_checkers.no_rewriter_nhwc_ops
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def is_4d(node: Node):
|
|
113
|
+
val = node.meta.get("val")
|
|
114
|
+
if val is None:
|
|
115
|
+
return False
|
|
116
|
+
if not hasattr(val, "shape"):
|
|
117
|
+
return False
|
|
118
|
+
|
|
119
|
+
return len(val.shape) == 4
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def all_layout_sensitive_inputs_are_4d(node: Node):
|
|
123
|
+
return all(is_4d(m) for m in get_layout_sensitive_inputs(node))
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
# ==== Quantize ops (use default NHWC checker)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@layout_sensitive_inputs_getters.register(
|
|
130
|
+
torch.ops.quantized_decomposed.dequantize_per_tensor
|
|
131
|
+
)
|
|
132
|
+
@layout_sensitive_inputs_getters.register(
|
|
133
|
+
torch.ops.quantized_decomposed.quantize_per_tensor
|
|
134
|
+
)
|
|
135
|
+
@layout_sensitive_inputs_getters.register(
|
|
136
|
+
torch.ops.quantized_decomposed.dequantize_per_channel
|
|
137
|
+
)
|
|
138
|
+
@layout_sensitive_inputs_getters.register(
|
|
139
|
+
torch.ops.quantized_decomposed.quantize_per_channel
|
|
140
|
+
)
|
|
141
|
+
def _qdq_layout_sensitive_inputs_getter(node: Node):
|
|
142
|
+
return [node.args[0]]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
# ==== Ops must be NHWC if possible
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@layout_sensitive_inputs_getters.register(aten.convolution)
|
|
149
|
+
@layout_sensitive_inputs_getters.register(aten._native_batch_norm_legit_no_training)
|
|
150
|
+
@layout_sensitive_inputs_getters.register(aten.native_group_norm)
|
|
151
|
+
def _first_arg_getter(node):
|
|
152
|
+
return [node.args[0]]
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# Note: default layout sensitive inputs are all inputs when not specified.
|
|
156
|
+
@nhwcable_node_checkers.register(aten.max_pool2d)
|
|
157
|
+
@nhwcable_node_checkers.register(aten.max_pool2d_with_indices)
|
|
158
|
+
@nhwcable_node_checkers.register(aten.amax)
|
|
159
|
+
@nhwcable_node_checkers.register(aten.avg_pool2d)
|
|
160
|
+
@nhwcable_node_checkers.register(aten._prelu_kernel)
|
|
161
|
+
@nhwcable_node_checkers.register(aten.upsample_bilinear2d)
|
|
162
|
+
@nhwcable_node_checkers.register(aten.upsample_nearest2d)
|
|
163
|
+
@nhwcable_node_checkers.register(aten._adaptive_avg_pool2d)
|
|
164
|
+
@nhwcable_node_checkers.register(aten.convolution)
|
|
165
|
+
def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
|
|
166
|
+
can_be = all_layout_sensitive_inputs_are_4d(node)
|
|
167
|
+
return NHWCable(can_be, must_be=can_be)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@nhwcable_node_checkers.register(aten._native_batch_norm_legit_no_training)
|
|
171
|
+
@nhwcable_node_checkers.register(aten.native_group_norm)
|
|
172
|
+
def _aten_norm_checker(node):
|
|
173
|
+
val = node.meta.get("val")
|
|
174
|
+
if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
|
|
175
|
+
return NHWCable(can_be=False, must_be=False)
|
|
176
|
+
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
# ==== Ops must be NCHW
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@nhwcable_node_checkers.register(torch.ops.xla.mark_tensor)
|
|
183
|
+
@nhwcable_node_checkers.register(utils.tensor_to_nchw)
|
|
184
|
+
@nhwcable_node_checkers.register(utils.tensor_to_nhwc)
|
|
185
|
+
@nhwcable_node_checkers.register("output")
|
|
186
|
+
@nhwcable_node_checkers.register(aten.view)
|
|
187
|
+
@nhwcable_node_checkers.register(aten.unsqueeze_copy)
|
|
188
|
+
@nhwcable_node_checkers.register(aten.expand)
|
|
189
|
+
@nhwcable_node_checkers.register(aten.permute)
|
|
190
|
+
@nhwcable_node_checkers.register(aten.as_strided)
|
|
191
|
+
def _not_nhwc(node: Node):
|
|
192
|
+
return NHWCable(can_be=False, must_be=False)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
# ==== Others
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
@layout_sensitive_inputs_getters.register(aten.index)
|
|
199
|
+
@layout_sensitive_inputs_getters.register(aten._unsafe_index)
|
|
200
|
+
def _aten_index_layout_sensitive_inputs_getter(node):
|
|
201
|
+
return [node.args[0]]
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@nhwcable_node_checkers.register(aten.index)
|
|
205
|
+
@nhwcable_node_checkers.register(aten._unsafe_index)
|
|
206
|
+
def _aten_index_checker(node):
|
|
207
|
+
layout_sensitive_inputs = get_layout_sensitive_inputs(node)
|
|
208
|
+
can_be = is_4d(node) and all_layout_sensitive_inputs_are_4d(node)
|
|
209
|
+
return NHWCable(can_be, must_be=False)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
@nhwcable_node_checkers.register(operator.getitem)
|
|
213
|
+
def _getitem_checker(node):
|
|
214
|
+
src = node.args[0]
|
|
215
|
+
return nhwcable_node_checkers[src.target](src)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
# Tag which is added to a node's meta to indicate that is is part of the NHWC
|
|
18
|
+
# partition.
|
|
19
|
+
IS_NHWC_NODE = "OPTIMIZE_LAYOUT_TRANSPOSES_PASS__IS_NHWC_NODE"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# Tag which is added to a node's meta to indicate that it is derived completely
|
|
23
|
+
# from constant and/or weight tensor(s).
|
|
24
|
+
IS_CONST_NODE = "OPTIMIZE_LAYOUT_TRANSPOSES_PASS__IS_CONST_NODE"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def mark_as_nhwc_node(node: torch.fx.Node) -> None:
|
|
28
|
+
node.meta[IS_NHWC_NODE] = True
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def mark_as_nchw_node(node: torch.fx.Node) -> None:
|
|
32
|
+
node.meta[IS_NHWC_NODE] = False
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def is_nhwc_node(node: torch.fx.Node) -> bool:
|
|
36
|
+
return node.meta.get(IS_NHWC_NODE, False)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def is_nchw_node(node: torch.fx.Node) -> bool:
|
|
40
|
+
return not is_nhwc_node(node)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def mark_as_const_node(node: torch.fx.Node) -> None:
|
|
44
|
+
node.meta[IS_CONST_NODE] = True
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def is_const_node(node: torch.fx.Node) -> bool:
|
|
48
|
+
return node.meta.get(IS_CONST_NODE, False)
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from . import greedy
|
|
17
|
+
from . import min_cut
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
|
|
19
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def partition(graph_module: torch.fx.GraphModule):
|
|
23
|
+
"""Partition the graph module into NHWC and non-NHWC subgraphs, and mark
|
|
24
|
+
nodes in the NHWC partitions.
|
|
25
|
+
|
|
26
|
+
Implements O(|V|) greedy partitioning algorithm.
|
|
27
|
+
See go/pytorch-layout-transpose-optimization for more details.
|
|
28
|
+
"""
|
|
29
|
+
graph = graph_module.graph
|
|
30
|
+
|
|
31
|
+
for node in list(graph.nodes):
|
|
32
|
+
if len(node.all_input_nodes) == 0:
|
|
33
|
+
# This node has no inputs so we don't need to change anything
|
|
34
|
+
continue
|
|
35
|
+
|
|
36
|
+
if layout_check.must_be_nhwc(node):
|
|
37
|
+
# If the node has must_be_nhwc equals true, mark this node as NHWC
|
|
38
|
+
|
|
39
|
+
layout_mark.mark_as_nhwc_node(node)
|
|
40
|
+
elif layout_check.can_be_nhwc(node):
|
|
41
|
+
# If the following conditions are all true, mark this node as NHWC
|
|
42
|
+
# - The node has can_be_nhwc equals true
|
|
43
|
+
# - Any of the node's layout sensitive inputs is marked as NHWC
|
|
44
|
+
# - All the node's layout sensitive inputs are 4D tensors
|
|
45
|
+
|
|
46
|
+
layout_sensitive_inputs = layout_check.get_layout_sensitive_inputs(node)
|
|
47
|
+
|
|
48
|
+
should_be_nhwc = any(map(layout_mark.is_nhwc_node, layout_sensitive_inputs))
|
|
49
|
+
for input_node in layout_sensitive_inputs:
|
|
50
|
+
if not layout_mark.is_nhwc_node(input_node) and not layout_check.is_4d(
|
|
51
|
+
input_node
|
|
52
|
+
):
|
|
53
|
+
should_be_nhwc = False
|
|
54
|
+
|
|
55
|
+
if should_be_nhwc:
|
|
56
|
+
layout_mark.mark_as_nhwc_node(node)
|
|
57
|
+
|
|
58
|
+
graph_module.recompile()
|
|
59
|
+
return graph_module
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import collections
|
|
17
|
+
import dataclasses
|
|
18
|
+
import itertools
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import scipy
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
|
|
25
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def can_partition(graph_module: torch.fx.GraphModule):
|
|
29
|
+
"""Returns true if the input graph_module can be partitioned by min cut solver
|
|
30
|
+
in a reasonable time.
|
|
31
|
+
|
|
32
|
+
The min cut solver implements O(|V|^2|E|) Dinic's algorithm, which may
|
|
33
|
+
take a long time to complete for large graph module. This function determines
|
|
34
|
+
whether the graph module can be partitioned by the graph module size.
|
|
35
|
+
|
|
36
|
+
See go/pytorch-layout-transpose-optimization for more details.
|
|
37
|
+
"""
|
|
38
|
+
graph = graph_module.graph
|
|
39
|
+
n_nodes = len(graph.nodes)
|
|
40
|
+
n_edges = sum(len(n.users) for n in graph.nodes)
|
|
41
|
+
|
|
42
|
+
# According to the experiments our model set, |V| < 2000 can
|
|
43
|
+
# be partitioned generally in a reasonable time.
|
|
44
|
+
return n_nodes**2 * n_edges < 2000**3
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class MinCutSolver:
|
|
48
|
+
# A number that is large enough but can fit into int32 with all computations
|
|
49
|
+
# in the maximum flow.
|
|
50
|
+
INF_COST = 1 << 28
|
|
51
|
+
|
|
52
|
+
def __init__(self):
|
|
53
|
+
self._edges_map = collections.defaultdict(dict)
|
|
54
|
+
self._obj_to_node = {}
|
|
55
|
+
self._node_to_obj = {}
|
|
56
|
+
self._nodes_cnt = 0
|
|
57
|
+
|
|
58
|
+
self.source = self._next_nid()
|
|
59
|
+
self.sink = self._next_nid()
|
|
60
|
+
|
|
61
|
+
def _next_nid(self):
|
|
62
|
+
nid = self._nodes_cnt
|
|
63
|
+
self._nodes_cnt += 1
|
|
64
|
+
return nid
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def nodes(self):
|
|
68
|
+
return list(range(self._nodes_cnt))
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def edges_map(self):
|
|
72
|
+
return self._edges_map
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def edges(self):
|
|
76
|
+
return [
|
|
77
|
+
[n, m, cost]
|
|
78
|
+
for n, next_nodes in self._edges_map.items()
|
|
79
|
+
for m, cost in next_nodes.items()
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def graph(self):
|
|
84
|
+
edges = np.array(self.edges)
|
|
85
|
+
return scipy.sparse.csr_matrix(
|
|
86
|
+
(np.minimum(edges[:, 2], MinCutSolver.INF_COST), (edges[:, 0], edges[:, 1])),
|
|
87
|
+
shape=(self._nodes_cnt, self._nodes_cnt),
|
|
88
|
+
dtype=np.int32,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def get_nid(self, obj=None):
|
|
92
|
+
if obj is None:
|
|
93
|
+
return self._next_nid()
|
|
94
|
+
|
|
95
|
+
nid = self._obj_to_node.get(obj)
|
|
96
|
+
if nid is None:
|
|
97
|
+
nid = self._next_nid()
|
|
98
|
+
|
|
99
|
+
self._obj_to_node[obj] = nid
|
|
100
|
+
self._node_to_obj[nid] = obj
|
|
101
|
+
return nid
|
|
102
|
+
|
|
103
|
+
def get_obj(self, nid: int):
|
|
104
|
+
return self._node_to_obj.get(nid, None)
|
|
105
|
+
|
|
106
|
+
def add_edge(self, a_id: int, b_id: int, cost: int):
|
|
107
|
+
assert isinstance(cost, int)
|
|
108
|
+
self._edges_map[a_id][b_id] = cost
|
|
109
|
+
|
|
110
|
+
def solve(self):
|
|
111
|
+
flow = scipy.sparse.csgraph.maximum_flow(
|
|
112
|
+
self.graph, self.source, self.sink, method="dinic"
|
|
113
|
+
).flow
|
|
114
|
+
|
|
115
|
+
# Max-flow min-cut theorem: find min-cuts in the residual network.
|
|
116
|
+
ds = scipy.cluster.hierarchy.DisjointSet(self.nodes)
|
|
117
|
+
for n, m, cost in self.edges:
|
|
118
|
+
if abs(flow[n, m]) < cost:
|
|
119
|
+
ds.merge(n, m)
|
|
120
|
+
|
|
121
|
+
residual_reachable_nodes = ds.subset(self.source)
|
|
122
|
+
|
|
123
|
+
cuts = set()
|
|
124
|
+
for n, m, cost in self.edges:
|
|
125
|
+
if n in residual_reachable_nodes and m not in residual_reachable_nodes:
|
|
126
|
+
cuts.add((n, m))
|
|
127
|
+
|
|
128
|
+
return cuts
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@dataclasses.dataclass(frozen=True)
|
|
132
|
+
class MultiUsersDummyNode:
|
|
133
|
+
src: torch.fx.Node
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def partition(graph_module: torch.fx.GraphModule):
|
|
137
|
+
"""Partition the graph module into NHWC and non-NHWC subgraphs, and mark
|
|
138
|
+
nodes in the NHWC partitions.
|
|
139
|
+
|
|
140
|
+
Implements O(|V|^2|E|) min-cut (optimal) partitioning algorithm.
|
|
141
|
+
See go/pytorch-layout-transpose-optimization for more details.
|
|
142
|
+
"""
|
|
143
|
+
graph = graph_module.graph
|
|
144
|
+
|
|
145
|
+
mc_solver = MinCutSolver()
|
|
146
|
+
for fx_node in graph.nodes:
|
|
147
|
+
if layout_mark.is_const_node(fx_node):
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
nid = mc_solver.get_nid(fx_node)
|
|
151
|
+
if fx_node.op in ("placeholder", "output"):
|
|
152
|
+
# All inputs and outputs are not NHWCable nodes in the graph,
|
|
153
|
+
# connected to source S directly with inf cost to cut
|
|
154
|
+
mc_solver.add_edge(mc_solver.source, nid, cost=MinCutSolver.INF_COST)
|
|
155
|
+
elif not layout_check.can_be_nhwc(fx_node):
|
|
156
|
+
# All not NHWCable nodes are connected to source S directly,
|
|
157
|
+
# with inf cost to cut.
|
|
158
|
+
mc_solver.add_edge(mc_solver.source, nid, cost=MinCutSolver.INF_COST)
|
|
159
|
+
elif layout_check.must_be_nhwc(fx_node):
|
|
160
|
+
# All must be NHWC nodes are connected to sink T directly,
|
|
161
|
+
# with inf cost to cut
|
|
162
|
+
mc_solver.add_edge(nid, mc_solver.sink, cost=MinCutSolver.INF_COST)
|
|
163
|
+
|
|
164
|
+
cut_cost = 10 # set 10 to be a unit of cut cost
|
|
165
|
+
if fx_node.target in (torch.ops.aten.mean.default, torch.ops.aten.mean.dim):
|
|
166
|
+
# TFLite converter cannot fuse the lowering of (tpos-mean) but (mean-tpos)
|
|
167
|
+
# when it applies on the feature dimensions. Therefore decreasing the cut
|
|
168
|
+
# cost for aten.mean's out-going edges to favor having a cut (transpose)
|
|
169
|
+
# after the node than before when the number of transposes are equal.
|
|
170
|
+
# TODO: Remove this rule when converter has fuse rule for tpos-mean.
|
|
171
|
+
cut_cost = 9
|
|
172
|
+
|
|
173
|
+
if len(fx_node.users) > 1:
|
|
174
|
+
# If a node's (A1) output is used by multiple nodes (B1, B2, B3, ...),
|
|
175
|
+
# the cost to split A1 and Bs into different partitions would just be 1
|
|
176
|
+
# transpose. So we need to introduce a dummy node between A1 and Bs in the
|
|
177
|
+
# min-cut graph to reflect the fact that disconnecting them doesn't
|
|
178
|
+
# introduce multiple transposes.
|
|
179
|
+
dummy_nid = mc_solver.get_nid(MultiUsersDummyNode(fx_node))
|
|
180
|
+
mc_solver.add_edge(nid, dummy_nid, cost=cut_cost)
|
|
181
|
+
mc_solver.add_edge(dummy_nid, nid, cost=cut_cost)
|
|
182
|
+
nid = dummy_nid
|
|
183
|
+
|
|
184
|
+
for user in fx_node.users:
|
|
185
|
+
# All the other nodes and edges in the model graph are scattered
|
|
186
|
+
# and connected as is in the new graph, with 1 cost to cut an edge.
|
|
187
|
+
user_id = mc_solver.get_nid(user)
|
|
188
|
+
mc_solver.add_edge(nid, user_id, cost=cut_cost)
|
|
189
|
+
mc_solver.add_edge(user_id, nid, cost=cut_cost)
|
|
190
|
+
|
|
191
|
+
cuts = mc_solver.solve()
|
|
192
|
+
|
|
193
|
+
# Find nodes that is connected to sink after the min-cut and mark as NHWC.
|
|
194
|
+
ds = scipy.cluster.hierarchy.DisjointSet(mc_solver.nodes)
|
|
195
|
+
for n, m, cost in mc_solver.edges:
|
|
196
|
+
if (n, m) in cuts or (m, n) in cuts:
|
|
197
|
+
continue
|
|
198
|
+
ds.merge(n, m)
|
|
199
|
+
assert not ds.connected(mc_solver.source, mc_solver.sink)
|
|
200
|
+
|
|
201
|
+
for nid in mc_solver.nodes:
|
|
202
|
+
if ds.connected(nid, mc_solver.source):
|
|
203
|
+
continue
|
|
204
|
+
|
|
205
|
+
obj = mc_solver.get_obj(nid)
|
|
206
|
+
if obj is None:
|
|
207
|
+
continue
|
|
208
|
+
if isinstance(obj, MultiUsersDummyNode):
|
|
209
|
+
continue
|
|
210
|
+
|
|
211
|
+
assert isinstance(obj, torch.fx.Node)
|
|
212
|
+
layout_mark.mark_as_nhwc_node(obj)
|
|
213
|
+
|
|
214
|
+
graph_module.recompile()
|
|
215
|
+
return graph_module
|