tico 0.1.0.dev250616__py3-none-any.whl → 0.1.0.dev250618__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +6 -2
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +3 -0
- tico/passes/cast_aten_where_arg_type.py +4 -1
- tico/passes/cast_mixed_type_args.py +4 -1
- tico/passes/convert_conv1d_to_conv2d.py +12 -4
- tico/passes/convert_layout_op_to_reshape.py +3 -2
- tico/passes/convert_repeat_to_expand_copy.py +5 -2
- tico/passes/convert_to_relu6.py +4 -3
- tico/passes/decompose_addmm.py +11 -7
- tico/passes/decompose_batch_norm.py +7 -11
- tico/passes/decompose_fake_quantize.py +12 -6
- tico/passes/decompose_fake_quantize_tensor_qparams.py +12 -6
- tico/passes/decompose_group_norm.py +50 -21
- tico/passes/decompose_grouped_conv2d.py +15 -7
- tico/passes/decompose_slice_scatter.py +9 -5
- tico/passes/fuse_leading_unsqueeze_reshape.py +8 -3
- tico/passes/legalize_predefined_layout_operators.py +33 -25
- tico/passes/lower_pow2_to_mul.py +3 -1
- tico/passes/lower_to_resize_nearest_neighbor.py +21 -10
- tico/passes/lower_to_slice.py +21 -11
- tico/passes/remove_redundant_permute.py +5 -3
- tico/passes/remove_redundant_reshape.py +5 -2
- tico/passes/remove_redundant_to_copy.py +4 -0
- tico/passes/restore_linear.py +7 -5
- tico/passes/segment_index_select.py +9 -5
- tico/utils/convert.py +2 -0
- tico/utils/graph.py +48 -2
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250618.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250618.dist-info}/RECORD +35 -34
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250618.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250618.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250618.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250618.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ from torch.export import ExportedProgram
|
|
23
23
|
|
24
24
|
from tico.serialize.circle_mapping import extract_shape
|
25
25
|
from tico.utils import logging
|
26
|
+
from tico.utils.graph import create_node
|
26
27
|
from tico.utils.passes import PassBase, PassResult
|
27
28
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
29
|
from tico.utils.utils import is_target_node
|
@@ -89,24 +90,40 @@ class DecomposeGroupNorm(PassBase):
|
|
89
90
|
def __init__(self):
|
90
91
|
super().__init__()
|
91
92
|
|
92
|
-
def _insert_norm(self, graph, tensor, eps):
|
93
|
+
def _insert_norm(self, graph, tensor, eps, origin):
|
93
94
|
"""
|
94
95
|
Insert (tensor - mean) / sqrt(var + eps)) into the graph
|
95
96
|
and return the normalized tensor node.
|
96
97
|
"""
|
97
|
-
mean =
|
98
|
-
|
98
|
+
mean = create_node(
|
99
|
+
graph,
|
100
|
+
torch.ops.aten.mean.dim,
|
101
|
+
(tensor, [-1]),
|
102
|
+
{"keepdim": True},
|
103
|
+
origin=origin,
|
99
104
|
)
|
100
|
-
deviation =
|
101
|
-
|
102
|
-
var = graph.call_function(
|
103
|
-
torch.ops.aten.mean.dim, (squared, [-1]), {"keepdim": True}
|
105
|
+
deviation = create_node(
|
106
|
+
graph, torch.ops.aten.sub.Tensor, (tensor, mean), origin=origin
|
104
107
|
)
|
105
|
-
|
108
|
+
squared = create_node(
|
109
|
+
graph, torch.ops.aten.pow.Tensor_Scalar, (deviation, 2), origin=origin
|
110
|
+
)
|
111
|
+
var = create_node(
|
112
|
+
graph,
|
113
|
+
torch.ops.aten.mean.dim,
|
114
|
+
(squared, [-1]),
|
115
|
+
{"keepdim": True},
|
116
|
+
origin=origin,
|
117
|
+
)
|
118
|
+
inverse_std = create_node(
|
119
|
+
graph,
|
106
120
|
torch.ops.aten.rsqrt.default,
|
107
|
-
(graph
|
121
|
+
(create_node(graph, torch.ops.aten.add.Tensor, (var, eps), origin=origin),),
|
122
|
+
origin=origin,
|
123
|
+
)
|
124
|
+
return create_node(
|
125
|
+
graph, torch.ops.aten.mul.Tensor, (deviation, inverse_std), origin=origin
|
108
126
|
)
|
109
|
-
return graph.call_function(torch.ops.aten.mul.Tensor, (deviation, inverse_std))
|
110
127
|
|
111
128
|
def call(self, exported_program: ExportedProgram) -> PassResult:
|
112
129
|
logger = logging.getLogger(__name__)
|
@@ -178,17 +195,23 @@ class DecomposeGroupNorm(PassBase):
|
|
178
195
|
# Branch only on whether a reshape is needed; the normalization is shared.
|
179
196
|
if norm_size != x_shape[-1]:
|
180
197
|
# Pack groups so that the last dimension equals norm_size.
|
181
|
-
packed =
|
182
|
-
|
198
|
+
packed = create_node(
|
199
|
+
graph,
|
200
|
+
torch.ops.aten.reshape.default,
|
201
|
+
(x, pack_shape),
|
202
|
+
origin=node,
|
183
203
|
)
|
184
|
-
normed = self._insert_norm(graph, packed, eps)
|
204
|
+
normed = self._insert_norm(graph, packed, eps, origin=node)
|
185
205
|
# Restore the original shape after normalization.
|
186
|
-
layer_norm =
|
187
|
-
|
206
|
+
layer_norm = create_node(
|
207
|
+
graph,
|
208
|
+
torch.ops.aten.reshape.default,
|
209
|
+
(normed, x_shape),
|
210
|
+
origin=node,
|
188
211
|
)
|
189
212
|
else:
|
190
213
|
# The input already has norm_size in the last dimension.
|
191
|
-
layer_norm = self._insert_norm(graph, x, eps)
|
214
|
+
layer_norm = self._insert_norm(graph, x, eps, origin=node)
|
192
215
|
|
193
216
|
# weight
|
194
217
|
if weight:
|
@@ -197,13 +220,17 @@ class DecomposeGroupNorm(PassBase):
|
|
197
220
|
assert weight_shape[0] == C
|
198
221
|
reshape_size = [1] * len(x_shape)
|
199
222
|
reshape_size[1] = C
|
200
|
-
weight =
|
223
|
+
weight = create_node(
|
224
|
+
graph,
|
201
225
|
torch.ops.aten.view.default,
|
202
226
|
(weight, reshape_size),
|
227
|
+
origin=node,
|
203
228
|
)
|
204
|
-
layer_norm =
|
229
|
+
layer_norm = create_node(
|
230
|
+
graph,
|
205
231
|
torch.ops.aten.mul.Tensor,
|
206
232
|
(layer_norm, weight),
|
233
|
+
origin=node,
|
207
234
|
)
|
208
235
|
|
209
236
|
# bias
|
@@ -213,15 +240,17 @@ class DecomposeGroupNorm(PassBase):
|
|
213
240
|
assert bias_shape[0] == C
|
214
241
|
reshape_size = [1] * len(x_shape)
|
215
242
|
reshape_size[1] = C
|
216
|
-
bias =
|
243
|
+
bias = create_node(
|
244
|
+
graph,
|
217
245
|
torch.ops.aten.view.default,
|
218
246
|
(bias, reshape_size),
|
247
|
+
origin=node,
|
219
248
|
)
|
220
|
-
layer_norm =
|
249
|
+
layer_norm = create_node(
|
250
|
+
graph,
|
221
251
|
torch.ops.aten.add.Tensor,
|
222
252
|
(layer_norm, bias),
|
223
253
|
)
|
224
|
-
|
225
254
|
# Reset last node's meta for propagating replacing node's meta.
|
226
255
|
layer_norm.meta = {}
|
227
256
|
|
@@ -23,7 +23,7 @@ from tico.passes import ops
|
|
23
23
|
from tico.serialize.circle_mapping import extract_shape
|
24
24
|
from tico.utils import logging
|
25
25
|
from tico.utils.errors import InvalidArgumentError, NotYetSupportedError
|
26
|
-
from tico.utils.graph import add_placeholder
|
26
|
+
from tico.utils.graph import add_placeholder, create_node
|
27
27
|
from tico.utils.passes import PassBase, PassResult
|
28
28
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
29
29
|
from tico.utils.utils import is_target_node
|
@@ -159,19 +159,26 @@ class DecomposeGroupedConv2d(PassBase):
|
|
159
159
|
|
160
160
|
conv2d_tensors = []
|
161
161
|
for i in range(groups):
|
162
|
-
sliced_input =
|
162
|
+
sliced_input = create_node(
|
163
|
+
graph,
|
163
164
|
torch.ops.aten.slice.Tensor,
|
164
165
|
(input_, 1, group_size * i, group_size * (i + 1), 1),
|
166
|
+
origin=node,
|
165
167
|
)
|
166
|
-
sliced_weight =
|
168
|
+
sliced_weight = create_node(
|
169
|
+
graph,
|
167
170
|
torch.ops.aten.slice.Tensor,
|
168
171
|
(weight, 0, out_group_size * i, out_group_size * (i + 1), 1),
|
172
|
+
origin=node,
|
169
173
|
)
|
170
|
-
sliced_bias =
|
174
|
+
sliced_bias = create_node(
|
175
|
+
graph,
|
171
176
|
torch.ops.aten.slice.Tensor,
|
172
177
|
(bias, 0, out_group_size * i, out_group_size * (i + 1), 1),
|
178
|
+
origin=node,
|
173
179
|
)
|
174
|
-
conv2d_tensor =
|
180
|
+
conv2d_tensor = create_node(
|
181
|
+
graph,
|
175
182
|
conv2d_op,
|
176
183
|
(
|
177
184
|
sliced_input,
|
@@ -182,11 +189,12 @@ class DecomposeGroupedConv2d(PassBase):
|
|
182
189
|
dilation,
|
183
190
|
1,
|
184
191
|
),
|
192
|
+
origin=node,
|
185
193
|
)
|
186
194
|
conv2d_tensors.append(conv2d_tensor)
|
187
195
|
|
188
|
-
concat_output =
|
189
|
-
torch.ops.aten.cat.default, (conv2d_tensors, 1)
|
196
|
+
concat_output = create_node(
|
197
|
+
graph, torch.ops.aten.cat.default, (conv2d_tensors, 1)
|
190
198
|
)
|
191
199
|
|
192
200
|
node.replace_all_uses_with(concat_output, propagate_meta=True)
|
@@ -23,6 +23,7 @@ from torch.export import ExportedProgram
|
|
23
23
|
from tico.serialize.circle_mapping import extract_shape
|
24
24
|
|
25
25
|
from tico.utils import logging
|
26
|
+
from tico.utils.graph import create_node
|
26
27
|
from tico.utils.passes import PassBase, PassResult
|
27
28
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
29
|
from tico.utils.utils import enforce_type, is_target_node
|
@@ -130,16 +131,19 @@ class DecomposeSliceScatter(PassBase):
|
|
130
131
|
slices = []
|
131
132
|
|
132
133
|
if 0 < start:
|
133
|
-
slice_0 =
|
134
|
+
slice_0 = create_node(
|
135
|
+
graph,
|
134
136
|
torch.ops.aten.slice_copy.Tensor,
|
135
137
|
args=(input, dim, 0, start, 1),
|
138
|
+
origin=node,
|
136
139
|
)
|
137
140
|
slices.append(slice_0)
|
138
141
|
|
139
142
|
slices.append(src)
|
140
143
|
|
141
144
|
if start + end < extract_shape(input)[dim]:
|
142
|
-
slice_1 =
|
145
|
+
slice_1 = create_node(
|
146
|
+
graph,
|
143
147
|
torch.ops.aten.slice_copy.Tensor,
|
144
148
|
args=(
|
145
149
|
input,
|
@@ -148,13 +152,13 @@ class DecomposeSliceScatter(PassBase):
|
|
148
152
|
extract_shape(input)[dim],
|
149
153
|
1,
|
150
154
|
),
|
155
|
+
origin=node,
|
151
156
|
)
|
152
157
|
slices.append(slice_1)
|
153
158
|
|
154
|
-
concat =
|
155
|
-
torch.ops.aten.cat.default, args=(slices, dim)
|
159
|
+
concat = create_node(
|
160
|
+
graph, torch.ops.aten.cat.default, args=(slices, dim)
|
156
161
|
)
|
157
|
-
# Not set meta for propagating replacing node's meta.
|
158
162
|
node.replace_all_uses_with(concat, propagate_meta=True)
|
159
163
|
|
160
164
|
modified = True
|
@@ -20,6 +20,7 @@ from torch.export import ExportedProgram
|
|
20
20
|
from tico.passes import ops
|
21
21
|
from tico.serialize.circle_mapping import extract_shape
|
22
22
|
from tico.utils import logging
|
23
|
+
from tico.utils.graph import create_node
|
23
24
|
from tico.utils.passes import PassBase, PassResult
|
24
25
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
25
26
|
from tico.utils.utils import is_target_node
|
@@ -84,15 +85,19 @@ class FuseLeadingUnsqueezeReshape(PassBase):
|
|
84
85
|
k = len(back_shape) - len(permute_shape)
|
85
86
|
with graph.inserting_before(permute):
|
86
87
|
new_shape = [1] * k + list(reshape_front_size)
|
87
|
-
r_new =
|
88
|
+
r_new = create_node(
|
89
|
+
graph,
|
88
90
|
torch.ops.aten.reshape.default,
|
89
91
|
args=(reshape_front_input, new_shape),
|
92
|
+
origin=reshape_back,
|
90
93
|
)
|
91
94
|
new_p_dims = list(range(k)) + [
|
92
95
|
d + k for d in permute_dims
|
93
96
|
] # shift by k
|
94
|
-
p_new =
|
95
|
-
|
97
|
+
p_new = create_node(
|
98
|
+
graph,
|
99
|
+
torch.ops.aten.permute.default,
|
100
|
+
args=(r_new, new_p_dims),
|
96
101
|
)
|
97
102
|
|
98
103
|
reshape_back.replace_all_uses_with(p_new, propagate_meta=True)
|
@@ -23,6 +23,7 @@ from torch.export import ExportedProgram
|
|
23
23
|
from tico.serialize.circle_graph import extract_shape
|
24
24
|
from tico.utils import logging
|
25
25
|
from tico.utils.errors import NotYetSupportedError
|
26
|
+
from tico.utils.graph import create_node
|
26
27
|
from tico.utils.passes import PassBase, PassResult
|
27
28
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
29
|
from tico.utils.utils import is_target_node
|
@@ -124,9 +125,11 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
124
125
|
# TODO Introduce a method that inserts permute op.
|
125
126
|
# input permute
|
126
127
|
with graph.inserting_after(input):
|
127
|
-
input_permute =
|
128
|
+
input_permute = create_node(
|
129
|
+
graph,
|
128
130
|
torch.ops.aten.permute.default,
|
129
131
|
args=(input, NCHW_to_NHWC),
|
132
|
+
origin=input,
|
130
133
|
)
|
131
134
|
node.update_arg(node.args.index(input), input_permute)
|
132
135
|
|
@@ -142,9 +145,11 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
142
145
|
else:
|
143
146
|
assert groups == 1 or groups == input_shape[1] # Cannot reach here
|
144
147
|
|
145
|
-
weight_permute =
|
148
|
+
weight_permute = create_node(
|
149
|
+
graph,
|
146
150
|
torch.ops.aten.permute.default,
|
147
151
|
args=(weight, perm),
|
152
|
+
origin=weight,
|
148
153
|
)
|
149
154
|
if args.weight.target in [
|
150
155
|
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
@@ -171,18 +176,16 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
171
176
|
else:
|
172
177
|
assert groups == 1 or groups == input_shape[1] # Cannot reach here
|
173
178
|
|
174
|
-
circle_op =
|
175
|
-
legalized_op,
|
176
|
-
args=node.args,
|
177
|
-
kwargs=node.kwargs,
|
179
|
+
circle_op = create_node(
|
180
|
+
graph, legalized_op, args=node.args, kwargs=node.kwargs, origin=node
|
178
181
|
)
|
179
182
|
# output permute
|
180
183
|
NHWC_to_NCHW = [0, 3, 1, 2]
|
181
|
-
conv_out_permute =
|
184
|
+
conv_out_permute = create_node(
|
185
|
+
graph,
|
182
186
|
torch.ops.aten.permute.default,
|
183
187
|
args=(circle_op, NHWC_to_NCHW),
|
184
188
|
)
|
185
|
-
# Not set meta for propagating replacing node's meta.
|
186
189
|
node.replace_all_uses_with(conv_out_permute, propagate_meta=True)
|
187
190
|
|
188
191
|
logger.debug(f"{node.name} is replaced with {circle_op.name}")
|
@@ -224,25 +227,29 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
224
227
|
with graph.inserting_after(input):
|
225
228
|
# input permute
|
226
229
|
NCHW_to_NHWC = [0, 2, 3, 1]
|
227
|
-
input_permute =
|
230
|
+
input_permute = create_node(
|
231
|
+
graph,
|
228
232
|
torch.ops.aten.permute.default,
|
229
233
|
args=(input, NCHW_to_NHWC),
|
234
|
+
origin=input,
|
230
235
|
)
|
231
236
|
node.update_arg(node.args.index(input), input_permute)
|
232
237
|
with graph.inserting_before(node):
|
233
238
|
# circle instnorm
|
234
|
-
circle_instnorm =
|
239
|
+
circle_instnorm = create_node(
|
240
|
+
graph,
|
235
241
|
torch.ops.circle_custom.instance_norm,
|
236
242
|
args=node.args,
|
237
243
|
kwargs=node.kwargs,
|
244
|
+
origin=node,
|
238
245
|
)
|
239
246
|
# output permute
|
240
247
|
NHWC_to_NCHW = [0, 3, 1, 2]
|
241
|
-
instnorm_out_permute =
|
248
|
+
instnorm_out_permute = create_node(
|
249
|
+
graph,
|
242
250
|
torch.ops.aten.permute.default,
|
243
251
|
args=(circle_instnorm, NHWC_to_NCHW),
|
244
252
|
)
|
245
|
-
# Not set meta for propagating replacing node's meta.
|
246
253
|
node.replace_all_uses_with(instnorm_out_permute, propagate_meta=True)
|
247
254
|
|
248
255
|
logger.debug(f"{node.name} is replaced with {circle_instnorm.name}")
|
@@ -275,25 +282,25 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
275
282
|
# TODO Introduce a method that inserts permute op.
|
276
283
|
# input permute
|
277
284
|
with graph.inserting_after(input_):
|
278
|
-
input_permute =
|
285
|
+
input_permute = create_node(
|
286
|
+
graph,
|
279
287
|
torch.ops.aten.permute.default,
|
280
288
|
args=(input_, NCHW_to_NHWC),
|
289
|
+
origin=input_,
|
281
290
|
)
|
282
291
|
node.update_arg(node.args.index(input_), input_permute)
|
283
292
|
with graph.inserting_before(node):
|
284
293
|
legalized_op = torch.ops.circle_custom.maxpool2d
|
285
|
-
circle_maxpool2d =
|
286
|
-
legalized_op,
|
287
|
-
args=node.args,
|
288
|
-
kwargs=node.kwargs,
|
294
|
+
circle_maxpool2d = create_node(
|
295
|
+
graph, legalized_op, args=node.args, kwargs=node.kwargs, origin=node
|
289
296
|
)
|
290
297
|
# output permute
|
291
298
|
NHWC_to_NCHW = [0, 3, 1, 2]
|
292
|
-
maxpool_out_permute =
|
299
|
+
maxpool_out_permute = create_node(
|
300
|
+
graph,
|
293
301
|
torch.ops.aten.permute.default,
|
294
302
|
args=(circle_maxpool2d, NHWC_to_NCHW),
|
295
303
|
)
|
296
|
-
# Not set meta for propagating replacing get_item's meta.
|
297
304
|
get_item, *_ = node.users.keys()
|
298
305
|
get_item.replace_all_uses_with(maxpool_out_permute, propagate_meta=True)
|
299
306
|
|
@@ -327,21 +334,22 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
327
334
|
# TODO Introduce a method that inserts permute op.
|
328
335
|
# input permute
|
329
336
|
with graph.inserting_after(input_):
|
330
|
-
input_permute =
|
337
|
+
input_permute = create_node(
|
338
|
+
graph,
|
331
339
|
torch.ops.aten.permute.default,
|
332
340
|
args=(input_, NCHW_to_NHWC),
|
341
|
+
origin=input_,
|
333
342
|
)
|
334
343
|
node.update_arg(node.args.index(input_), input_permute)
|
335
344
|
with graph.inserting_before(node):
|
336
345
|
legalized_op = torch.ops.circle_custom.avgpool2d
|
337
|
-
circle_avgpool2d =
|
338
|
-
legalized_op,
|
339
|
-
args=node.args,
|
340
|
-
kwargs=node.kwargs,
|
346
|
+
circle_avgpool2d = create_node(
|
347
|
+
graph, legalized_op, args=node.args, kwargs=node.kwargs, origin=node
|
341
348
|
)
|
342
349
|
# output permute
|
343
350
|
NHWC_to_NCHW = [0, 3, 1, 2]
|
344
|
-
avgpool_out_permute =
|
351
|
+
avgpool_out_permute = create_node(
|
352
|
+
graph,
|
345
353
|
torch.ops.aten.permute.default,
|
346
354
|
args=(circle_avgpool2d, NHWC_to_NCHW),
|
347
355
|
)
|
tico/passes/lower_pow2_to_mul.py
CHANGED
@@ -20,6 +20,7 @@ import torch
|
|
20
20
|
from torch.export import ExportedProgram
|
21
21
|
|
22
22
|
from tico.utils import logging
|
23
|
+
from tico.utils.graph import create_node
|
23
24
|
from tico.utils.passes import PassBase, PassResult
|
24
25
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
25
26
|
from tico.utils.utils import is_target_node
|
@@ -55,7 +56,8 @@ class LowerPow2ToMul(PassBase):
|
|
55
56
|
|
56
57
|
lhs = rhs = in_
|
57
58
|
with graph.inserting_after(node):
|
58
|
-
new_mul =
|
59
|
+
new_mul = create_node(
|
60
|
+
graph,
|
59
61
|
torch.ops.aten.mul.Tensor,
|
60
62
|
args=(lhs, rhs),
|
61
63
|
kwargs={},
|
@@ -22,6 +22,7 @@ from torch.export import ExportedProgram
|
|
22
22
|
from tico.serialize.circle_mapping import extract_shape
|
23
23
|
from tico.utils import logging
|
24
24
|
from tico.utils.errors import NotYetSupportedError
|
25
|
+
from tico.utils.graph import create_node
|
25
26
|
from tico.utils.passes import PassBase, PassResult
|
26
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
28
|
from tico.utils.utils import is_target_node
|
@@ -108,18 +109,23 @@ class LowerToResizeNearestNeighbor(PassBase):
|
|
108
109
|
assert expected_shape == list(extract_shape(node))
|
109
110
|
|
110
111
|
with graph.inserting_before(node):
|
111
|
-
nchw_to_nhwc =
|
112
|
-
|
112
|
+
nchw_to_nhwc = create_node(
|
113
|
+
graph,
|
114
|
+
torch.ops.aten.permute.default,
|
115
|
+
args=(input_tensor, [0, 2, 3, 1]),
|
116
|
+
origin=input_tensor,
|
113
117
|
)
|
114
|
-
resize_nearest_neighbor =
|
118
|
+
resize_nearest_neighbor = create_node(
|
119
|
+
graph,
|
115
120
|
torch.ops.circle_custom.resize_nearest_neighbor,
|
116
121
|
args=(nchw_to_nhwc, [len(expected_H_index), len(expected_W_index)]),
|
122
|
+
origin=node,
|
117
123
|
)
|
118
|
-
nhwc_to_nchw =
|
124
|
+
nhwc_to_nchw = create_node(
|
125
|
+
graph,
|
119
126
|
torch.ops.aten.permute.default,
|
120
127
|
args=(resize_nearest_neighbor, [0, 3, 1, 2]),
|
121
128
|
)
|
122
|
-
# Not set meta for propagating replacing node's meta.
|
123
129
|
node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True)
|
124
130
|
|
125
131
|
return resize_nearest_neighbor
|
@@ -171,18 +177,23 @@ class LowerToResizeNearestNeighbor(PassBase):
|
|
171
177
|
)
|
172
178
|
|
173
179
|
with graph.inserting_before(node):
|
174
|
-
nchw_to_nhwc =
|
175
|
-
|
180
|
+
nchw_to_nhwc = create_node(
|
181
|
+
graph,
|
182
|
+
torch.ops.aten.permute.default,
|
183
|
+
args=(input_tensor, [0, 2, 3, 1]),
|
184
|
+
origin=input_tensor,
|
176
185
|
)
|
177
|
-
resize_nearest_neighbor =
|
186
|
+
resize_nearest_neighbor = create_node(
|
187
|
+
graph,
|
178
188
|
torch.ops.circle_custom.resize_nearest_neighbor,
|
179
189
|
args=(nchw_to_nhwc, [expected_H, expected_W]),
|
190
|
+
origin=node,
|
180
191
|
)
|
181
|
-
nhwc_to_nchw =
|
192
|
+
nhwc_to_nchw = create_node(
|
193
|
+
graph,
|
182
194
|
torch.ops.aten.permute.default,
|
183
195
|
args=(resize_nearest_neighbor, [0, 3, 1, 2]),
|
184
196
|
)
|
185
|
-
# Not set meta for propagating replacing node's meta.
|
186
197
|
node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True)
|
187
198
|
return resize_nearest_neighbor
|
188
199
|
|
tico/passes/lower_to_slice.py
CHANGED
@@ -30,7 +30,7 @@ from torch.export import ExportedProgram
|
|
30
30
|
from tico.passes import ops
|
31
31
|
from tico.serialize.circle_graph import extract_shape
|
32
32
|
from tico.utils import logging
|
33
|
-
from tico.utils.graph import is_single_value_tensor
|
33
|
+
from tico.utils.graph import create_node, is_single_value_tensor
|
34
34
|
from tico.utils.passes import PassBase, PassResult
|
35
35
|
from tico.utils.trace_decorators import trace_const_diff_on_pass
|
36
36
|
from tico.utils.utils import is_target_node
|
@@ -103,17 +103,22 @@ class LowerSelectCopyToSlice(PassBase):
|
|
103
103
|
|
104
104
|
with graph.inserting_after(node):
|
105
105
|
# slice
|
106
|
-
slice_node =
|
107
|
-
|
106
|
+
slice_node = create_node(
|
107
|
+
graph,
|
108
|
+
torch.ops.aten.slice.Tensor,
|
109
|
+
args=slice_copy_args,
|
110
|
+
origin=node,
|
108
111
|
)
|
109
112
|
node_shape = extract_shape(node)
|
110
113
|
with graph.inserting_after(slice_node):
|
111
114
|
# reshape
|
112
115
|
reshape_args = (slice_node, list(node_shape))
|
113
|
-
reshape_node =
|
114
|
-
|
116
|
+
reshape_node = create_node(
|
117
|
+
graph,
|
118
|
+
torch.ops.aten.reshape.default,
|
119
|
+
args=reshape_args,
|
115
120
|
)
|
116
|
-
node.replace_all_uses_with(reshape_node, propagate_meta=
|
121
|
+
node.replace_all_uses_with(reshape_node, propagate_meta=True)
|
117
122
|
|
118
123
|
modified = True
|
119
124
|
logger.debug(
|
@@ -196,17 +201,22 @@ class LowerIndexSelectToSlice(PassBase):
|
|
196
201
|
|
197
202
|
with graph.inserting_after(node):
|
198
203
|
# slice
|
199
|
-
slice_node =
|
200
|
-
|
204
|
+
slice_node = create_node(
|
205
|
+
graph,
|
206
|
+
torch.ops.aten.slice.Tensor,
|
207
|
+
args=slice_copy_args,
|
208
|
+
origin=node,
|
201
209
|
)
|
202
210
|
node_shape = extract_shape(node)
|
203
211
|
with graph.inserting_after(slice_node):
|
204
212
|
# reshape
|
205
213
|
reshape_args = (slice_node, list(node_shape))
|
206
|
-
reshape_node =
|
207
|
-
|
214
|
+
reshape_node = create_node(
|
215
|
+
graph,
|
216
|
+
torch.ops.aten.reshape.default,
|
217
|
+
args=reshape_args,
|
208
218
|
)
|
209
|
-
node.replace_all_uses_with(reshape_node, propagate_meta=
|
219
|
+
node.replace_all_uses_with(reshape_node, propagate_meta=True)
|
210
220
|
|
211
221
|
modified = True
|
212
222
|
logger.debug(
|
@@ -14,7 +14,6 @@
|
|
14
14
|
|
15
15
|
from typing import TYPE_CHECKING
|
16
16
|
|
17
|
-
|
18
17
|
if TYPE_CHECKING:
|
19
18
|
import torch.fx
|
20
19
|
import torch
|
@@ -23,6 +22,7 @@ from torch.export import ExportedProgram
|
|
23
22
|
from tico.passes import ops
|
24
23
|
from tico.serialize.circle_mapping import extract_shape
|
25
24
|
from tico.utils import logging
|
25
|
+
from tico.utils.graph import create_node
|
26
26
|
from tico.utils.passes import PassBase, PassResult
|
27
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
28
|
from tico.utils.utils import is_target_node
|
@@ -106,8 +106,10 @@ class RemoveRedundantPermutePattern1(PassBase):
|
|
106
106
|
else:
|
107
107
|
with graph.inserting_after(permute2):
|
108
108
|
new_args = (permute1_input, fused_dims)
|
109
|
-
fused_permute =
|
110
|
-
|
109
|
+
fused_permute = create_node(
|
110
|
+
graph,
|
111
|
+
torch.ops.aten.permute.default,
|
112
|
+
args=new_args,
|
111
113
|
)
|
112
114
|
permute2.replace_all_uses_with(fused_permute, propagate_meta=True)
|
113
115
|
logger.debug(f"{permute1.name} and {permute2.name} are fused.")
|
@@ -22,6 +22,7 @@ from torch.export import ExportedProgram
|
|
22
22
|
from tico.passes import ops
|
23
23
|
from tico.serialize.circle_mapping import extract_shape
|
24
24
|
from tico.utils import logging
|
25
|
+
from tico.utils.graph import create_node
|
25
26
|
from tico.utils.passes import PassBase, PassResult
|
26
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
28
|
from tico.utils.utils import broadcastable, is_target_node, set_new_meta_val
|
@@ -369,8 +370,10 @@ class RemoveRedundantReshapePattern4(PassBase):
|
|
369
370
|
assert isinstance(s, int), type(s)
|
370
371
|
|
371
372
|
with graph.inserting_before(reshape1):
|
372
|
-
fused_reshape =
|
373
|
-
|
373
|
+
fused_reshape = create_node(
|
374
|
+
graph,
|
375
|
+
reshape1.target,
|
376
|
+
(reshape1_input, reshape2_size),
|
374
377
|
)
|
375
378
|
|
376
379
|
reshape2.replace_all_uses_with(fused_reshape, propagate_meta=True)
|
@@ -70,6 +70,10 @@ class RemoveRedundantToCopy(PassBase):
|
|
70
70
|
if input_dtype != target_dtype:
|
71
71
|
continue
|
72
72
|
|
73
|
+
if hasattr(args, "memory_format") and args.memory_format is not None:
|
74
|
+
if args.memory_format != torch.contiguous_format:
|
75
|
+
continue
|
76
|
+
|
73
77
|
node.replace_all_uses_with(input_, propagate_meta=False)
|
74
78
|
|
75
79
|
modified = True
|
tico/passes/restore_linear.py
CHANGED
@@ -16,6 +16,7 @@ import torch
|
|
16
16
|
from torch.export import ExportedProgram
|
17
17
|
|
18
18
|
from tico.utils import logging
|
19
|
+
from tico.utils.graph import create_node
|
19
20
|
from tico.utils.passes import PassBase, PassResult
|
20
21
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
21
22
|
|
@@ -74,11 +75,12 @@ class RestoreLinear(PassBase):
|
|
74
75
|
|
75
76
|
addmm_args = (input, weight, bias)
|
76
77
|
with graph.inserting_after(node):
|
77
|
-
linear_node =
|
78
|
-
|
78
|
+
linear_node = create_node(
|
79
|
+
graph,
|
80
|
+
torch.ops.aten.linear.default,
|
81
|
+
args=addmm_args,
|
79
82
|
)
|
80
83
|
node.replace_all_uses_with(linear_node, propagate_meta=True)
|
81
|
-
graph.erase_node(node)
|
82
84
|
|
83
85
|
elif node.target == torch.ops.aten.mm.default:
|
84
86
|
assert len(node.args) == 2
|
@@ -97,8 +99,8 @@ class RestoreLinear(PassBase):
|
|
97
99
|
|
98
100
|
mm_args = (input, weight)
|
99
101
|
with graph.inserting_after(node):
|
100
|
-
linear_node =
|
101
|
-
torch.ops.aten.linear.default, args=mm_args
|
102
|
+
linear_node = create_node(
|
103
|
+
graph, torch.ops.aten.linear.default, args=mm_args
|
102
104
|
)
|
103
105
|
node.replace_all_uses_with(linear_node, propagate_meta=True)
|
104
106
|
|