tico 0.1.0.dev250615__py3-none-any.whl → 0.1.0.dev250617__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +6 -2
- tico/passes/cast_aten_where_arg_type.py +4 -1
- tico/passes/cast_mixed_type_args.py +4 -1
- tico/passes/convert_conv1d_to_conv2d.py +12 -4
- tico/passes/convert_layout_op_to_reshape.py +3 -2
- tico/passes/convert_repeat_to_expand_copy.py +5 -2
- tico/passes/convert_to_relu6.py +4 -3
- tico/passes/decompose_addmm.py +11 -7
- tico/passes/decompose_batch_norm.py +7 -11
- tico/passes/decompose_fake_quantize.py +12 -6
- tico/passes/decompose_fake_quantize_tensor_qparams.py +12 -6
- tico/passes/decompose_group_norm.py +50 -21
- tico/passes/decompose_grouped_conv2d.py +15 -7
- tico/passes/decompose_slice_scatter.py +9 -5
- tico/passes/fuse_leading_unsqueeze_reshape.py +8 -3
- tico/passes/legalize_predefined_layout_operators.py +33 -25
- tico/passes/lower_pow2_to_mul.py +3 -1
- tico/passes/lower_to_resize_nearest_neighbor.py +21 -10
- tico/passes/lower_to_slice.py +21 -11
- tico/passes/remove_redundant_permute.py +5 -3
- tico/passes/remove_redundant_reshape.py +5 -2
- tico/passes/remove_redundant_to_copy.py +4 -0
- tico/passes/restore_linear.py +7 -5
- tico/passes/segment_index_select.py +9 -5
- tico/utils/graph.py +48 -2
- {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/METADATA +7 -2
- {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/RECORD +32 -32
- {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ from torch.export import ExportedProgram
|
|
23
23
|
from tico.serialize.circle_mapping import extract_shape
|
24
24
|
|
25
25
|
from tico.utils import logging
|
26
|
+
from tico.utils.graph import create_node
|
26
27
|
from tico.utils.passes import PassBase, PassResult
|
27
28
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
29
|
from tico.utils.utils import enforce_type, is_target_node
|
@@ -130,16 +131,19 @@ class DecomposeSliceScatter(PassBase):
|
|
130
131
|
slices = []
|
131
132
|
|
132
133
|
if 0 < start:
|
133
|
-
slice_0 =
|
134
|
+
slice_0 = create_node(
|
135
|
+
graph,
|
134
136
|
torch.ops.aten.slice_copy.Tensor,
|
135
137
|
args=(input, dim, 0, start, 1),
|
138
|
+
origin=node,
|
136
139
|
)
|
137
140
|
slices.append(slice_0)
|
138
141
|
|
139
142
|
slices.append(src)
|
140
143
|
|
141
144
|
if start + end < extract_shape(input)[dim]:
|
142
|
-
slice_1 =
|
145
|
+
slice_1 = create_node(
|
146
|
+
graph,
|
143
147
|
torch.ops.aten.slice_copy.Tensor,
|
144
148
|
args=(
|
145
149
|
input,
|
@@ -148,13 +152,13 @@ class DecomposeSliceScatter(PassBase):
|
|
148
152
|
extract_shape(input)[dim],
|
149
153
|
1,
|
150
154
|
),
|
155
|
+
origin=node,
|
151
156
|
)
|
152
157
|
slices.append(slice_1)
|
153
158
|
|
154
|
-
concat =
|
155
|
-
torch.ops.aten.cat.default, args=(slices, dim)
|
159
|
+
concat = create_node(
|
160
|
+
graph, torch.ops.aten.cat.default, args=(slices, dim)
|
156
161
|
)
|
157
|
-
# Not set meta for propagating replacing node's meta.
|
158
162
|
node.replace_all_uses_with(concat, propagate_meta=True)
|
159
163
|
|
160
164
|
modified = True
|
@@ -20,6 +20,7 @@ from torch.export import ExportedProgram
|
|
20
20
|
from tico.passes import ops
|
21
21
|
from tico.serialize.circle_mapping import extract_shape
|
22
22
|
from tico.utils import logging
|
23
|
+
from tico.utils.graph import create_node
|
23
24
|
from tico.utils.passes import PassBase, PassResult
|
24
25
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
25
26
|
from tico.utils.utils import is_target_node
|
@@ -84,15 +85,19 @@ class FuseLeadingUnsqueezeReshape(PassBase):
|
|
84
85
|
k = len(back_shape) - len(permute_shape)
|
85
86
|
with graph.inserting_before(permute):
|
86
87
|
new_shape = [1] * k + list(reshape_front_size)
|
87
|
-
r_new =
|
88
|
+
r_new = create_node(
|
89
|
+
graph,
|
88
90
|
torch.ops.aten.reshape.default,
|
89
91
|
args=(reshape_front_input, new_shape),
|
92
|
+
origin=reshape_back,
|
90
93
|
)
|
91
94
|
new_p_dims = list(range(k)) + [
|
92
95
|
d + k for d in permute_dims
|
93
96
|
] # shift by k
|
94
|
-
p_new =
|
95
|
-
|
97
|
+
p_new = create_node(
|
98
|
+
graph,
|
99
|
+
torch.ops.aten.permute.default,
|
100
|
+
args=(r_new, new_p_dims),
|
96
101
|
)
|
97
102
|
|
98
103
|
reshape_back.replace_all_uses_with(p_new, propagate_meta=True)
|
@@ -23,6 +23,7 @@ from torch.export import ExportedProgram
|
|
23
23
|
from tico.serialize.circle_graph import extract_shape
|
24
24
|
from tico.utils import logging
|
25
25
|
from tico.utils.errors import NotYetSupportedError
|
26
|
+
from tico.utils.graph import create_node
|
26
27
|
from tico.utils.passes import PassBase, PassResult
|
27
28
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
29
|
from tico.utils.utils import is_target_node
|
@@ -124,9 +125,11 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
124
125
|
# TODO Introduce a method that inserts permute op.
|
125
126
|
# input permute
|
126
127
|
with graph.inserting_after(input):
|
127
|
-
input_permute =
|
128
|
+
input_permute = create_node(
|
129
|
+
graph,
|
128
130
|
torch.ops.aten.permute.default,
|
129
131
|
args=(input, NCHW_to_NHWC),
|
132
|
+
origin=input,
|
130
133
|
)
|
131
134
|
node.update_arg(node.args.index(input), input_permute)
|
132
135
|
|
@@ -142,9 +145,11 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
142
145
|
else:
|
143
146
|
assert groups == 1 or groups == input_shape[1] # Cannot reach here
|
144
147
|
|
145
|
-
weight_permute =
|
148
|
+
weight_permute = create_node(
|
149
|
+
graph,
|
146
150
|
torch.ops.aten.permute.default,
|
147
151
|
args=(weight, perm),
|
152
|
+
origin=weight,
|
148
153
|
)
|
149
154
|
if args.weight.target in [
|
150
155
|
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
@@ -171,18 +176,16 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
171
176
|
else:
|
172
177
|
assert groups == 1 or groups == input_shape[1] # Cannot reach here
|
173
178
|
|
174
|
-
circle_op =
|
175
|
-
legalized_op,
|
176
|
-
args=node.args,
|
177
|
-
kwargs=node.kwargs,
|
179
|
+
circle_op = create_node(
|
180
|
+
graph, legalized_op, args=node.args, kwargs=node.kwargs, origin=node
|
178
181
|
)
|
179
182
|
# output permute
|
180
183
|
NHWC_to_NCHW = [0, 3, 1, 2]
|
181
|
-
conv_out_permute =
|
184
|
+
conv_out_permute = create_node(
|
185
|
+
graph,
|
182
186
|
torch.ops.aten.permute.default,
|
183
187
|
args=(circle_op, NHWC_to_NCHW),
|
184
188
|
)
|
185
|
-
# Not set meta for propagating replacing node's meta.
|
186
189
|
node.replace_all_uses_with(conv_out_permute, propagate_meta=True)
|
187
190
|
|
188
191
|
logger.debug(f"{node.name} is replaced with {circle_op.name}")
|
@@ -224,25 +227,29 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
224
227
|
with graph.inserting_after(input):
|
225
228
|
# input permute
|
226
229
|
NCHW_to_NHWC = [0, 2, 3, 1]
|
227
|
-
input_permute =
|
230
|
+
input_permute = create_node(
|
231
|
+
graph,
|
228
232
|
torch.ops.aten.permute.default,
|
229
233
|
args=(input, NCHW_to_NHWC),
|
234
|
+
origin=input,
|
230
235
|
)
|
231
236
|
node.update_arg(node.args.index(input), input_permute)
|
232
237
|
with graph.inserting_before(node):
|
233
238
|
# circle instnorm
|
234
|
-
circle_instnorm =
|
239
|
+
circle_instnorm = create_node(
|
240
|
+
graph,
|
235
241
|
torch.ops.circle_custom.instance_norm,
|
236
242
|
args=node.args,
|
237
243
|
kwargs=node.kwargs,
|
244
|
+
origin=node,
|
238
245
|
)
|
239
246
|
# output permute
|
240
247
|
NHWC_to_NCHW = [0, 3, 1, 2]
|
241
|
-
instnorm_out_permute =
|
248
|
+
instnorm_out_permute = create_node(
|
249
|
+
graph,
|
242
250
|
torch.ops.aten.permute.default,
|
243
251
|
args=(circle_instnorm, NHWC_to_NCHW),
|
244
252
|
)
|
245
|
-
# Not set meta for propagating replacing node's meta.
|
246
253
|
node.replace_all_uses_with(instnorm_out_permute, propagate_meta=True)
|
247
254
|
|
248
255
|
logger.debug(f"{node.name} is replaced with {circle_instnorm.name}")
|
@@ -275,25 +282,25 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
275
282
|
# TODO Introduce a method that inserts permute op.
|
276
283
|
# input permute
|
277
284
|
with graph.inserting_after(input_):
|
278
|
-
input_permute =
|
285
|
+
input_permute = create_node(
|
286
|
+
graph,
|
279
287
|
torch.ops.aten.permute.default,
|
280
288
|
args=(input_, NCHW_to_NHWC),
|
289
|
+
origin=input_,
|
281
290
|
)
|
282
291
|
node.update_arg(node.args.index(input_), input_permute)
|
283
292
|
with graph.inserting_before(node):
|
284
293
|
legalized_op = torch.ops.circle_custom.maxpool2d
|
285
|
-
circle_maxpool2d =
|
286
|
-
legalized_op,
|
287
|
-
args=node.args,
|
288
|
-
kwargs=node.kwargs,
|
294
|
+
circle_maxpool2d = create_node(
|
295
|
+
graph, legalized_op, args=node.args, kwargs=node.kwargs, origin=node
|
289
296
|
)
|
290
297
|
# output permute
|
291
298
|
NHWC_to_NCHW = [0, 3, 1, 2]
|
292
|
-
maxpool_out_permute =
|
299
|
+
maxpool_out_permute = create_node(
|
300
|
+
graph,
|
293
301
|
torch.ops.aten.permute.default,
|
294
302
|
args=(circle_maxpool2d, NHWC_to_NCHW),
|
295
303
|
)
|
296
|
-
# Not set meta for propagating replacing get_item's meta.
|
297
304
|
get_item, *_ = node.users.keys()
|
298
305
|
get_item.replace_all_uses_with(maxpool_out_permute, propagate_meta=True)
|
299
306
|
|
@@ -327,21 +334,22 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
327
334
|
# TODO Introduce a method that inserts permute op.
|
328
335
|
# input permute
|
329
336
|
with graph.inserting_after(input_):
|
330
|
-
input_permute =
|
337
|
+
input_permute = create_node(
|
338
|
+
graph,
|
331
339
|
torch.ops.aten.permute.default,
|
332
340
|
args=(input_, NCHW_to_NHWC),
|
341
|
+
origin=input_,
|
333
342
|
)
|
334
343
|
node.update_arg(node.args.index(input_), input_permute)
|
335
344
|
with graph.inserting_before(node):
|
336
345
|
legalized_op = torch.ops.circle_custom.avgpool2d
|
337
|
-
circle_avgpool2d =
|
338
|
-
legalized_op,
|
339
|
-
args=node.args,
|
340
|
-
kwargs=node.kwargs,
|
346
|
+
circle_avgpool2d = create_node(
|
347
|
+
graph, legalized_op, args=node.args, kwargs=node.kwargs, origin=node
|
341
348
|
)
|
342
349
|
# output permute
|
343
350
|
NHWC_to_NCHW = [0, 3, 1, 2]
|
344
|
-
avgpool_out_permute =
|
351
|
+
avgpool_out_permute = create_node(
|
352
|
+
graph,
|
345
353
|
torch.ops.aten.permute.default,
|
346
354
|
args=(circle_avgpool2d, NHWC_to_NCHW),
|
347
355
|
)
|
tico/passes/lower_pow2_to_mul.py
CHANGED
@@ -20,6 +20,7 @@ import torch
|
|
20
20
|
from torch.export import ExportedProgram
|
21
21
|
|
22
22
|
from tico.utils import logging
|
23
|
+
from tico.utils.graph import create_node
|
23
24
|
from tico.utils.passes import PassBase, PassResult
|
24
25
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
25
26
|
from tico.utils.utils import is_target_node
|
@@ -55,7 +56,8 @@ class LowerPow2ToMul(PassBase):
|
|
55
56
|
|
56
57
|
lhs = rhs = in_
|
57
58
|
with graph.inserting_after(node):
|
58
|
-
new_mul =
|
59
|
+
new_mul = create_node(
|
60
|
+
graph,
|
59
61
|
torch.ops.aten.mul.Tensor,
|
60
62
|
args=(lhs, rhs),
|
61
63
|
kwargs={},
|
@@ -22,6 +22,7 @@ from torch.export import ExportedProgram
|
|
22
22
|
from tico.serialize.circle_mapping import extract_shape
|
23
23
|
from tico.utils import logging
|
24
24
|
from tico.utils.errors import NotYetSupportedError
|
25
|
+
from tico.utils.graph import create_node
|
25
26
|
from tico.utils.passes import PassBase, PassResult
|
26
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
28
|
from tico.utils.utils import is_target_node
|
@@ -108,18 +109,23 @@ class LowerToResizeNearestNeighbor(PassBase):
|
|
108
109
|
assert expected_shape == list(extract_shape(node))
|
109
110
|
|
110
111
|
with graph.inserting_before(node):
|
111
|
-
nchw_to_nhwc =
|
112
|
-
|
112
|
+
nchw_to_nhwc = create_node(
|
113
|
+
graph,
|
114
|
+
torch.ops.aten.permute.default,
|
115
|
+
args=(input_tensor, [0, 2, 3, 1]),
|
116
|
+
origin=input_tensor,
|
113
117
|
)
|
114
|
-
resize_nearest_neighbor =
|
118
|
+
resize_nearest_neighbor = create_node(
|
119
|
+
graph,
|
115
120
|
torch.ops.circle_custom.resize_nearest_neighbor,
|
116
121
|
args=(nchw_to_nhwc, [len(expected_H_index), len(expected_W_index)]),
|
122
|
+
origin=node,
|
117
123
|
)
|
118
|
-
nhwc_to_nchw =
|
124
|
+
nhwc_to_nchw = create_node(
|
125
|
+
graph,
|
119
126
|
torch.ops.aten.permute.default,
|
120
127
|
args=(resize_nearest_neighbor, [0, 3, 1, 2]),
|
121
128
|
)
|
122
|
-
# Not set meta for propagating replacing node's meta.
|
123
129
|
node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True)
|
124
130
|
|
125
131
|
return resize_nearest_neighbor
|
@@ -171,18 +177,23 @@ class LowerToResizeNearestNeighbor(PassBase):
|
|
171
177
|
)
|
172
178
|
|
173
179
|
with graph.inserting_before(node):
|
174
|
-
nchw_to_nhwc =
|
175
|
-
|
180
|
+
nchw_to_nhwc = create_node(
|
181
|
+
graph,
|
182
|
+
torch.ops.aten.permute.default,
|
183
|
+
args=(input_tensor, [0, 2, 3, 1]),
|
184
|
+
origin=input_tensor,
|
176
185
|
)
|
177
|
-
resize_nearest_neighbor =
|
186
|
+
resize_nearest_neighbor = create_node(
|
187
|
+
graph,
|
178
188
|
torch.ops.circle_custom.resize_nearest_neighbor,
|
179
189
|
args=(nchw_to_nhwc, [expected_H, expected_W]),
|
190
|
+
origin=node,
|
180
191
|
)
|
181
|
-
nhwc_to_nchw =
|
192
|
+
nhwc_to_nchw = create_node(
|
193
|
+
graph,
|
182
194
|
torch.ops.aten.permute.default,
|
183
195
|
args=(resize_nearest_neighbor, [0, 3, 1, 2]),
|
184
196
|
)
|
185
|
-
# Not set meta for propagating replacing node's meta.
|
186
197
|
node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True)
|
187
198
|
return resize_nearest_neighbor
|
188
199
|
|
tico/passes/lower_to_slice.py
CHANGED
@@ -30,7 +30,7 @@ from torch.export import ExportedProgram
|
|
30
30
|
from tico.passes import ops
|
31
31
|
from tico.serialize.circle_graph import extract_shape
|
32
32
|
from tico.utils import logging
|
33
|
-
from tico.utils.graph import is_single_value_tensor
|
33
|
+
from tico.utils.graph import create_node, is_single_value_tensor
|
34
34
|
from tico.utils.passes import PassBase, PassResult
|
35
35
|
from tico.utils.trace_decorators import trace_const_diff_on_pass
|
36
36
|
from tico.utils.utils import is_target_node
|
@@ -103,17 +103,22 @@ class LowerSelectCopyToSlice(PassBase):
|
|
103
103
|
|
104
104
|
with graph.inserting_after(node):
|
105
105
|
# slice
|
106
|
-
slice_node =
|
107
|
-
|
106
|
+
slice_node = create_node(
|
107
|
+
graph,
|
108
|
+
torch.ops.aten.slice.Tensor,
|
109
|
+
args=slice_copy_args,
|
110
|
+
origin=node,
|
108
111
|
)
|
109
112
|
node_shape = extract_shape(node)
|
110
113
|
with graph.inserting_after(slice_node):
|
111
114
|
# reshape
|
112
115
|
reshape_args = (slice_node, list(node_shape))
|
113
|
-
reshape_node =
|
114
|
-
|
116
|
+
reshape_node = create_node(
|
117
|
+
graph,
|
118
|
+
torch.ops.aten.reshape.default,
|
119
|
+
args=reshape_args,
|
115
120
|
)
|
116
|
-
node.replace_all_uses_with(reshape_node, propagate_meta=
|
121
|
+
node.replace_all_uses_with(reshape_node, propagate_meta=True)
|
117
122
|
|
118
123
|
modified = True
|
119
124
|
logger.debug(
|
@@ -196,17 +201,22 @@ class LowerIndexSelectToSlice(PassBase):
|
|
196
201
|
|
197
202
|
with graph.inserting_after(node):
|
198
203
|
# slice
|
199
|
-
slice_node =
|
200
|
-
|
204
|
+
slice_node = create_node(
|
205
|
+
graph,
|
206
|
+
torch.ops.aten.slice.Tensor,
|
207
|
+
args=slice_copy_args,
|
208
|
+
origin=node,
|
201
209
|
)
|
202
210
|
node_shape = extract_shape(node)
|
203
211
|
with graph.inserting_after(slice_node):
|
204
212
|
# reshape
|
205
213
|
reshape_args = (slice_node, list(node_shape))
|
206
|
-
reshape_node =
|
207
|
-
|
214
|
+
reshape_node = create_node(
|
215
|
+
graph,
|
216
|
+
torch.ops.aten.reshape.default,
|
217
|
+
args=reshape_args,
|
208
218
|
)
|
209
|
-
node.replace_all_uses_with(reshape_node, propagate_meta=
|
219
|
+
node.replace_all_uses_with(reshape_node, propagate_meta=True)
|
210
220
|
|
211
221
|
modified = True
|
212
222
|
logger.debug(
|
@@ -14,7 +14,6 @@
|
|
14
14
|
|
15
15
|
from typing import TYPE_CHECKING
|
16
16
|
|
17
|
-
|
18
17
|
if TYPE_CHECKING:
|
19
18
|
import torch.fx
|
20
19
|
import torch
|
@@ -23,6 +22,7 @@ from torch.export import ExportedProgram
|
|
23
22
|
from tico.passes import ops
|
24
23
|
from tico.serialize.circle_mapping import extract_shape
|
25
24
|
from tico.utils import logging
|
25
|
+
from tico.utils.graph import create_node
|
26
26
|
from tico.utils.passes import PassBase, PassResult
|
27
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
28
|
from tico.utils.utils import is_target_node
|
@@ -106,8 +106,10 @@ class RemoveRedundantPermutePattern1(PassBase):
|
|
106
106
|
else:
|
107
107
|
with graph.inserting_after(permute2):
|
108
108
|
new_args = (permute1_input, fused_dims)
|
109
|
-
fused_permute =
|
110
|
-
|
109
|
+
fused_permute = create_node(
|
110
|
+
graph,
|
111
|
+
torch.ops.aten.permute.default,
|
112
|
+
args=new_args,
|
111
113
|
)
|
112
114
|
permute2.replace_all_uses_with(fused_permute, propagate_meta=True)
|
113
115
|
logger.debug(f"{permute1.name} and {permute2.name} are fused.")
|
@@ -22,6 +22,7 @@ from torch.export import ExportedProgram
|
|
22
22
|
from tico.passes import ops
|
23
23
|
from tico.serialize.circle_mapping import extract_shape
|
24
24
|
from tico.utils import logging
|
25
|
+
from tico.utils.graph import create_node
|
25
26
|
from tico.utils.passes import PassBase, PassResult
|
26
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
28
|
from tico.utils.utils import broadcastable, is_target_node, set_new_meta_val
|
@@ -369,8 +370,10 @@ class RemoveRedundantReshapePattern4(PassBase):
|
|
369
370
|
assert isinstance(s, int), type(s)
|
370
371
|
|
371
372
|
with graph.inserting_before(reshape1):
|
372
|
-
fused_reshape =
|
373
|
-
|
373
|
+
fused_reshape = create_node(
|
374
|
+
graph,
|
375
|
+
reshape1.target,
|
376
|
+
(reshape1_input, reshape2_size),
|
374
377
|
)
|
375
378
|
|
376
379
|
reshape2.replace_all_uses_with(fused_reshape, propagate_meta=True)
|
@@ -70,6 +70,10 @@ class RemoveRedundantToCopy(PassBase):
|
|
70
70
|
if input_dtype != target_dtype:
|
71
71
|
continue
|
72
72
|
|
73
|
+
if hasattr(args, "memory_format") and args.memory_format is not None:
|
74
|
+
if args.memory_format != torch.contiguous_format:
|
75
|
+
continue
|
76
|
+
|
73
77
|
node.replace_all_uses_with(input_, propagate_meta=False)
|
74
78
|
|
75
79
|
modified = True
|
tico/passes/restore_linear.py
CHANGED
@@ -16,6 +16,7 @@ import torch
|
|
16
16
|
from torch.export import ExportedProgram
|
17
17
|
|
18
18
|
from tico.utils import logging
|
19
|
+
from tico.utils.graph import create_node
|
19
20
|
from tico.utils.passes import PassBase, PassResult
|
20
21
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
21
22
|
|
@@ -74,11 +75,12 @@ class RestoreLinear(PassBase):
|
|
74
75
|
|
75
76
|
addmm_args = (input, weight, bias)
|
76
77
|
with graph.inserting_after(node):
|
77
|
-
linear_node =
|
78
|
-
|
78
|
+
linear_node = create_node(
|
79
|
+
graph,
|
80
|
+
torch.ops.aten.linear.default,
|
81
|
+
args=addmm_args,
|
79
82
|
)
|
80
83
|
node.replace_all_uses_with(linear_node, propagate_meta=True)
|
81
|
-
graph.erase_node(node)
|
82
84
|
|
83
85
|
elif node.target == torch.ops.aten.mm.default:
|
84
86
|
assert len(node.args) == 2
|
@@ -97,8 +99,8 @@ class RestoreLinear(PassBase):
|
|
97
99
|
|
98
100
|
mm_args = (input, weight)
|
99
101
|
with graph.inserting_after(node):
|
100
|
-
linear_node =
|
101
|
-
torch.ops.aten.linear.default, args=mm_args
|
102
|
+
linear_node = create_node(
|
103
|
+
graph, torch.ops.aten.linear.default, args=mm_args
|
102
104
|
)
|
103
105
|
node.replace_all_uses_with(linear_node, propagate_meta=True)
|
104
106
|
|
@@ -31,7 +31,7 @@ from torch.export import ExportedProgram
|
|
31
31
|
from tico.passes import ops
|
32
32
|
from tico.serialize.circle_graph import extract_shape
|
33
33
|
from tico.utils import logging
|
34
|
-
from tico.utils.graph import add_placeholder, is_single_value_tensor
|
34
|
+
from tico.utils.graph import add_placeholder, create_node, is_single_value_tensor
|
35
35
|
from tico.utils.passes import PassBase, PassResult
|
36
36
|
from tico.utils.trace_decorators import trace_const_diff_on_pass
|
37
37
|
from tico.utils.utils import is_target_node
|
@@ -116,18 +116,22 @@ class SegmentIndexSelectConst(PassBase):
|
|
116
116
|
exported_program, torch.tensor([i]), prefix="segm_index"
|
117
117
|
)
|
118
118
|
with graph.inserting_before(node):
|
119
|
-
index_select_node =
|
119
|
+
index_select_node = create_node(
|
120
|
+
graph,
|
120
121
|
torch.ops.aten.index_select.default,
|
121
122
|
args=(input, dim, index_node),
|
123
|
+
origin=node,
|
122
124
|
)
|
123
125
|
index_select_node_list.append(index_select_node)
|
124
126
|
|
125
127
|
with graph.inserting_before(node):
|
126
|
-
concat_node =
|
127
|
-
|
128
|
+
concat_node = create_node(
|
129
|
+
graph,
|
130
|
+
torch.ops.aten.cat.default,
|
131
|
+
args=(index_select_node_list, dim),
|
128
132
|
)
|
129
133
|
|
130
|
-
node.replace_all_uses_with(concat_node, propagate_meta=
|
134
|
+
node.replace_all_uses_with(concat_node, propagate_meta=True)
|
131
135
|
|
132
136
|
modified = True
|
133
137
|
logger.debug(
|
tico/utils/graph.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16
16
|
# See the License for the specific language governing permissions and
|
17
17
|
# limitations under the License.
|
18
18
|
|
19
|
-
from typing import Optional, TYPE_CHECKING
|
19
|
+
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
|
20
20
|
|
21
21
|
if TYPE_CHECKING:
|
22
22
|
import torch.fx
|
@@ -24,7 +24,7 @@ import torch
|
|
24
24
|
from torch.export import ExportedProgram
|
25
25
|
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
|
26
26
|
|
27
|
-
from tico.utils.utils import get_fake_mode
|
27
|
+
from tico.utils.utils import get_fake_mode, set_new_meta_val
|
28
28
|
|
29
29
|
|
30
30
|
def is_torch_param(node: torch.fx.Node, ep: ExportedProgram):
|
@@ -234,3 +234,49 @@ def get_module_name_chain(node: Optional[torch.fx.Node]) -> str:
|
|
234
234
|
return next(reversed(stack.values()))[1]
|
235
235
|
else:
|
236
236
|
return "unknown"
|
237
|
+
|
238
|
+
|
239
|
+
def create_node(
|
240
|
+
graph: torch.fx.Graph,
|
241
|
+
target: torch._ops.OpOverload,
|
242
|
+
args: Optional[Tuple[Any, ...]] = None,
|
243
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
244
|
+
*,
|
245
|
+
origin: Optional[torch.fx.Node] = None,
|
246
|
+
) -> torch.fx.Node:
|
247
|
+
"""
|
248
|
+
Insert a new node into graph and propagate metadata from *origin*.
|
249
|
+
|
250
|
+
Parameters
|
251
|
+
----------
|
252
|
+
graph : torch.fx.Graph
|
253
|
+
The graph that will own the newly-created node.
|
254
|
+
|
255
|
+
target : torch._ops.OpOverload
|
256
|
+
The op to call (e.g. `torch.add` or "call_function" target).
|
257
|
+
|
258
|
+
args : Tuple[Any, ...], optional
|
259
|
+
Positional arguments for the new node.
|
260
|
+
|
261
|
+
kwargs : Dict[str, Any], optional
|
262
|
+
Keyword arguments for the new node.
|
263
|
+
|
264
|
+
origin : torch.fx.Node, optional
|
265
|
+
If given, every key in `origin.meta` **except** "val" is copied
|
266
|
+
onto the new node. "val" is recomputed from *args* /*kwargs* using
|
267
|
+
the internal meta-inference helper.
|
268
|
+
|
269
|
+
Returns
|
270
|
+
-------
|
271
|
+
torch.fx.Node
|
272
|
+
The freshly inserted node with fully-populated `.meta`.
|
273
|
+
"""
|
274
|
+
new_node = graph.call_function(target, args=args, kwargs=kwargs)
|
275
|
+
if origin:
|
276
|
+
assert isinstance(origin, torch.fx.Node), type(origin)
|
277
|
+
# Propagate "nn_module_stack" to retain the originating module context
|
278
|
+
# for meaningful node names.
|
279
|
+
if "nn_module_stack" in origin.meta:
|
280
|
+
new_node.meta["nn_module_stack"] = origin.meta["nn_module_stack"]
|
281
|
+
|
282
|
+
return new_node
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: tico
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.dev250617
|
4
4
|
Summary: Convert exported Torch module to circle
|
5
5
|
Home-page: UNKNOWN
|
6
6
|
License: UNKNOWN
|
@@ -105,10 +105,15 @@ You can convert a torch module to a circle model with these steps.
|
|
105
105
|
torch_module = AddModule()
|
106
106
|
example_inputs = (torch.ones(4), torch.ones(4))
|
107
107
|
|
108
|
-
circle_model = tico.convert(torch_module, example_inputs)
|
108
|
+
circle_model = tico.convert(torch_module.eval(), example_inputs)
|
109
109
|
circle_model.save('add.circle')
|
110
110
|
```
|
111
111
|
|
112
|
+
**NOTE**
|
113
|
+
Please make sure to call `eval()` on the PyTorch module before passing it to our API.
|
114
|
+
This ensures the model runs in inference mode, disabling layers like dropout and
|
115
|
+
batch normalization updates.
|
116
|
+
|
112
117
|
**Compile with configuration**
|
113
118
|
|
114
119
|
```python
|