ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (91) hide show
  1. ai_edge_torch/__init__.py +30 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +330 -0
  5. ai_edge_torch/convert/converter.py +171 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
  9. ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +273 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +171 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/debug/__init__.py +16 -0
  27. ai_edge_torch/debug/culprit.py +423 -0
  28. ai_edge_torch/debug/test/__init__.py +14 -0
  29. ai_edge_torch/debug/test/test_culprit.py +133 -0
  30. ai_edge_torch/debug/utils.py +48 -0
  31. ai_edge_torch/experimental/__init__.py +14 -0
  32. ai_edge_torch/generative/__init__.py +14 -0
  33. ai_edge_torch/generative/examples/__init__.py +14 -0
  34. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  35. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  36. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  37. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  39. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  40. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  42. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  43. ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
  44. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
  46. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  47. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  48. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  49. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  50. ai_edge_torch/generative/layers/__init__.py +14 -0
  51. ai_edge_torch/generative/layers/attention.py +288 -0
  52. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  53. ai_edge_torch/generative/layers/builder.py +103 -0
  54. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  55. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  56. ai_edge_torch/generative/layers/model_config.py +135 -0
  57. ai_edge_torch/generative/layers/normalization.py +62 -0
  58. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  59. ai_edge_torch/generative/quantize/__init__.py +14 -0
  60. ai_edge_torch/generative/quantize/example.py +45 -0
  61. ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
  62. ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
  63. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  64. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  65. ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
  66. ai_edge_torch/generative/test/__init__.py +14 -0
  67. ai_edge_torch/generative/test/test_model_conversion.py +201 -0
  68. ai_edge_torch/generative/test/test_quantize.py +109 -0
  69. ai_edge_torch/generative/utilities/__init__.py +15 -0
  70. ai_edge_torch/generative/utilities/loader.py +290 -0
  71. ai_edge_torch/generative/utilities/t5_loader.py +467 -0
  72. ai_edge_torch/hlfb/__init__.py +16 -0
  73. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  74. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  75. ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
  76. ai_edge_torch/hlfb/test/__init__.py +14 -0
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  79. ai_edge_torch/model.py +134 -0
  80. ai_edge_torch/quantize/__init__.py +16 -0
  81. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  83. ai_edge_torch/quantize/quant_config.py +85 -0
  84. ai_edge_torch/testing/__init__.py +14 -0
  85. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  86. ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
  87. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
  88. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
  89. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
  90. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
  91. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
@@ -0,0 +1,215 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import dataclasses
16
+ import operator
17
+
18
+ import torch
19
+ from torch.fx import Node
20
+
21
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
22
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite # NOQA
23
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
24
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass.op_func_registry import OpFuncRegistry # NOQA
25
+
26
+ aten = torch.ops.aten
27
+
28
+ __all__ = [
29
+ "is_4d",
30
+ "can_be_nhwc",
31
+ "must_be_nhwc",
32
+ "get_layout_sensitive_inputs",
33
+ "get_no_rewriter_nhwc_ops",
34
+ ]
35
+
36
+
37
+ class LayoutSensitiveInputsGettersRegistry(OpFuncRegistry):
38
+
39
+ def __missing__(self, op):
40
+
41
+ def _default_getter(node: Node):
42
+ """Default layout sensitive inputs are all input nodes."""
43
+ return node.all_input_nodes
44
+
45
+ return _default_getter
46
+
47
+
48
+ @dataclasses.dataclass
49
+ class NHWCable:
50
+ can_be: bool
51
+ must_be: bool
52
+
53
+ def __bool__(self):
54
+ raise RuntimeError(
55
+ "Boolean value on NHWCable is disabled. Please call .can_be or .must_be"
56
+ )
57
+
58
+
59
+ class NHWCableNodeCheckersRegistry(OpFuncRegistry):
60
+
61
+ def __init__(self):
62
+ self.no_rewriter_nhwc_ops = set()
63
+
64
+ def __missing__(self, op):
65
+
66
+ def _default_checker(node: Node):
67
+ """Default checker for most of the layout insensitive ops.
68
+
69
+ The node should be marked and rewritten to NHWC if:
70
+ 1. The node output is a single 4-D tensor.
71
+ 2. All layout sensitive input nodes (default all inputs) of this
72
+ node are all marked as NHWC.
73
+ 3. All layout sensitive input nodes return 4-D tensors.
74
+ 4. There exists a rewrite rule for this node (explicit registry
75
+ required for noop.)
76
+ """
77
+ nonlocal self
78
+ layout_sensitive_inputs = get_layout_sensitive_inputs(node)
79
+
80
+ can_be_nhwc = is_4d(node) and all_layout_sensitive_inputs_are_4d(node)
81
+ has_rewriter = layout_rewrite.has_nhwc_rewriter(node)
82
+
83
+ if can_be_nhwc and not has_rewriter:
84
+ self.no_rewriter_nhwc_ops.add(node.target)
85
+
86
+ return NHWCable(can_be_nhwc and has_rewriter, must_be=False)
87
+
88
+ return _default_checker
89
+
90
+
91
+ nhwcable_node_checkers = NHWCableNodeCheckersRegistry()
92
+ layout_sensitive_inputs_getters = LayoutSensitiveInputsGettersRegistry()
93
+
94
+
95
+ def can_be_nhwc(node: Node):
96
+ return nhwcable_node_checkers[node.target](node).can_be
97
+
98
+
99
+ def must_be_nhwc(node: Node):
100
+ return nhwcable_node_checkers[node.target](node).must_be
101
+
102
+
103
+ def get_layout_sensitive_inputs(node: Node):
104
+ return layout_sensitive_inputs_getters[node.target](node)
105
+
106
+
107
+ def get_no_rewriter_nhwc_ops():
108
+ """Debug only: get the ops that may be NHWC but not due to no rewriter registered."""
109
+ return nhwcable_node_checkers.no_rewriter_nhwc_ops
110
+
111
+
112
+ def is_4d(node: Node):
113
+ val = node.meta.get("val")
114
+ if val is None:
115
+ return False
116
+ if not hasattr(val, "shape"):
117
+ return False
118
+
119
+ return len(val.shape) == 4
120
+
121
+
122
+ def all_layout_sensitive_inputs_are_4d(node: Node):
123
+ return all(is_4d(m) for m in get_layout_sensitive_inputs(node))
124
+
125
+
126
+ # ==== Quantize ops (use default NHWC checker)
127
+
128
+
129
+ @layout_sensitive_inputs_getters.register(
130
+ torch.ops.quantized_decomposed.dequantize_per_tensor
131
+ )
132
+ @layout_sensitive_inputs_getters.register(
133
+ torch.ops.quantized_decomposed.quantize_per_tensor
134
+ )
135
+ @layout_sensitive_inputs_getters.register(
136
+ torch.ops.quantized_decomposed.dequantize_per_channel
137
+ )
138
+ @layout_sensitive_inputs_getters.register(
139
+ torch.ops.quantized_decomposed.quantize_per_channel
140
+ )
141
+ def _qdq_layout_sensitive_inputs_getter(node: Node):
142
+ return [node.args[0]]
143
+
144
+
145
+ # ==== Ops must be NHWC if possible
146
+
147
+
148
+ @layout_sensitive_inputs_getters.register(aten.convolution)
149
+ @layout_sensitive_inputs_getters.register(aten._native_batch_norm_legit_no_training)
150
+ @layout_sensitive_inputs_getters.register(aten.native_group_norm)
151
+ def _first_arg_getter(node):
152
+ return [node.args[0]]
153
+
154
+
155
+ # Note: default layout sensitive inputs are all inputs when not specified.
156
+ @nhwcable_node_checkers.register(aten.max_pool2d)
157
+ @nhwcable_node_checkers.register(aten.max_pool2d_with_indices)
158
+ @nhwcable_node_checkers.register(aten.amax)
159
+ @nhwcable_node_checkers.register(aten.avg_pool2d)
160
+ @nhwcable_node_checkers.register(aten._prelu_kernel)
161
+ @nhwcable_node_checkers.register(aten.upsample_bilinear2d)
162
+ @nhwcable_node_checkers.register(aten.upsample_nearest2d)
163
+ @nhwcable_node_checkers.register(aten._adaptive_avg_pool2d)
164
+ @nhwcable_node_checkers.register(aten.convolution)
165
+ def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
166
+ can_be = all_layout_sensitive_inputs_are_4d(node)
167
+ return NHWCable(can_be, must_be=can_be)
168
+
169
+
170
+ @nhwcable_node_checkers.register(aten._native_batch_norm_legit_no_training)
171
+ @nhwcable_node_checkers.register(aten.native_group_norm)
172
+ def _aten_norm_checker(node):
173
+ val = node.meta.get("val")
174
+ if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
175
+ return NHWCable(can_be=False, must_be=False)
176
+ return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
177
+
178
+
179
+ # ==== Ops must be NCHW
180
+
181
+
182
+ @nhwcable_node_checkers.register(torch.ops.xla.mark_tensor)
183
+ @nhwcable_node_checkers.register(utils.tensor_to_nchw)
184
+ @nhwcable_node_checkers.register(utils.tensor_to_nhwc)
185
+ @nhwcable_node_checkers.register("output")
186
+ @nhwcable_node_checkers.register(aten.view)
187
+ @nhwcable_node_checkers.register(aten.unsqueeze_copy)
188
+ @nhwcable_node_checkers.register(aten.expand)
189
+ @nhwcable_node_checkers.register(aten.permute)
190
+ @nhwcable_node_checkers.register(aten.as_strided)
191
+ def _not_nhwc(node: Node):
192
+ return NHWCable(can_be=False, must_be=False)
193
+
194
+
195
+ # ==== Others
196
+
197
+
198
+ @layout_sensitive_inputs_getters.register(aten.index)
199
+ @layout_sensitive_inputs_getters.register(aten._unsafe_index)
200
+ def _aten_index_layout_sensitive_inputs_getter(node):
201
+ return [node.args[0]]
202
+
203
+
204
+ @nhwcable_node_checkers.register(aten.index)
205
+ @nhwcable_node_checkers.register(aten._unsafe_index)
206
+ def _aten_index_checker(node):
207
+ layout_sensitive_inputs = get_layout_sensitive_inputs(node)
208
+ can_be = is_4d(node) and all_layout_sensitive_inputs_are_4d(node)
209
+ return NHWCable(can_be, must_be=False)
210
+
211
+
212
+ @nhwcable_node_checkers.register(operator.getitem)
213
+ def _getitem_checker(node):
214
+ src = node.args[0]
215
+ return nhwcable_node_checkers[src.target](src)
@@ -0,0 +1,48 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import torch
16
+
17
+ # Tag which is added to a node's meta to indicate that is is part of the NHWC
18
+ # partition.
19
+ IS_NHWC_NODE = "OPTIMIZE_LAYOUT_TRANSPOSES_PASS__IS_NHWC_NODE"
20
+
21
+
22
+ # Tag which is added to a node's meta to indicate that it is derived completely
23
+ # from constant and/or weight tensor(s).
24
+ IS_CONST_NODE = "OPTIMIZE_LAYOUT_TRANSPOSES_PASS__IS_CONST_NODE"
25
+
26
+
27
+ def mark_as_nhwc_node(node: torch.fx.Node) -> None:
28
+ node.meta[IS_NHWC_NODE] = True
29
+
30
+
31
+ def mark_as_nchw_node(node: torch.fx.Node) -> None:
32
+ node.meta[IS_NHWC_NODE] = False
33
+
34
+
35
+ def is_nhwc_node(node: torch.fx.Node) -> bool:
36
+ return node.meta.get(IS_NHWC_NODE, False)
37
+
38
+
39
+ def is_nchw_node(node: torch.fx.Node) -> bool:
40
+ return not is_nhwc_node(node)
41
+
42
+
43
+ def mark_as_const_node(node: torch.fx.Node) -> None:
44
+ node.meta[IS_CONST_NODE] = True
45
+
46
+
47
+ def is_const_node(node: torch.fx.Node) -> bool:
48
+ return node.meta.get(IS_CONST_NODE, False)
@@ -0,0 +1,17 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from . import greedy
17
+ from . import min_cut
@@ -0,0 +1,59 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import torch
17
+
18
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
19
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
20
+
21
+
22
+ def partition(graph_module: torch.fx.GraphModule):
23
+ """Partition the graph module into NHWC and non-NHWC subgraphs, and mark
24
+ nodes in the NHWC partitions.
25
+
26
+ Implements O(|V|) greedy partitioning algorithm.
27
+ See go/pytorch-layout-transpose-optimization for more details.
28
+ """
29
+ graph = graph_module.graph
30
+
31
+ for node in list(graph.nodes):
32
+ if len(node.all_input_nodes) == 0:
33
+ # This node has no inputs so we don't need to change anything
34
+ continue
35
+
36
+ if layout_check.must_be_nhwc(node):
37
+ # If the node has must_be_nhwc equals true, mark this node as NHWC
38
+
39
+ layout_mark.mark_as_nhwc_node(node)
40
+ elif layout_check.can_be_nhwc(node):
41
+ # If the following conditions are all true, mark this node as NHWC
42
+ # - The node has can_be_nhwc equals true
43
+ # - Any of the node's layout sensitive inputs is marked as NHWC
44
+ # - All the node's layout sensitive inputs are 4D tensors
45
+
46
+ layout_sensitive_inputs = layout_check.get_layout_sensitive_inputs(node)
47
+
48
+ should_be_nhwc = any(map(layout_mark.is_nhwc_node, layout_sensitive_inputs))
49
+ for input_node in layout_sensitive_inputs:
50
+ if not layout_mark.is_nhwc_node(input_node) and not layout_check.is_4d(
51
+ input_node
52
+ ):
53
+ should_be_nhwc = False
54
+
55
+ if should_be_nhwc:
56
+ layout_mark.mark_as_nhwc_node(node)
57
+
58
+ graph_module.recompile()
59
+ return graph_module
@@ -0,0 +1,196 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import collections
17
+ import dataclasses
18
+ import itertools
19
+
20
+ import numpy as np
21
+ import scipy
22
+ import torch
23
+
24
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
25
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
26
+
27
+
28
+ class MinCutSolver:
29
+ # A number that is large enough but can fit into int32 with all computations
30
+ # in the maximum flow.
31
+ INF_COST = 1 << 28
32
+
33
+ def __init__(self):
34
+ self._edges_map = collections.defaultdict(dict)
35
+ self._obj_to_node = {}
36
+ self._node_to_obj = {}
37
+ self._nodes_cnt = 0
38
+
39
+ self.source = self._next_nid()
40
+ self.sink = self._next_nid()
41
+
42
+ def _next_nid(self):
43
+ nid = self._nodes_cnt
44
+ self._nodes_cnt += 1
45
+ return nid
46
+
47
+ @property
48
+ def nodes(self):
49
+ return list(range(self._nodes_cnt))
50
+
51
+ @property
52
+ def edges_map(self):
53
+ return self._edges_map
54
+
55
+ @property
56
+ def edges(self):
57
+ return [
58
+ [n, m, cost]
59
+ for n, next_nodes in self._edges_map.items()
60
+ for m, cost in next_nodes.items()
61
+ ]
62
+
63
+ @property
64
+ def graph(self):
65
+ edges = np.array(self.edges)
66
+ return scipy.sparse.csr_matrix(
67
+ (np.minimum(edges[:, 2], MinCutSolver.INF_COST), (edges[:, 0], edges[:, 1])),
68
+ shape=(self._nodes_cnt, self._nodes_cnt),
69
+ dtype=np.int32,
70
+ )
71
+
72
+ def get_nid(self, obj=None):
73
+ if obj is None:
74
+ return self._next_nid()
75
+
76
+ nid = self._obj_to_node.get(obj)
77
+ if nid is None:
78
+ nid = self._next_nid()
79
+
80
+ self._obj_to_node[obj] = nid
81
+ self._node_to_obj[nid] = obj
82
+ return nid
83
+
84
+ def get_obj(self, nid: int):
85
+ return self._node_to_obj.get(nid, None)
86
+
87
+ def add_edge(self, a_id: int, b_id: int, cost: int):
88
+ assert isinstance(cost, int)
89
+ self._edges_map[a_id][b_id] = cost
90
+
91
+ def solve(self):
92
+ flow = scipy.sparse.csgraph.maximum_flow(
93
+ self.graph, self.source, self.sink, method="dinic"
94
+ ).flow
95
+
96
+ # Max-flow min-cut theorem: find min-cuts in the residual network.
97
+ ds = scipy.cluster.hierarchy.DisjointSet(self.nodes)
98
+ for n, m, cost in self.edges:
99
+ if abs(flow[n, m]) < cost:
100
+ ds.merge(n, m)
101
+
102
+ residual_reachable_nodes = ds.subset(self.source)
103
+
104
+ cuts = set()
105
+ for n, m, cost in self.edges:
106
+ if n in residual_reachable_nodes and m not in residual_reachable_nodes:
107
+ cuts.add((n, m))
108
+
109
+ return cuts
110
+
111
+
112
+ @dataclasses.dataclass(frozen=True)
113
+ class MultiUsersDummyNode:
114
+ src: torch.fx.Node
115
+
116
+
117
+ def partition(graph_module: torch.fx.GraphModule):
118
+ """Partition the graph module into NHWC and non-NHWC subgraphs, and mark
119
+ nodes in the NHWC partitions.
120
+
121
+ Implements O(|V|^2|E|) min-cut (optimal) partitioning algorithm.
122
+ See go/pytorch-layout-transpose-optimization for more details.
123
+ """
124
+ graph = graph_module.graph
125
+
126
+ mc_solver = MinCutSolver()
127
+ for fx_node in graph.nodes:
128
+ if layout_mark.is_const_node(fx_node):
129
+ continue
130
+
131
+ nid = mc_solver.get_nid(fx_node)
132
+ if fx_node.op in ("placeholder", "output"):
133
+ # All inputs and outputs are not NHWCable nodes in the graph,
134
+ # connected to source S directly with inf cost to cut
135
+ mc_solver.add_edge(mc_solver.source, nid, cost=MinCutSolver.INF_COST)
136
+ elif not layout_check.can_be_nhwc(fx_node):
137
+ # All not NHWCable nodes are connected to source S directly,
138
+ # with inf cost to cut.
139
+ mc_solver.add_edge(mc_solver.source, nid, cost=MinCutSolver.INF_COST)
140
+ elif layout_check.must_be_nhwc(fx_node):
141
+ # All must be NHWC nodes are connected to sink T directly,
142
+ # with inf cost to cut
143
+ mc_solver.add_edge(nid, mc_solver.sink, cost=MinCutSolver.INF_COST)
144
+
145
+ cut_cost = 10 # set 10 to be a unit of cut cost
146
+ if fx_node.target in (torch.ops.aten.mean.default, torch.ops.aten.mean.dim):
147
+ # TFLite converter cannot fuse the lowering of (tpos-mean) but (mean-tpos)
148
+ # when it applies on the feature dimensions. Therefore decreasing the cut
149
+ # cost for aten.mean's out-going edges to favor having a cut (transpose)
150
+ # after the node than before when the number of transposes are equal.
151
+ # TODO: Remove this rule when converter has fuse rule for tpos-mean.
152
+ cut_cost = 9
153
+
154
+ if len(fx_node.users) > 1:
155
+ # If a node's (A1) output is used by multiple nodes (B1, B2, B3, ...),
156
+ # the cost to split A1 and Bs into different partitions would just be 1
157
+ # transpose. So we need to introduce a dummy node between A1 and Bs in the
158
+ # min-cut graph to reflect the fact that disconnecting them doesn't
159
+ # introduce multiple transposes.
160
+ dummy_nid = mc_solver.get_nid(MultiUsersDummyNode(fx_node))
161
+ mc_solver.add_edge(nid, dummy_nid, cost=cut_cost)
162
+ mc_solver.add_edge(dummy_nid, nid, cost=cut_cost)
163
+ nid = dummy_nid
164
+
165
+ for user in fx_node.users:
166
+ # All the other nodes and edges in the model graph are scattered
167
+ # and connected as is in the new graph, with 1 cost to cut an edge.
168
+ user_id = mc_solver.get_nid(user)
169
+ mc_solver.add_edge(nid, user_id, cost=cut_cost)
170
+ mc_solver.add_edge(user_id, nid, cost=cut_cost)
171
+
172
+ cuts = mc_solver.solve()
173
+
174
+ # Find nodes that is connected to sink after the min-cut and mark as NHWC.
175
+ ds = scipy.cluster.hierarchy.DisjointSet(mc_solver.nodes)
176
+ for n, m, cost in mc_solver.edges:
177
+ if (n, m) in cuts or (m, n) in cuts:
178
+ continue
179
+ ds.merge(n, m)
180
+ assert not ds.connected(mc_solver.source, mc_solver.sink)
181
+
182
+ for nid in mc_solver.nodes:
183
+ if ds.connected(nid, mc_solver.source):
184
+ continue
185
+
186
+ obj = mc_solver.get_obj(nid)
187
+ if obj is None:
188
+ continue
189
+ if isinstance(obj, MultiUsersDummyNode):
190
+ continue
191
+
192
+ assert isinstance(obj, torch.fx.Node)
193
+ layout_mark.mark_as_nhwc_node(obj)
194
+
195
+ graph_module.recompile()
196
+ return graph_module