tico 0.1.0.dev250615__py3-none-any.whl → 0.1.0.dev250617__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. tico/__init__.py +1 -1
  2. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +6 -2
  3. tico/passes/cast_aten_where_arg_type.py +4 -1
  4. tico/passes/cast_mixed_type_args.py +4 -1
  5. tico/passes/convert_conv1d_to_conv2d.py +12 -4
  6. tico/passes/convert_layout_op_to_reshape.py +3 -2
  7. tico/passes/convert_repeat_to_expand_copy.py +5 -2
  8. tico/passes/convert_to_relu6.py +4 -3
  9. tico/passes/decompose_addmm.py +11 -7
  10. tico/passes/decompose_batch_norm.py +7 -11
  11. tico/passes/decompose_fake_quantize.py +12 -6
  12. tico/passes/decompose_fake_quantize_tensor_qparams.py +12 -6
  13. tico/passes/decompose_group_norm.py +50 -21
  14. tico/passes/decompose_grouped_conv2d.py +15 -7
  15. tico/passes/decompose_slice_scatter.py +9 -5
  16. tico/passes/fuse_leading_unsqueeze_reshape.py +8 -3
  17. tico/passes/legalize_predefined_layout_operators.py +33 -25
  18. tico/passes/lower_pow2_to_mul.py +3 -1
  19. tico/passes/lower_to_resize_nearest_neighbor.py +21 -10
  20. tico/passes/lower_to_slice.py +21 -11
  21. tico/passes/remove_redundant_permute.py +5 -3
  22. tico/passes/remove_redundant_reshape.py +5 -2
  23. tico/passes/remove_redundant_to_copy.py +4 -0
  24. tico/passes/restore_linear.py +7 -5
  25. tico/passes/segment_index_select.py +9 -5
  26. tico/utils/graph.py +48 -2
  27. {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/METADATA +7 -2
  28. {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/RECORD +32 -32
  29. {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/LICENSE +0 -0
  30. {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/WHEEL +0 -0
  31. {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/entry_points.txt +0 -0
  32. {tico-0.1.0.dev250615.dist-info → tico-0.1.0.dev250617.dist-info}/top_level.txt +0 -0
tico/__init__.py CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
21
21
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
22
22
 
23
23
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
24
- __version__ = "0.1.0.dev250615"
24
+ __version__ = "0.1.0.dev250617"
25
25
 
26
26
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
27
27
  SECURE_TORCH_VERSION = "2.6.0"
@@ -24,6 +24,7 @@ from torch.export import ExportedProgram
24
24
  from tico.serialize.quant_param import QPARAM_KEY, QuantParam
25
25
  from tico.utils import logging
26
26
  from tico.utils.errors import NotYetSupportedError
27
+ from tico.utils.graph import create_node
27
28
  from tico.utils.passes import PassBase, PassResult
28
29
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
30
  from tico.utils.utils import quant_min_max, set_new_meta_val
@@ -145,9 +146,11 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
145
146
 
146
147
  with graph.inserting_before(node):
147
148
  q_args = (inp, scale, zerop, min_, max_, dtype)
148
- quantize = graph.call_function(
149
+ quantize = create_node(
150
+ graph,
149
151
  torch.ops.quantized_decomposed.quantize_per_tensor.default,
150
152
  args=q_args,
153
+ origin=node,
151
154
  )
152
155
  quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
153
156
  set_new_meta_val(quantize)
@@ -166,7 +169,8 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
166
169
  dtype = getattr(torch, qparam.dtype)
167
170
  with graph.inserting_after(node):
168
171
  q_args = (node, scale, zerop, min_, max_, dtype)
169
- quantize = graph.call_function(
172
+ quantize = create_node(
173
+ graph,
170
174
  torch.ops.quantized_decomposed.quantize_per_tensor.default,
171
175
  args=q_args,
172
176
  )
@@ -21,6 +21,7 @@ from torch.export import ExportedProgram
21
21
 
22
22
  from tico.serialize.circle_mapping import extract_torch_dtype
23
23
  from tico.utils import logging
24
+ from tico.utils.graph import create_node
24
25
  from tico.utils.passes import PassBase, PassResult
25
26
  from tico.utils.trace_decorators import (
26
27
  trace_const_diff_on_pass,
@@ -158,10 +159,12 @@ class CastATenWhereArgType(PassBase):
158
159
  f"{to_cast.name}({buf_data.dtype}) data range is out of {dtype_to_cast} range"
159
160
  )
160
161
  with graph_module.graph.inserting_after(to_cast):
161
- cast = graph_module.graph.call_function(
162
+ cast = create_node(
163
+ graph,
162
164
  torch.ops.aten._to_copy.default,
163
165
  args=(to_cast,),
164
166
  kwargs={"dtype": dtype_to_cast},
167
+ origin=to_cast,
165
168
  )
166
169
  # set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
167
170
  set_new_meta_val(cast)
@@ -26,6 +26,7 @@ from torch.export import ExportedProgram
26
26
 
27
27
  from tico.serialize.circle_mapping import extract_torch_dtype
28
28
  from tico.utils import logging
29
+ from tico.utils.graph import create_node
29
30
  from tico.utils.passes import PassBase, PassResult
30
31
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
31
32
  from tico.utils.utils import is_target_node, set_new_meta_val
@@ -126,10 +127,12 @@ class CastMixedTypeArgs(PassBase):
126
127
 
127
128
  if isinstance(arg_to_promote, torch.fx.Node):
128
129
  with graph.inserting_after(arg_to_promote):
129
- to_copy = graph.call_function(
130
+ to_copy = create_node(
131
+ graph,
130
132
  torch.ops.aten._to_copy.default,
131
133
  (arg_to_promote,),
132
134
  {"dtype": type_to_promote},
135
+ origin=arg_to_promote,
133
136
  )
134
137
  # set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
135
138
  set_new_meta_val(to_copy)
@@ -22,6 +22,7 @@ from torch.export import ExportedProgram
22
22
  from tico.serialize.circle_graph import extract_shape
23
23
  from tico.utils import logging
24
24
  from tico.utils.errors import NotYetSupportedError
25
+ from tico.utils.graph import create_node
25
26
  from tico.utils.passes import PassBase, PassResult
26
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
28
  from tico.utils.utils import is_target_node
@@ -89,15 +90,19 @@ class ConvertConv1dToConv2d(PassBase):
89
90
  )
90
91
 
91
92
  with graph.inserting_after(input):
92
- input_unsqueeze = graph_module.graph.call_function(
93
+ input_unsqueeze = create_node(
94
+ graph,
93
95
  torch.ops.aten.unsqueeze.default,
94
96
  args=(input, 3),
97
+ origin=input,
95
98
  )
96
99
 
97
100
  with graph.inserting_after(weight):
98
- weight_unsqueeze = graph_module.graph.call_function(
101
+ weight_unsqueeze = create_node(
102
+ graph,
99
103
  torch.ops.aten.unsqueeze.default,
100
104
  args=(weight, 3),
105
+ origin=weight,
101
106
  )
102
107
 
103
108
  with graph.inserting_before(node):
@@ -106,7 +111,8 @@ class ConvertConv1dToConv2d(PassBase):
106
111
  elif isinstance(padding, str):
107
112
  conv2d_op = torch.ops.aten.conv2d.padding
108
113
 
109
- conv2d = graph_module.graph.call_function(
114
+ conv2d = create_node(
115
+ graph,
110
116
  conv2d_op,
111
117
  args=(
112
118
  input_unsqueeze,
@@ -118,9 +124,11 @@ class ConvertConv1dToConv2d(PassBase):
118
124
  groups,
119
125
  ),
120
126
  kwargs=node.kwargs,
127
+ origin=node,
121
128
  )
122
129
 
123
- conv_out_squeeze = graph_module.graph.call_function(
130
+ conv_out_squeeze = create_node(
131
+ graph,
124
132
  torch.ops.aten.squeeze.dims,
125
133
  args=(conv2d, [3]),
126
134
  )
@@ -22,6 +22,7 @@ from torch.export import ExportedProgram
22
22
  from tico.passes import ops
23
23
  from tico.serialize.circle_mapping import extract_shape
24
24
  from tico.utils import logging
25
+ from tico.utils.graph import create_node
25
26
  from tico.utils.passes import PassBase, PassResult
26
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
28
  from tico.utils.validate_args_kwargs import SqueezeArgs, UnSqueezeArgs, ViewArgs
@@ -48,11 +49,11 @@ class ConvertLayoutOpToReshape(PassBase):
48
49
  out_shape = list(extract_shape(node))
49
50
 
50
51
  with graph.inserting_after(node):
51
- reshape_node = graph.call_function(
52
+ reshape_node = create_node(
53
+ graph,
52
54
  torch.ops.aten.reshape.default,
53
55
  args=(input, out_shape),
54
56
  )
55
-
56
57
  node.replace_all_uses_with(reshape_node, propagate_meta=True)
57
58
 
58
59
  logger.debug(f"{node.name} is replaced with {reshape_node.name}")
@@ -20,6 +20,7 @@ import torch
20
20
  from torch.export import ExportedProgram
21
21
 
22
22
  from tico.utils import logging
23
+ from tico.utils.graph import create_node
23
24
  from tico.utils.passes import PassBase, PassResult
24
25
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
26
  from tico.utils.utils import is_target_node
@@ -71,8 +72,10 @@ class ConvertRepeatToExpandCopy(PassBase):
71
72
  expand_copy_args = (tensor, size)
72
73
 
73
74
  with graph.inserting_after(node):
74
- expand_copy_node = graph.call_function(
75
- torch.ops.aten.expand_copy.default, args=expand_copy_args
75
+ expand_copy_node = create_node(
76
+ graph,
77
+ torch.ops.aten.expand_copy.default,
78
+ args=expand_copy_args,
76
79
  )
77
80
  node.replace_all_uses_with(expand_copy_node, propagate_meta=True)
78
81
 
@@ -20,6 +20,7 @@ import torch
20
20
  from torch.export import ExportedProgram
21
21
 
22
22
  from tico.utils import logging
23
+ from tico.utils.graph import create_node
23
24
  from tico.utils.passes import PassBase, PassResult
24
25
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
26
  from tico.utils.validate_args_kwargs import ClampArgs, HardTanhArgs
@@ -58,7 +59,7 @@ class ConvertHardTanhToReLU6(Converter):
58
59
  input = args.input
59
60
 
60
61
  with graph.inserting_after(node):
61
- relu_node = graph.call_function(torch.ops.aten.relu6.default, args=(input,))
62
+ relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
62
63
  node.replace_all_uses_with(relu_node, propagate_meta=True)
63
64
 
64
65
 
@@ -84,7 +85,7 @@ class ConvertClampToReLU6(Converter):
84
85
  input = args.input
85
86
 
86
87
  with graph.inserting_after(node):
87
- relu_node = graph.call_function(torch.ops.aten.relu6.default, args=(input,))
88
+ relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
88
89
  node.replace_all_uses_with(relu_node, propagate_meta=True)
89
90
 
90
91
 
@@ -140,7 +141,7 @@ class ConvertDoubleClampsToReLU6(Converter):
140
141
  input = prev_args.input
141
142
 
142
143
  with graph.inserting_after(node):
143
- relu_node = graph.call_function(torch.ops.aten.relu6.default, args=(input,))
144
+ relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
144
145
  node.replace_all_uses_with(relu_node, propagate_meta=True)
145
146
 
146
147
 
@@ -21,7 +21,7 @@ from torch.export import ExportedProgram
21
21
 
22
22
  from tico.serialize.circle_mapping import extract_shape
23
23
  from tico.utils import logging
24
- from tico.utils.graph import add_placeholder
24
+ from tico.utils.graph import add_placeholder, create_node
25
25
  from tico.utils.passes import PassBase, PassResult
26
26
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
27
  from tico.utils.utils import is_target_node, set_new_meta_val
@@ -78,7 +78,9 @@ class DecomposeAddmm(PassBase):
78
78
 
79
79
  with graph.inserting_before(node):
80
80
  # out = beta * input + alpha * (mat1 @ mat2)
81
- matmul = graph.call_function(torch.ops.aten.mm.default, (mat1, mat2))
81
+ matmul = create_node(
82
+ graph, torch.ops.aten.mm.default, (mat1, mat2), origin=node
83
+ )
82
84
  set_new_meta_val(matmul)
83
85
 
84
86
  if beta == 1:
@@ -90,7 +92,9 @@ class DecomposeAddmm(PassBase):
90
92
  f"{node.name}_beta_zeros",
91
93
  )
92
94
  else:
93
- bias = graph.call_function(torch.ops.aten.mul.Tensor, (input, beta))
95
+ bias = create_node(
96
+ graph, torch.ops.aten.mul.Tensor, (input, beta), origin=node
97
+ )
94
98
 
95
99
  if alpha == 1:
96
100
  scaled_matmul: torch.fx.Node | torch.Tensor = matmul
@@ -101,12 +105,12 @@ class DecomposeAddmm(PassBase):
101
105
  f"{node.name}_alpha_zeros",
102
106
  )
103
107
  else:
104
- scaled_matmul = graph.call_function(
105
- torch.ops.aten.mul.Tensor, (matmul, alpha)
108
+ scaled_matmul = create_node(
109
+ graph, torch.ops.aten.mul.Tensor, (matmul, alpha), origin=node
106
110
  )
107
111
 
108
- result = graph.call_function(
109
- torch.ops.aten.add.Tensor, (bias, scaled_matmul)
112
+ result = create_node(
113
+ graph, torch.ops.aten.add.Tensor, (bias, scaled_matmul)
110
114
  )
111
115
 
112
116
  node.replace_all_uses_with(result, propagate_meta=True)
@@ -24,6 +24,7 @@ from tico.utils import logging
24
24
  from tico.utils.errors import NotYetSupportedError
25
25
  from tico.utils.graph import (
26
26
  add_placeholder,
27
+ create_node,
27
28
  get_first_user_input,
28
29
  get_torch_buffer_value,
29
30
  get_torch_param_value,
@@ -32,16 +33,10 @@ from tico.utils.graph import (
32
33
  )
33
34
  from tico.utils.passes import PassBase, PassResult
34
35
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
35
- from tico.utils.utils import fill_meta_val, is_target_node
36
+ from tico.utils.utils import is_target_node
36
37
  from tico.utils.validate_args_kwargs import NativeBatchNormLegitNoTrainingArgs
37
38
 
38
39
 
39
- def insert_node(graph: torch.fx.Graph, operation, args):
40
- new_node = graph.call_function(operation, args)
41
-
42
- return new_node
43
-
44
-
45
40
  @trace_graph_diff_on_pass
46
41
  class DecomposeBatchNorm(PassBase):
47
42
  """
@@ -173,19 +168,20 @@ class DecomposeBatchNorm(PassBase):
173
168
  )
174
169
 
175
170
  with gm.graph.inserting_before(node):
176
- mul = graph.call_function(
171
+ mul = create_node(
172
+ graph,
177
173
  torch.ops.aten.mul.Tensor,
178
174
  args=(input_, mul_const_node),
175
+ origin=node,
179
176
  )
180
- add = graph.call_function(
177
+ add = create_node(
178
+ graph,
181
179
  torch.ops.aten.add.Tensor,
182
180
  args=(mul, add_const_node),
183
181
  )
184
- # Not set meta for propagating replacing get_item's meta.
185
182
  get_item, *_ = node.users.keys()
186
183
  get_item.replace_all_uses_with(add, propagate_meta=True)
187
184
 
188
- fill_meta_val(exported_program)
189
185
  logger.debug(f"{node.name} is decomposed to {mul.name} and {add.name}")
190
186
  modified = True
191
187
 
@@ -23,6 +23,7 @@ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
23
23
  from torch.export import ExportedProgram
24
24
 
25
25
  from tico.utils import logging
26
+ from tico.utils.graph import create_node
26
27
  from tico.utils.passes import PassBase, PassResult
27
28
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
29
  from tico.utils.validate_args_kwargs import FakeQuantizePerChannelArgs
@@ -69,6 +70,7 @@ class DecomposeFakeQuantize(PassBase):
69
70
  modified = False
70
71
 
71
72
  gm = exported_program.graph_module
73
+ g = gm.graph
72
74
  qd = torch.ops.quantized_decomposed # type: ignore[return]
73
75
  for node in gm.graph.nodes:
74
76
  if node.op != "call_function":
@@ -83,17 +85,19 @@ class DecomposeFakeQuantize(PassBase):
83
85
  **{"dtype": get_quant_type(quant_min, quant_max)},
84
86
  }
85
87
  with gm.graph.inserting_before(node):
86
- quant = gm.graph.call_function(
88
+ quant = create_node(
89
+ g,
87
90
  qd.quantize_per_tensor.default,
88
91
  args=node.args,
89
92
  kwargs=quant_kwargs,
93
+ origin=node,
90
94
  )
91
- dequnt = gm.graph.call_function(
95
+ dequnt = create_node(
96
+ g,
92
97
  qd.dequantize_per_tensor.default,
93
98
  args=(quant, *quant.args[1:]),
94
99
  kwargs=quant.kwargs,
95
100
  )
96
- # Not set meta for propagating replacing node's meta.
97
101
  node.replace_all_uses_with(dequnt, propagate_meta=True)
98
102
  modified = True
99
103
 
@@ -107,17 +111,19 @@ class DecomposeFakeQuantize(PassBase):
107
111
  **{"dtype": get_quant_type(quant_min, quant_max)},
108
112
  }
109
113
  with gm.graph.inserting_before(node):
110
- quant = gm.graph.call_function(
114
+ quant = create_node(
115
+ g,
111
116
  qd.quantize_per_channel.default,
112
117
  args=node.args,
113
118
  kwargs=quant_kwargs,
119
+ origin=node,
114
120
  )
115
- dequnt = gm.graph.call_function(
121
+ dequnt = create_node(
122
+ g,
116
123
  qd.dequantize_per_channel.default,
117
124
  args=(quant, *quant.args[1:]),
118
125
  kwargs=quant.kwargs,
119
126
  )
120
- # Not set meta for propagating replacing node's meta.
121
127
  node.replace_all_uses_with(dequnt, propagate_meta=True)
122
128
  modified = True
123
129
 
@@ -30,6 +30,7 @@ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
30
30
  from torch.export import ExportedProgram
31
31
 
32
32
  from tico.utils import logging
33
+ from tico.utils.graph import create_node
33
34
  from tico.utils.passes import PassBase, PassResult
34
35
  from tico.utils.trace_decorators import (
35
36
  trace_const_diff_on_pass,
@@ -200,6 +201,7 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
200
201
  modified = False
201
202
 
202
203
  gm = exported_program.graph_module
204
+ g = gm.graph
203
205
  qd = torch.ops.quantized_decomposed # type: ignore[return]
204
206
  for node in gm.graph.nodes:
205
207
  if node.op != "call_function":
@@ -226,17 +228,19 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
226
228
  **{"dtype": get_quant_type(quant_min, quant_max)},
227
229
  }
228
230
  with gm.graph.inserting_before(node):
229
- quant = gm.graph.call_function(
231
+ quant = create_node(
232
+ g,
230
233
  qd.quantize_per_tensor.default,
231
234
  args=(tensor, s_value, zp_value, quant_min, quant_max),
232
235
  kwargs=quant_kwargs,
236
+ origin=node,
233
237
  )
234
- dequant = gm.graph.call_function(
238
+ dequant = create_node(
239
+ g,
235
240
  qd.dequantize_per_tensor.default,
236
241
  args=(quant, *quant.args[1:]),
237
242
  kwargs=quant.kwargs,
238
243
  )
239
- # Not set meta for propagating replacing get_item's meta.
240
244
  get_item.replace_all_uses_with(dequant, propagate_meta=True)
241
245
  # If `mask` can be graph output, which prevents `eliminate_dead_code()` from eliminating `mask`.
242
246
  # So, let's remove `mask` from the output.args first.
@@ -267,17 +271,19 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
267
271
  **{"dtype": get_quant_type(quant_min, quant_max)},
268
272
  }
269
273
  with gm.graph.inserting_before(node):
270
- quant = gm.graph.call_function(
274
+ quant = create_node(
275
+ g,
271
276
  qd.quantize_per_tensor.default,
272
277
  args=(tensor, s_value, zp_value, quant_min, quant_max),
273
278
  kwargs=quant_kwargs,
279
+ origin=node,
274
280
  )
275
- dequant = gm.graph.call_function(
281
+ dequant = create_node(
282
+ g,
276
283
  qd.dequantize_per_tensor.default,
277
284
  args=(quant, *quant.args[1:]),
278
285
  kwargs=quant.kwargs,
279
286
  )
280
- # Not set meta for propagating replacing get_item's meta.
281
287
  node.replace_all_uses_with(dequant, propagate_meta=True)
282
288
  modified = True
283
289
 
@@ -23,6 +23,7 @@ from torch.export import ExportedProgram
23
23
 
24
24
  from tico.serialize.circle_mapping import extract_shape
25
25
  from tico.utils import logging
26
+ from tico.utils.graph import create_node
26
27
  from tico.utils.passes import PassBase, PassResult
27
28
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
29
  from tico.utils.utils import is_target_node
@@ -89,24 +90,40 @@ class DecomposeGroupNorm(PassBase):
89
90
  def __init__(self):
90
91
  super().__init__()
91
92
 
92
- def _insert_norm(self, graph, tensor, eps):
93
+ def _insert_norm(self, graph, tensor, eps, origin):
93
94
  """
94
95
  Insert (tensor - mean) / sqrt(var + eps)) into the graph
95
96
  and return the normalized tensor node.
96
97
  """
97
- mean = graph.call_function(
98
- torch.ops.aten.mean.dim, (tensor, [-1]), {"keepdim": True}
98
+ mean = create_node(
99
+ graph,
100
+ torch.ops.aten.mean.dim,
101
+ (tensor, [-1]),
102
+ {"keepdim": True},
103
+ origin=origin,
99
104
  )
100
- deviation = graph.call_function(torch.ops.aten.sub.Tensor, (tensor, mean))
101
- squared = graph.call_function(torch.ops.aten.pow.Tensor_Scalar, (deviation, 2))
102
- var = graph.call_function(
103
- torch.ops.aten.mean.dim, (squared, [-1]), {"keepdim": True}
105
+ deviation = create_node(
106
+ graph, torch.ops.aten.sub.Tensor, (tensor, mean), origin=origin
104
107
  )
105
- inverse_std = graph.call_function(
108
+ squared = create_node(
109
+ graph, torch.ops.aten.pow.Tensor_Scalar, (deviation, 2), origin=origin
110
+ )
111
+ var = create_node(
112
+ graph,
113
+ torch.ops.aten.mean.dim,
114
+ (squared, [-1]),
115
+ {"keepdim": True},
116
+ origin=origin,
117
+ )
118
+ inverse_std = create_node(
119
+ graph,
106
120
  torch.ops.aten.rsqrt.default,
107
- (graph.call_function(torch.ops.aten.add.Tensor, (var, eps)),),
121
+ (create_node(graph, torch.ops.aten.add.Tensor, (var, eps), origin=origin),),
122
+ origin=origin,
123
+ )
124
+ return create_node(
125
+ graph, torch.ops.aten.mul.Tensor, (deviation, inverse_std), origin=origin
108
126
  )
109
- return graph.call_function(torch.ops.aten.mul.Tensor, (deviation, inverse_std))
110
127
 
111
128
  def call(self, exported_program: ExportedProgram) -> PassResult:
112
129
  logger = logging.getLogger(__name__)
@@ -178,17 +195,23 @@ class DecomposeGroupNorm(PassBase):
178
195
  # Branch only on whether a reshape is needed; the normalization is shared.
179
196
  if norm_size != x_shape[-1]:
180
197
  # Pack groups so that the last dimension equals norm_size.
181
- packed = graph.call_function(
182
- torch.ops.aten.reshape.default, (x, pack_shape)
198
+ packed = create_node(
199
+ graph,
200
+ torch.ops.aten.reshape.default,
201
+ (x, pack_shape),
202
+ origin=node,
183
203
  )
184
- normed = self._insert_norm(graph, packed, eps)
204
+ normed = self._insert_norm(graph, packed, eps, origin=node)
185
205
  # Restore the original shape after normalization.
186
- layer_norm = graph.call_function(
187
- torch.ops.aten.reshape.default, (normed, x_shape)
206
+ layer_norm = create_node(
207
+ graph,
208
+ torch.ops.aten.reshape.default,
209
+ (normed, x_shape),
210
+ origin=node,
188
211
  )
189
212
  else:
190
213
  # The input already has norm_size in the last dimension.
191
- layer_norm = self._insert_norm(graph, x, eps)
214
+ layer_norm = self._insert_norm(graph, x, eps, origin=node)
192
215
 
193
216
  # weight
194
217
  if weight:
@@ -197,13 +220,17 @@ class DecomposeGroupNorm(PassBase):
197
220
  assert weight_shape[0] == C
198
221
  reshape_size = [1] * len(x_shape)
199
222
  reshape_size[1] = C
200
- weight = graph.call_function(
223
+ weight = create_node(
224
+ graph,
201
225
  torch.ops.aten.view.default,
202
226
  (weight, reshape_size),
227
+ origin=node,
203
228
  )
204
- layer_norm = graph.call_function(
229
+ layer_norm = create_node(
230
+ graph,
205
231
  torch.ops.aten.mul.Tensor,
206
232
  (layer_norm, weight),
233
+ origin=node,
207
234
  )
208
235
 
209
236
  # bias
@@ -213,15 +240,17 @@ class DecomposeGroupNorm(PassBase):
213
240
  assert bias_shape[0] == C
214
241
  reshape_size = [1] * len(x_shape)
215
242
  reshape_size[1] = C
216
- bias = graph.call_function(
243
+ bias = create_node(
244
+ graph,
217
245
  torch.ops.aten.view.default,
218
246
  (bias, reshape_size),
247
+ origin=node,
219
248
  )
220
- layer_norm = graph.call_function(
249
+ layer_norm = create_node(
250
+ graph,
221
251
  torch.ops.aten.add.Tensor,
222
252
  (layer_norm, bias),
223
253
  )
224
-
225
254
  # Reset last node's meta for propagating replacing node's meta.
226
255
  layer_norm.meta = {}
227
256
 
@@ -23,7 +23,7 @@ from tico.passes import ops
23
23
  from tico.serialize.circle_mapping import extract_shape
24
24
  from tico.utils import logging
25
25
  from tico.utils.errors import InvalidArgumentError, NotYetSupportedError
26
- from tico.utils.graph import add_placeholder
26
+ from tico.utils.graph import add_placeholder, create_node
27
27
  from tico.utils.passes import PassBase, PassResult
28
28
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
29
  from tico.utils.utils import is_target_node
@@ -159,19 +159,26 @@ class DecomposeGroupedConv2d(PassBase):
159
159
 
160
160
  conv2d_tensors = []
161
161
  for i in range(groups):
162
- sliced_input = graph.call_function(
162
+ sliced_input = create_node(
163
+ graph,
163
164
  torch.ops.aten.slice.Tensor,
164
165
  (input_, 1, group_size * i, group_size * (i + 1), 1),
166
+ origin=node,
165
167
  )
166
- sliced_weight = graph.call_function(
168
+ sliced_weight = create_node(
169
+ graph,
167
170
  torch.ops.aten.slice.Tensor,
168
171
  (weight, 0, out_group_size * i, out_group_size * (i + 1), 1),
172
+ origin=node,
169
173
  )
170
- sliced_bias = graph.call_function(
174
+ sliced_bias = create_node(
175
+ graph,
171
176
  torch.ops.aten.slice.Tensor,
172
177
  (bias, 0, out_group_size * i, out_group_size * (i + 1), 1),
178
+ origin=node,
173
179
  )
174
- conv2d_tensor = graph.call_function(
180
+ conv2d_tensor = create_node(
181
+ graph,
175
182
  conv2d_op,
176
183
  (
177
184
  sliced_input,
@@ -182,11 +189,12 @@ class DecomposeGroupedConv2d(PassBase):
182
189
  dilation,
183
190
  1,
184
191
  ),
192
+ origin=node,
185
193
  )
186
194
  conv2d_tensors.append(conv2d_tensor)
187
195
 
188
- concat_output = graph.call_function(
189
- torch.ops.aten.cat.default, (conv2d_tensors, 1)
196
+ concat_output = create_node(
197
+ graph, torch.ops.aten.cat.default, (conv2d_tensors, 1)
190
198
  )
191
199
 
192
200
  node.replace_all_uses_with(concat_output, propagate_meta=True)