tico 0.1.0.dev250615__py3-none-any.whl → 0.1.0.dev250617__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. tico/__init__.py +1 -1
  2. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +6 -2
  3. tico/passes/cast_aten_where_arg_type.py +4 -1
  4. tico/passes/cast_mixed_type_args.py +4 -1
  5. tico/passes/convert_conv1d_to_conv2d.py +12 -4
  6. tico/passes/convert_layout_op_to_reshape.py +3 -2
  7. tico/passes/convert_repeat_to_expand_copy.py +5 -2
  8. tico/passes/convert_to_relu6.py +4 -3
  9. tico/passes/decompose_addmm.py +11 -7
  10. tico/passes/decompose_batch_norm.py +7 -11
  11. tico/passes/decompose_fake_quantize.py +12 -6
  12. tico/passes/decompose_fake_quantize_tensor_qparams.py +12 -6
  13. tico/passes/decompose_group_norm.py +50 -21
  14. tico/passes/decompose_grouped_conv2d.py +15 -7
  15. tico/passes/decompose_slice_scatter.py +9 -5
  16. tico/passes/fuse_leading_unsqueeze_reshape.py +8 -3
  17. tico/passes/legalize_predefined_layout_operators.py +33 -25
  18. tico/passes/lower_pow2_to_mul.py +3 -1
  19. tico/passes/lower_to_resize_nearest_neighbor.py +21 -10
  20. tico/passes/lower_to_slice.py +21 -11
  21. tico/passes/remove_redundant_permute.py +5 -3
  22. tico/passes/remove_redundant_reshape.py +5 -2
  23. tico/passes/remove_redundant_to_copy.py +4 -0
  24. tico/passes/restore_linear.py +7 -5
  25. tico/passes/segment_index_select.py +9 -5
  26. tico/utils/graph.py +48 -2
  27. {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/METADATA +7 -2
  28. {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/RECORD +32 -32
  29. {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/LICENSE +0 -0
  30. {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/WHEEL +0 -0
  31. {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/entry_points.txt +0 -0
  32. {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ from torch.export import ExportedProgram
23
23
  from tico.serialize.circle_mapping import extract_shape
24
24
 
25
25
  from tico.utils import logging
26
+ from tico.utils.graph import create_node
26
27
  from tico.utils.passes import PassBase, PassResult
27
28
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
29
  from tico.utils.utils import enforce_type, is_target_node
@@ -130,16 +131,19 @@ class DecomposeSliceScatter(PassBase):
130
131
  slices = []
131
132
 
132
133
  if 0 < start:
133
- slice_0 = graph.call_function(
134
+ slice_0 = create_node(
135
+ graph,
134
136
  torch.ops.aten.slice_copy.Tensor,
135
137
  args=(input, dim, 0, start, 1),
138
+ origin=node,
136
139
  )
137
140
  slices.append(slice_0)
138
141
 
139
142
  slices.append(src)
140
143
 
141
144
  if start + end < extract_shape(input)[dim]:
142
- slice_1 = graph.call_function(
145
+ slice_1 = create_node(
146
+ graph,
143
147
  torch.ops.aten.slice_copy.Tensor,
144
148
  args=(
145
149
  input,
@@ -148,13 +152,13 @@ class DecomposeSliceScatter(PassBase):
148
152
  extract_shape(input)[dim],
149
153
  1,
150
154
  ),
155
+ origin=node,
151
156
  )
152
157
  slices.append(slice_1)
153
158
 
154
- concat = graph.call_function(
155
- torch.ops.aten.cat.default, args=(slices, dim)
159
+ concat = create_node(
160
+ graph, torch.ops.aten.cat.default, args=(slices, dim)
156
161
  )
157
- # Not set meta for propagating replacing node's meta.
158
162
  node.replace_all_uses_with(concat, propagate_meta=True)
159
163
 
160
164
  modified = True
@@ -20,6 +20,7 @@ from torch.export import ExportedProgram
20
20
  from tico.passes import ops
21
21
  from tico.serialize.circle_mapping import extract_shape
22
22
  from tico.utils import logging
23
+ from tico.utils.graph import create_node
23
24
  from tico.utils.passes import PassBase, PassResult
24
25
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
26
  from tico.utils.utils import is_target_node
@@ -84,15 +85,19 @@ class FuseLeadingUnsqueezeReshape(PassBase):
84
85
  k = len(back_shape) - len(permute_shape)
85
86
  with graph.inserting_before(permute):
86
87
  new_shape = [1] * k + list(reshape_front_size)
87
- r_new = graph.call_function(
88
+ r_new = create_node(
89
+ graph,
88
90
  torch.ops.aten.reshape.default,
89
91
  args=(reshape_front_input, new_shape),
92
+ origin=reshape_back,
90
93
  )
91
94
  new_p_dims = list(range(k)) + [
92
95
  d + k for d in permute_dims
93
96
  ] # shift by k
94
- p_new = graph.call_function(
95
- torch.ops.aten.permute.default, args=(r_new, new_p_dims)
97
+ p_new = create_node(
98
+ graph,
99
+ torch.ops.aten.permute.default,
100
+ args=(r_new, new_p_dims),
96
101
  )
97
102
 
98
103
  reshape_back.replace_all_uses_with(p_new, propagate_meta=True)
@@ -23,6 +23,7 @@ from torch.export import ExportedProgram
23
23
  from tico.serialize.circle_graph import extract_shape
24
24
  from tico.utils import logging
25
25
  from tico.utils.errors import NotYetSupportedError
26
+ from tico.utils.graph import create_node
26
27
  from tico.utils.passes import PassBase, PassResult
27
28
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
29
  from tico.utils.utils import is_target_node
@@ -124,9 +125,11 @@ class LegalizePreDefinedLayoutOperators(PassBase):
124
125
  # TODO Introduce a method that inserts permute op.
125
126
  # input permute
126
127
  with graph.inserting_after(input):
127
- input_permute = graph_module.graph.call_function(
128
+ input_permute = create_node(
129
+ graph,
128
130
  torch.ops.aten.permute.default,
129
131
  args=(input, NCHW_to_NHWC),
132
+ origin=input,
130
133
  )
131
134
  node.update_arg(node.args.index(input), input_permute)
132
135
 
@@ -142,9 +145,11 @@ class LegalizePreDefinedLayoutOperators(PassBase):
142
145
  else:
143
146
  assert groups == 1 or groups == input_shape[1] # Cannot reach here
144
147
 
145
- weight_permute = graph_module.graph.call_function(
148
+ weight_permute = create_node(
149
+ graph,
146
150
  torch.ops.aten.permute.default,
147
151
  args=(weight, perm),
152
+ origin=weight,
148
153
  )
149
154
  if args.weight.target in [
150
155
  torch.ops.quantized_decomposed.dequantize_per_channel.default,
@@ -171,18 +176,16 @@ class LegalizePreDefinedLayoutOperators(PassBase):
171
176
  else:
172
177
  assert groups == 1 or groups == input_shape[1] # Cannot reach here
173
178
 
174
- circle_op = graph_module.graph.call_function(
175
- legalized_op,
176
- args=node.args,
177
- kwargs=node.kwargs,
179
+ circle_op = create_node(
180
+ graph, legalized_op, args=node.args, kwargs=node.kwargs, origin=node
178
181
  )
179
182
  # output permute
180
183
  NHWC_to_NCHW = [0, 3, 1, 2]
181
- conv_out_permute = graph_module.graph.call_function(
184
+ conv_out_permute = create_node(
185
+ graph,
182
186
  torch.ops.aten.permute.default,
183
187
  args=(circle_op, NHWC_to_NCHW),
184
188
  )
185
- # Not set meta for propagating replacing node's meta.
186
189
  node.replace_all_uses_with(conv_out_permute, propagate_meta=True)
187
190
 
188
191
  logger.debug(f"{node.name} is replaced with {circle_op.name}")
@@ -224,25 +227,29 @@ class LegalizePreDefinedLayoutOperators(PassBase):
224
227
  with graph.inserting_after(input):
225
228
  # input permute
226
229
  NCHW_to_NHWC = [0, 2, 3, 1]
227
- input_permute = graph_module.graph.call_function(
230
+ input_permute = create_node(
231
+ graph,
228
232
  torch.ops.aten.permute.default,
229
233
  args=(input, NCHW_to_NHWC),
234
+ origin=input,
230
235
  )
231
236
  node.update_arg(node.args.index(input), input_permute)
232
237
  with graph.inserting_before(node):
233
238
  # circle instnorm
234
- circle_instnorm = graph_module.graph.call_function(
239
+ circle_instnorm = create_node(
240
+ graph,
235
241
  torch.ops.circle_custom.instance_norm,
236
242
  args=node.args,
237
243
  kwargs=node.kwargs,
244
+ origin=node,
238
245
  )
239
246
  # output permute
240
247
  NHWC_to_NCHW = [0, 3, 1, 2]
241
- instnorm_out_permute = graph_module.graph.call_function(
248
+ instnorm_out_permute = create_node(
249
+ graph,
242
250
  torch.ops.aten.permute.default,
243
251
  args=(circle_instnorm, NHWC_to_NCHW),
244
252
  )
245
- # Not set meta for propagating replacing node's meta.
246
253
  node.replace_all_uses_with(instnorm_out_permute, propagate_meta=True)
247
254
 
248
255
  logger.debug(f"{node.name} is replaced with {circle_instnorm.name}")
@@ -275,25 +282,25 @@ class LegalizePreDefinedLayoutOperators(PassBase):
275
282
  # TODO Introduce a method that inserts permute op.
276
283
  # input permute
277
284
  with graph.inserting_after(input_):
278
- input_permute = graph_module.graph.call_function(
285
+ input_permute = create_node(
286
+ graph,
279
287
  torch.ops.aten.permute.default,
280
288
  args=(input_, NCHW_to_NHWC),
289
+ origin=input_,
281
290
  )
282
291
  node.update_arg(node.args.index(input_), input_permute)
283
292
  with graph.inserting_before(node):
284
293
  legalized_op = torch.ops.circle_custom.maxpool2d
285
- circle_maxpool2d = graph_module.graph.call_function(
286
- legalized_op,
287
- args=node.args,
288
- kwargs=node.kwargs,
294
+ circle_maxpool2d = create_node(
295
+ graph, legalized_op, args=node.args, kwargs=node.kwargs, origin=node
289
296
  )
290
297
  # output permute
291
298
  NHWC_to_NCHW = [0, 3, 1, 2]
292
- maxpool_out_permute = graph_module.graph.call_function(
299
+ maxpool_out_permute = create_node(
300
+ graph,
293
301
  torch.ops.aten.permute.default,
294
302
  args=(circle_maxpool2d, NHWC_to_NCHW),
295
303
  )
296
- # Not set meta for propagating replacing get_item's meta.
297
304
  get_item, *_ = node.users.keys()
298
305
  get_item.replace_all_uses_with(maxpool_out_permute, propagate_meta=True)
299
306
 
@@ -327,21 +334,22 @@ class LegalizePreDefinedLayoutOperators(PassBase):
327
334
  # TODO Introduce a method that inserts permute op.
328
335
  # input permute
329
336
  with graph.inserting_after(input_):
330
- input_permute = graph_module.graph.call_function(
337
+ input_permute = create_node(
338
+ graph,
331
339
  torch.ops.aten.permute.default,
332
340
  args=(input_, NCHW_to_NHWC),
341
+ origin=input_,
333
342
  )
334
343
  node.update_arg(node.args.index(input_), input_permute)
335
344
  with graph.inserting_before(node):
336
345
  legalized_op = torch.ops.circle_custom.avgpool2d
337
- circle_avgpool2d = graph_module.graph.call_function(
338
- legalized_op,
339
- args=node.args,
340
- kwargs=node.kwargs,
346
+ circle_avgpool2d = create_node(
347
+ graph, legalized_op, args=node.args, kwargs=node.kwargs, origin=node
341
348
  )
342
349
  # output permute
343
350
  NHWC_to_NCHW = [0, 3, 1, 2]
344
- avgpool_out_permute = graph_module.graph.call_function(
351
+ avgpool_out_permute = create_node(
352
+ graph,
345
353
  torch.ops.aten.permute.default,
346
354
  args=(circle_avgpool2d, NHWC_to_NCHW),
347
355
  )
@@ -20,6 +20,7 @@ import torch
20
20
  from torch.export import ExportedProgram
21
21
 
22
22
  from tico.utils import logging
23
+ from tico.utils.graph import create_node
23
24
  from tico.utils.passes import PassBase, PassResult
24
25
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
26
  from tico.utils.utils import is_target_node
@@ -55,7 +56,8 @@ class LowerPow2ToMul(PassBase):
55
56
 
56
57
  lhs = rhs = in_
57
58
  with graph.inserting_after(node):
58
- new_mul = graph.call_function(
59
+ new_mul = create_node(
60
+ graph,
59
61
  torch.ops.aten.mul.Tensor,
60
62
  args=(lhs, rhs),
61
63
  kwargs={},
@@ -22,6 +22,7 @@ from torch.export import ExportedProgram
22
22
  from tico.serialize.circle_mapping import extract_shape
23
23
  from tico.utils import logging
24
24
  from tico.utils.errors import NotYetSupportedError
25
+ from tico.utils.graph import create_node
25
26
  from tico.utils.passes import PassBase, PassResult
26
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
28
  from tico.utils.utils import is_target_node
@@ -108,18 +109,23 @@ class LowerToResizeNearestNeighbor(PassBase):
108
109
  assert expected_shape == list(extract_shape(node))
109
110
 
110
111
  with graph.inserting_before(node):
111
- nchw_to_nhwc = graph.call_function(
112
- torch.ops.aten.permute.default, args=(input_tensor, [0, 2, 3, 1])
112
+ nchw_to_nhwc = create_node(
113
+ graph,
114
+ torch.ops.aten.permute.default,
115
+ args=(input_tensor, [0, 2, 3, 1]),
116
+ origin=input_tensor,
113
117
  )
114
- resize_nearest_neighbor = graph.call_function(
118
+ resize_nearest_neighbor = create_node(
119
+ graph,
115
120
  torch.ops.circle_custom.resize_nearest_neighbor,
116
121
  args=(nchw_to_nhwc, [len(expected_H_index), len(expected_W_index)]),
122
+ origin=node,
117
123
  )
118
- nhwc_to_nchw = graph.call_function(
124
+ nhwc_to_nchw = create_node(
125
+ graph,
119
126
  torch.ops.aten.permute.default,
120
127
  args=(resize_nearest_neighbor, [0, 3, 1, 2]),
121
128
  )
122
- # Not set meta for propagating replacing node's meta.
123
129
  node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True)
124
130
 
125
131
  return resize_nearest_neighbor
@@ -171,18 +177,23 @@ class LowerToResizeNearestNeighbor(PassBase):
171
177
  )
172
178
 
173
179
  with graph.inserting_before(node):
174
- nchw_to_nhwc = graph.call_function(
175
- torch.ops.aten.permute.default, args=(input_tensor, [0, 2, 3, 1])
180
+ nchw_to_nhwc = create_node(
181
+ graph,
182
+ torch.ops.aten.permute.default,
183
+ args=(input_tensor, [0, 2, 3, 1]),
184
+ origin=input_tensor,
176
185
  )
177
- resize_nearest_neighbor = graph.call_function(
186
+ resize_nearest_neighbor = create_node(
187
+ graph,
178
188
  torch.ops.circle_custom.resize_nearest_neighbor,
179
189
  args=(nchw_to_nhwc, [expected_H, expected_W]),
190
+ origin=node,
180
191
  )
181
- nhwc_to_nchw = graph.call_function(
192
+ nhwc_to_nchw = create_node(
193
+ graph,
182
194
  torch.ops.aten.permute.default,
183
195
  args=(resize_nearest_neighbor, [0, 3, 1, 2]),
184
196
  )
185
- # Not set meta for propagating replacing node's meta.
186
197
  node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True)
187
198
  return resize_nearest_neighbor
188
199
 
@@ -30,7 +30,7 @@ from torch.export import ExportedProgram
30
30
  from tico.passes import ops
31
31
  from tico.serialize.circle_graph import extract_shape
32
32
  from tico.utils import logging
33
- from tico.utils.graph import is_single_value_tensor
33
+ from tico.utils.graph import create_node, is_single_value_tensor
34
34
  from tico.utils.passes import PassBase, PassResult
35
35
  from tico.utils.trace_decorators import trace_const_diff_on_pass
36
36
  from tico.utils.utils import is_target_node
@@ -103,17 +103,22 @@ class LowerSelectCopyToSlice(PassBase):
103
103
 
104
104
  with graph.inserting_after(node):
105
105
  # slice
106
- slice_node = graph.call_function(
107
- torch.ops.aten.slice.Tensor, args=slice_copy_args
106
+ slice_node = create_node(
107
+ graph,
108
+ torch.ops.aten.slice.Tensor,
109
+ args=slice_copy_args,
110
+ origin=node,
108
111
  )
109
112
  node_shape = extract_shape(node)
110
113
  with graph.inserting_after(slice_node):
111
114
  # reshape
112
115
  reshape_args = (slice_node, list(node_shape))
113
- reshape_node = graph.call_function(
114
- torch.ops.aten.reshape.default, args=reshape_args
116
+ reshape_node = create_node(
117
+ graph,
118
+ torch.ops.aten.reshape.default,
119
+ args=reshape_args,
115
120
  )
116
- node.replace_all_uses_with(reshape_node, propagate_meta=False)
121
+ node.replace_all_uses_with(reshape_node, propagate_meta=True)
117
122
 
118
123
  modified = True
119
124
  logger.debug(
@@ -196,17 +201,22 @@ class LowerIndexSelectToSlice(PassBase):
196
201
 
197
202
  with graph.inserting_after(node):
198
203
  # slice
199
- slice_node = graph.call_function(
200
- torch.ops.aten.slice.Tensor, args=slice_copy_args
204
+ slice_node = create_node(
205
+ graph,
206
+ torch.ops.aten.slice.Tensor,
207
+ args=slice_copy_args,
208
+ origin=node,
201
209
  )
202
210
  node_shape = extract_shape(node)
203
211
  with graph.inserting_after(slice_node):
204
212
  # reshape
205
213
  reshape_args = (slice_node, list(node_shape))
206
- reshape_node = graph.call_function(
207
- torch.ops.aten.reshape.default, args=reshape_args
214
+ reshape_node = create_node(
215
+ graph,
216
+ torch.ops.aten.reshape.default,
217
+ args=reshape_args,
208
218
  )
209
- node.replace_all_uses_with(reshape_node, propagate_meta=False)
219
+ node.replace_all_uses_with(reshape_node, propagate_meta=True)
210
220
 
211
221
  modified = True
212
222
  logger.debug(
@@ -14,7 +14,6 @@
14
14
 
15
15
  from typing import TYPE_CHECKING
16
16
 
17
-
18
17
  if TYPE_CHECKING:
19
18
  import torch.fx
20
19
  import torch
@@ -23,6 +22,7 @@ from torch.export import ExportedProgram
23
22
  from tico.passes import ops
24
23
  from tico.serialize.circle_mapping import extract_shape
25
24
  from tico.utils import logging
25
+ from tico.utils.graph import create_node
26
26
  from tico.utils.passes import PassBase, PassResult
27
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
28
  from tico.utils.utils import is_target_node
@@ -106,8 +106,10 @@ class RemoveRedundantPermutePattern1(PassBase):
106
106
  else:
107
107
  with graph.inserting_after(permute2):
108
108
  new_args = (permute1_input, fused_dims)
109
- fused_permute = graph.call_function(
110
- torch.ops.aten.permute.default, args=new_args
109
+ fused_permute = create_node(
110
+ graph,
111
+ torch.ops.aten.permute.default,
112
+ args=new_args,
111
113
  )
112
114
  permute2.replace_all_uses_with(fused_permute, propagate_meta=True)
113
115
  logger.debug(f"{permute1.name} and {permute2.name} are fused.")
@@ -22,6 +22,7 @@ from torch.export import ExportedProgram
22
22
  from tico.passes import ops
23
23
  from tico.serialize.circle_mapping import extract_shape
24
24
  from tico.utils import logging
25
+ from tico.utils.graph import create_node
25
26
  from tico.utils.passes import PassBase, PassResult
26
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
28
  from tico.utils.utils import broadcastable, is_target_node, set_new_meta_val
@@ -369,8 +370,10 @@ class RemoveRedundantReshapePattern4(PassBase):
369
370
  assert isinstance(s, int), type(s)
370
371
 
371
372
  with graph.inserting_before(reshape1):
372
- fused_reshape = graph.call_function(
373
- reshape1.target, (reshape1_input, reshape2_size)
373
+ fused_reshape = create_node(
374
+ graph,
375
+ reshape1.target,
376
+ (reshape1_input, reshape2_size),
374
377
  )
375
378
 
376
379
  reshape2.replace_all_uses_with(fused_reshape, propagate_meta=True)
@@ -70,6 +70,10 @@ class RemoveRedundantToCopy(PassBase):
70
70
  if input_dtype != target_dtype:
71
71
  continue
72
72
 
73
+ if hasattr(args, "memory_format") and args.memory_format is not None:
74
+ if args.memory_format != torch.contiguous_format:
75
+ continue
76
+
73
77
  node.replace_all_uses_with(input_, propagate_meta=False)
74
78
 
75
79
  modified = True
@@ -16,6 +16,7 @@ import torch
16
16
  from torch.export import ExportedProgram
17
17
 
18
18
  from tico.utils import logging
19
+ from tico.utils.graph import create_node
19
20
  from tico.utils.passes import PassBase, PassResult
20
21
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
21
22
 
@@ -74,11 +75,12 @@ class RestoreLinear(PassBase):
74
75
 
75
76
  addmm_args = (input, weight, bias)
76
77
  with graph.inserting_after(node):
77
- linear_node = graph.call_function(
78
- torch.ops.aten.linear.default, args=addmm_args
78
+ linear_node = create_node(
79
+ graph,
80
+ torch.ops.aten.linear.default,
81
+ args=addmm_args,
79
82
  )
80
83
  node.replace_all_uses_with(linear_node, propagate_meta=True)
81
- graph.erase_node(node)
82
84
 
83
85
  elif node.target == torch.ops.aten.mm.default:
84
86
  assert len(node.args) == 2
@@ -97,8 +99,8 @@ class RestoreLinear(PassBase):
97
99
 
98
100
  mm_args = (input, weight)
99
101
  with graph.inserting_after(node):
100
- linear_node = graph.call_function(
101
- torch.ops.aten.linear.default, args=mm_args
102
+ linear_node = create_node(
103
+ graph, torch.ops.aten.linear.default, args=mm_args
102
104
  )
103
105
  node.replace_all_uses_with(linear_node, propagate_meta=True)
104
106
 
@@ -31,7 +31,7 @@ from torch.export import ExportedProgram
31
31
  from tico.passes import ops
32
32
  from tico.serialize.circle_graph import extract_shape
33
33
  from tico.utils import logging
34
- from tico.utils.graph import add_placeholder, is_single_value_tensor
34
+ from tico.utils.graph import add_placeholder, create_node, is_single_value_tensor
35
35
  from tico.utils.passes import PassBase, PassResult
36
36
  from tico.utils.trace_decorators import trace_const_diff_on_pass
37
37
  from tico.utils.utils import is_target_node
@@ -116,18 +116,22 @@ class SegmentIndexSelectConst(PassBase):
116
116
  exported_program, torch.tensor([i]), prefix="segm_index"
117
117
  )
118
118
  with graph.inserting_before(node):
119
- index_select_node = graph.call_function(
119
+ index_select_node = create_node(
120
+ graph,
120
121
  torch.ops.aten.index_select.default,
121
122
  args=(input, dim, index_node),
123
+ origin=node,
122
124
  )
123
125
  index_select_node_list.append(index_select_node)
124
126
 
125
127
  with graph.inserting_before(node):
126
- concat_node = graph.call_function(
127
- torch.ops.aten.cat.default, args=(index_select_node_list, dim)
128
+ concat_node = create_node(
129
+ graph,
130
+ torch.ops.aten.cat.default,
131
+ args=(index_select_node_list, dim),
128
132
  )
129
133
 
130
- node.replace_all_uses_with(concat_node, propagate_meta=False)
134
+ node.replace_all_uses_with(concat_node, propagate_meta=True)
131
135
 
132
136
  modified = True
133
137
  logger.debug(
tico/utils/graph.py CHANGED
@@ -16,7 +16,7 @@
16
16
  # See the License for the specific language governing permissions and
17
17
  # limitations under the License.
18
18
 
19
- from typing import Optional, TYPE_CHECKING
19
+ from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
20
20
 
21
21
  if TYPE_CHECKING:
22
22
  import torch.fx
@@ -24,7 +24,7 @@ import torch
24
24
  from torch.export import ExportedProgram
25
25
  from torch.export.exported_program import InputKind, InputSpec, TensorArgument
26
26
 
27
- from tico.utils.utils import get_fake_mode
27
+ from tico.utils.utils import get_fake_mode, set_new_meta_val
28
28
 
29
29
 
30
30
  def is_torch_param(node: torch.fx.Node, ep: ExportedProgram):
@@ -234,3 +234,49 @@ def get_module_name_chain(node: Optional[torch.fx.Node]) -> str:
234
234
  return next(reversed(stack.values()))[1]
235
235
  else:
236
236
  return "unknown"
237
+
238
+
239
+ def create_node(
240
+ graph: torch.fx.Graph,
241
+ target: torch._ops.OpOverload,
242
+ args: Optional[Tuple[Any, ...]] = None,
243
+ kwargs: Optional[Dict[str, Any]] = None,
244
+ *,
245
+ origin: Optional[torch.fx.Node] = None,
246
+ ) -> torch.fx.Node:
247
+ """
248
+ Insert a new node into graph and propagate metadata from *origin*.
249
+
250
+ Parameters
251
+ ----------
252
+ graph : torch.fx.Graph
253
+ The graph that will own the newly-created node.
254
+
255
+ target : torch._ops.OpOverload
256
+ The op to call (e.g. `torch.add` or "call_function" target).
257
+
258
+ args : Tuple[Any, ...], optional
259
+ Positional arguments for the new node.
260
+
261
+ kwargs : Dict[str, Any], optional
262
+ Keyword arguments for the new node.
263
+
264
+ origin : torch.fx.Node, optional
265
+ If given, every key in `origin.meta` **except** "val" is copied
266
+ onto the new node. "val" is recomputed from *args* /*kwargs* using
267
+ the internal meta-inference helper.
268
+
269
+ Returns
270
+ -------
271
+ torch.fx.Node
272
+ The freshly inserted node with fully-populated `.meta`.
273
+ """
274
+ new_node = graph.call_function(target, args=args, kwargs=kwargs)
275
+ if origin:
276
+ assert isinstance(origin, torch.fx.Node), type(origin)
277
+ # Propagate "nn_module_stack" to retain the originating module context
278
+ # for meaningful node names.
279
+ if "nn_module_stack" in origin.meta:
280
+ new_node.meta["nn_module_stack"] = origin.meta["nn_module_stack"]
281
+
282
+ return new_node
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250615
3
+ Version: 0.1.0.dev250617
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -105,10 +105,15 @@ You can convert a torch module to a circle model with these steps.
105
105
  torch_module = AddModule()
106
106
  example_inputs = (torch.ones(4), torch.ones(4))
107
107
 
108
- circle_model = tico.convert(torch_module, example_inputs)
108
+ circle_model = tico.convert(torch_module.eval(), example_inputs)
109
109
  circle_model.save('add.circle')
110
110
  ```
111
111
 
112
+ **NOTE**
113
+ Please make sure to call `eval()` on the PyTorch module before passing it to our API.
114
+ This ensures the model runs in inference mode, disabling layers like dropout and
115
+ batch normalization updates.
116
+
112
117
  **Compile with configuration**
113
118
 
114
119
  ```python