tico 0.1.0.dev250609__py3-none-any.whl → 0.1.0.dev250611__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. tico/__init__.py +1 -1
  2. tico/passes/cast_aten_where_arg_type.py +65 -62
  3. tico/passes/cast_mixed_type_args.py +2 -5
  4. tico/passes/convert_conv1d_to_conv2d.py +3 -4
  5. tico/passes/convert_repeat_to_expand_copy.py +5 -9
  6. tico/passes/decompose_addmm.py +41 -48
  7. tico/passes/decompose_batch_norm.py +97 -99
  8. tico/passes/decompose_fake_quantize.py +4 -6
  9. tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -4
  10. tico/passes/decompose_group_norm.py +8 -7
  11. tico/passes/decompose_grouped_conv2d.py +2 -3
  12. tico/passes/decompose_slice_scatter.py +2 -4
  13. tico/passes/extract_dtype_kwargs.py +2 -1
  14. tico/passes/fuse_leading_unsqueeze_reshape.py +107 -0
  15. tico/passes/fuse_redundant_reshape_to_mean.py +3 -5
  16. tico/passes/legalize_causal_mask_value.py +2 -7
  17. tico/passes/legalize_predefined_layout_operators.py +2 -3
  18. tico/passes/lower_pow2_to_mul.py +5 -7
  19. tico/passes/lower_to_resize_nearest_neighbor.py +6 -10
  20. tico/passes/lower_to_slice.py +3 -9
  21. tico/passes/merge_consecutive_cat.py +2 -4
  22. tico/passes/remove_nop.py +2 -3
  23. tico/passes/remove_redundant_assert_nodes.py +2 -1
  24. tico/passes/remove_redundant_expand.py +5 -9
  25. tico/passes/remove_redundant_permute.py +6 -5
  26. tico/passes/remove_redundant_reshape.py +17 -34
  27. tico/passes/remove_redundant_slice.py +2 -4
  28. tico/passes/remove_redundant_to_copy.py +2 -4
  29. tico/passes/segment_index_select.py +2 -4
  30. tico/serialize/operators/op_where.py +2 -2
  31. tico/utils/convert.py +2 -0
  32. tico/utils/utils.py +26 -0
  33. {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/METADATA +1 -1
  34. {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/RECORD +38 -37
  35. {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/LICENSE +0 -0
  36. {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/WHEEL +0 -0
  37. {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/entry_points.txt +0 -0
  38. {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/top_level.txt +0 -0
tico/__init__.py CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
21
21
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
22
22
 
23
23
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
24
- __version__ = "0.1.0.dev250609"
24
+ __version__ = "0.1.0.dev250611"
25
25
 
26
26
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
27
27
  SECURE_TORCH_VERSION = "2.6.0"
@@ -26,7 +26,8 @@ from tico.utils.trace_decorators import (
26
26
  trace_const_diff_on_pass,
27
27
  trace_graph_diff_on_pass,
28
28
  )
29
- from tico.utils.utils import set_new_meta_val
29
+ from tico.utils.utils import is_target_node, set_new_meta_val
30
+ from tico.utils.validate_args_kwargs import WhereSelfArgs
30
31
 
31
32
 
32
33
  dtype_ranking = {
@@ -114,69 +115,71 @@ class CastATenWhereArgType(PassBase):
114
115
  modified = False
115
116
 
116
117
  for node in graph.nodes:
117
- if node.op == "call_function" and node.target == torch.ops.aten.where.self:
118
-
119
- assert len(node.args) == 3
120
- (
121
- _,
122
- result_true,
123
- result_false,
124
- ) = node.args # first argument is not used
125
-
126
- ep = exported_program
127
-
128
- if not (
129
- result_true.name in ep.graph_signature.inputs_to_buffers
130
- and result_false.name in ep.graph_signature.inputs_to_buffers
131
- ):
132
- continue
133
-
134
- # Check if they have different data types
135
- true_dtype = extract_torch_dtype(result_true)
136
- false_dtype = extract_torch_dtype(result_false)
137
- if true_dtype == false_dtype:
138
- continue
139
-
140
- node_to_dtype = {result_true: true_dtype, result_false: false_dtype}
141
-
142
- not_to_cast, to_cast = sort_by_dtype(result_true, result_false)
143
-
144
- buf_name_to_data = {name: buf for name, buf in ep.named_buffers()}
145
- buf_name = ep.graph_signature.inputs_to_buffers[to_cast.name]
146
- buf_data = buf_name_to_data[buf_name]
147
-
148
- assert isinstance(buf_data, torch.Tensor)
149
-
150
- dtype_to_cast = node_to_dtype[not_to_cast]
151
-
152
- if dtype_to_cast == torch.float32:
153
- if not check_if_covered_by_float(buf_data):
154
- raise RuntimeError(
155
- f"{to_cast.name}({buf_data.dtype}) data range is out of {dtype_to_cast} range"
156
- )
157
- with graph_module.graph.inserting_after(to_cast):
158
- cast = graph_module.graph.call_function(
159
- torch.ops.aten._to_copy.default,
160
- args=(to_cast,),
161
- kwargs={"dtype": dtype_to_cast},
118
+ if not is_target_node(node, torch.ops.aten.where.self):
119
+ continue
120
+
121
+ where_args = WhereSelfArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
122
+ result_true, result_false = where_args.input, where_args.other
123
+ if not isinstance(result_true, torch.fx.Node) or not isinstance(
124
+ result_false, torch.fx.Node
125
+ ):
126
+ continue
127
+
128
+ ep = exported_program
129
+ assert isinstance(result_true, torch.fx.Node)
130
+ assert isinstance(result_false, torch.fx.Node)
131
+ if not (
132
+ result_true.name in ep.graph_signature.inputs_to_buffers
133
+ and result_false.name in ep.graph_signature.inputs_to_buffers
134
+ ):
135
+ continue
136
+
137
+ # Check if they have different data types
138
+ true_dtype = extract_torch_dtype(result_true)
139
+ false_dtype = extract_torch_dtype(result_false)
140
+ if true_dtype == false_dtype:
141
+ continue
142
+
143
+ node_to_dtype = {result_true: true_dtype, result_false: false_dtype}
144
+
145
+ not_to_cast, to_cast = sort_by_dtype(result_true, result_false)
146
+
147
+ buf_name_to_data = {name: buf for name, buf in ep.named_buffers()}
148
+ buf_name = ep.graph_signature.inputs_to_buffers[to_cast.name]
149
+ buf_data = buf_name_to_data[buf_name]
150
+
151
+ assert isinstance(buf_data, torch.Tensor)
152
+
153
+ dtype_to_cast = node_to_dtype[not_to_cast]
154
+
155
+ if dtype_to_cast == torch.float32:
156
+ if not check_if_covered_by_float(buf_data):
157
+ raise RuntimeError(
158
+ f"{to_cast.name}({buf_data.dtype}) data range is out of {dtype_to_cast} range"
162
159
  )
163
- # set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
164
- set_new_meta_val(cast)
165
- node.update_arg(node.args.index(to_cast), cast)
166
-
167
- # check if type promotion is valid.
168
- node_dtype_ori = extract_torch_dtype(node)
169
- set_new_meta_val(node)
170
- node_dtype = extract_torch_dtype(node)
171
- assert (
172
- node_dtype == node_dtype_ori
173
- ), f"Type casting doesn't change node's dtype."
174
-
175
- logger.debug(
176
- f"{to_cast.name}'s dtype was casted from {buf_data.dtype} to {dtype_to_cast}"
160
+ with graph_module.graph.inserting_after(to_cast):
161
+ cast = graph_module.graph.call_function(
162
+ torch.ops.aten._to_copy.default,
163
+ args=(to_cast,),
164
+ kwargs={"dtype": dtype_to_cast},
177
165
  )
178
-
179
- modified = True
166
+ # set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
167
+ set_new_meta_val(cast)
168
+ node.update_arg(node.args.index(to_cast), cast)
169
+
170
+ # check if type promotion is valid.
171
+ node_dtype_ori = extract_torch_dtype(node)
172
+ set_new_meta_val(node)
173
+ node_dtype = extract_torch_dtype(node)
174
+ assert (
175
+ node_dtype == node_dtype_ori
176
+ ), f"Type casting doesn't change node's dtype."
177
+
178
+ logger.debug(
179
+ f"{to_cast.name}'s dtype was casted from {buf_data.dtype} to {dtype_to_cast}"
180
+ )
181
+
182
+ modified = True
180
183
 
181
184
  graph.eliminate_dead_code()
182
185
  graph.lint()
@@ -28,7 +28,7 @@ from tico.serialize.circle_mapping import extract_torch_dtype
28
28
  from tico.utils import logging
29
29
  from tico.utils.passes import PassBase, PassResult
30
30
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
31
- from tico.utils.utils import set_new_meta_val
31
+ from tico.utils.utils import is_target_node, set_new_meta_val
32
32
 
33
33
 
34
34
  ops_to_promote = {
@@ -96,10 +96,7 @@ class CastMixedTypeArgs(PassBase):
96
96
  graph = graph_module.graph
97
97
  modified = False
98
98
  for node in graph.nodes:
99
- if not node.op == "call_function":
100
- continue
101
-
102
- if node.target not in ops_to_promote:
99
+ if not is_target_node(node, list(ops_to_promote.keys())):
103
100
  continue
104
101
 
105
102
  assert len(node.args) == 2
@@ -24,6 +24,7 @@ from tico.utils import logging
24
24
  from tico.utils.errors import NotYetSupportedError
25
25
  from tico.utils.passes import PassBase, PassResult
26
26
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
+ from tico.utils.utils import is_target_node
27
28
  from tico.utils.validate_args_kwargs import Conv1DArgs
28
29
 
29
30
 
@@ -131,17 +132,15 @@ class ConvertConv1dToConv2d(PassBase):
131
132
  return modified
132
133
 
133
134
  def call(self, exported_program: ExportedProgram) -> PassResult:
134
- target_conv_op = (torch.ops.aten.conv1d.default, torch.ops.aten.conv1d.padding)
135
+ target_conv_op = [torch.ops.aten.conv1d.default, torch.ops.aten.conv1d.padding]
135
136
 
136
137
  graph_module = exported_program.graph_module
137
138
  graph = graph_module.graph
138
139
  modified = False
139
140
  for node in graph.nodes:
140
- if not node.op == "call_function":
141
+ if not is_target_node(node, target_conv_op):
141
142
  continue
142
143
 
143
- if node.target not in target_conv_op:
144
- continue
145
144
  modified |= self.convert(exported_program, node)
146
145
 
147
146
  graph.eliminate_dead_code()
@@ -22,6 +22,8 @@ from torch.export import ExportedProgram
22
22
  from tico.utils import logging
23
23
  from tico.utils.passes import PassBase, PassResult
24
24
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
+ from tico.utils.utils import is_target_node
26
+ from tico.utils.validate_args_kwargs import RepeatArgs
25
27
 
26
28
 
27
29
  @trace_graph_diff_on_pass
@@ -42,17 +44,11 @@ class ConvertRepeatToExpandCopy(PassBase):
42
44
  graph = graph_module.graph
43
45
  modified = False
44
46
  for node in graph.nodes:
45
- if not node.op == "call_function":
47
+ if not is_target_node(node, torch.ops.aten.repeat.default):
46
48
  continue
47
49
 
48
- if node.target != torch.ops.aten.repeat.default:
49
- continue
50
-
51
- assert len(node.args) == 2
52
-
53
- tensor, repeats = node.args
54
- assert isinstance(tensor, torch.fx.Node)
55
- assert isinstance(repeats, list)
50
+ reshape_args = RepeatArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
51
+ tensor, repeats = reshape_args.input, reshape_args.repeats
56
52
 
57
53
  tensor_shape: List[int] = [int(dim) for dim in tensor.meta["val"].shape]
58
54
 
@@ -24,7 +24,7 @@ from tico.utils import logging
24
24
  from tico.utils.graph import add_placeholder
25
25
  from tico.utils.passes import PassBase, PassResult
26
26
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
- from tico.utils.utils import set_new_meta_val
27
+ from tico.utils.utils import is_target_node, set_new_meta_val
28
28
  from tico.utils.validate_args_kwargs import AddmmArgs
29
29
 
30
30
 
@@ -66,59 +66,52 @@ class DecomposeAddmm(PassBase):
66
66
  modified = False
67
67
 
68
68
  for node in graph.nodes:
69
- if node.op != "call_function":
69
+ if not is_target_node(node, torch.ops.aten.addmm.default):
70
70
  continue
71
71
 
72
- if node.target in [
73
- torch.ops.aten.addmm.default,
74
- ]:
75
- args = AddmmArgs(*node.args, **node.kwargs)
76
- input = args.input
77
- mat1 = args.mat1
78
- mat2 = args.mat2
79
- beta = args.beta
80
- alpha = args.alpha
81
-
82
- with graph.inserting_before(node):
83
- # out = beta * input + alpha * (mat1 @ mat2)
84
- matmul = graph.call_function(
85
- torch.ops.aten.mm.default, (mat1, mat2)
72
+ args = AddmmArgs(*node.args, **node.kwargs)
73
+ input = args.input
74
+ mat1 = args.mat1
75
+ mat2 = args.mat2
76
+ beta = args.beta
77
+ alpha = args.alpha
78
+
79
+ with graph.inserting_before(node):
80
+ # out = beta * input + alpha * (mat1 @ mat2)
81
+ matmul = graph.call_function(torch.ops.aten.mm.default, (mat1, mat2))
82
+ set_new_meta_val(matmul)
83
+
84
+ if beta == 1:
85
+ bias: torch.fx.Node | torch.Tensor = input
86
+ elif beta == 0:
87
+ bias = add_placeholder(
88
+ exported_program,
89
+ torch.zeros(extract_shape(input)),
90
+ f"{node.name}_beta_zeros",
86
91
  )
87
- set_new_meta_val(matmul)
88
-
89
- if beta == 1:
90
- bias: torch.fx.Node | torch.Tensor = input
91
- elif beta == 0:
92
- bias = add_placeholder(
93
- exported_program,
94
- torch.zeros(extract_shape(input)),
95
- f"{node.name}_beta_zeros",
96
- )
97
- else:
98
- bias = graph.call_function(
99
- torch.ops.aten.mul.Tensor, (input, beta)
100
- )
101
-
102
- if alpha == 1:
103
- scaled_matmul: torch.fx.Node | torch.Tensor = matmul
104
- elif alpha == 0:
105
- scaled_matmul = add_placeholder(
106
- exported_program,
107
- torch.zeros(extract_shape(matmul)),
108
- f"{node.name}_alpha_zeros",
109
- )
110
- else:
111
- scaled_matmul = graph.call_function(
112
- torch.ops.aten.mul.Tensor, (matmul, alpha)
113
- )
114
-
115
- result = graph.call_function(
116
- torch.ops.aten.add.Tensor, (bias, scaled_matmul)
92
+ else:
93
+ bias = graph.call_function(torch.ops.aten.mul.Tensor, (input, beta))
94
+
95
+ if alpha == 1:
96
+ scaled_matmul: torch.fx.Node | torch.Tensor = matmul
97
+ elif alpha == 0:
98
+ scaled_matmul = add_placeholder(
99
+ exported_program,
100
+ torch.zeros(extract_shape(matmul)),
101
+ f"{node.name}_alpha_zeros",
117
102
  )
103
+ else:
104
+ scaled_matmul = graph.call_function(
105
+ torch.ops.aten.mul.Tensor, (matmul, alpha)
106
+ )
107
+
108
+ result = graph.call_function(
109
+ torch.ops.aten.add.Tensor, (bias, scaled_matmul)
110
+ )
118
111
 
119
- node.replace_all_uses_with(result, propagate_meta=True)
112
+ node.replace_all_uses_with(result, propagate_meta=True)
120
113
 
121
- modified = True
114
+ modified = True
122
115
 
123
116
  gm.graph.eliminate_dead_code()
124
117
  gm.graph.lint()
@@ -32,7 +32,7 @@ from tico.utils.graph import (
32
32
  )
33
33
  from tico.utils.passes import PassBase, PassResult
34
34
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
35
- from tico.utils.utils import fill_meta_val
35
+ from tico.utils.utils import fill_meta_val, is_target_node
36
36
  from tico.utils.validate_args_kwargs import NativeBatchNormLegitNoTrainingArgs
37
37
 
38
38
 
@@ -87,109 +87,107 @@ class DecomposeBatchNorm(PassBase):
87
87
  modified = False
88
88
 
89
89
  for node in graph.nodes:
90
- if node.op != "call_function":
90
+ if not is_target_node(
91
+ node, torch.ops.aten._native_batch_norm_legit_no_training.default
92
+ ):
91
93
  continue
92
94
 
93
- if node.target in [
94
- torch.ops.aten._native_batch_norm_legit_no_training.default,
95
- ]:
96
- args = NativeBatchNormLegitNoTrainingArgs(*node.args)
97
- input_ = args.input
98
- weight = args.weight
99
- bias = args.bias
100
- running_mean = args.running_mean
101
- running_var = args.running_var
102
- eps = args.eps
103
-
104
- if not running_mean:
105
- raise NotYetSupportedError(
106
- f"running_mean=None is not supported yet"
107
- )
108
- if not running_var:
109
- raise NotYetSupportedError(f"running_var=None is not supported yet")
110
-
111
- """
112
- Only support the cases generated from torch.nn.BatchNorm2d module,
113
- for which, let's checks if weight and bias are parameters and
114
- running_mean and running_var are buffers.
115
- """
116
- if weight and not is_torch_param(weight, exported_program):
117
- continue
118
- if bias and not is_torch_param(bias, exported_program):
119
- continue
120
- if not is_torch_buffer(running_mean, exported_program):
121
- continue
122
- if not is_torch_buffer(running_var, exported_program):
123
- continue
124
-
125
- input_shape = extract_shape(input_)
126
- assert len(input_shape) == 4
127
- C = input_shape[1]
128
-
129
- weight_value = (
130
- get_torch_param_value(weight, exported_program)
131
- if weight
132
- else torch.tensor([1] * C)
95
+ args = NativeBatchNormLegitNoTrainingArgs(*node.args)
96
+ input_ = args.input
97
+ weight = args.weight
98
+ bias = args.bias
99
+ running_mean = args.running_mean
100
+ running_var = args.running_var
101
+ eps = args.eps
102
+
103
+ if not running_mean:
104
+ raise NotYetSupportedError(f"running_mean=None is not supported yet")
105
+ if not running_var:
106
+ raise NotYetSupportedError(f"running_var=None is not supported yet")
107
+
108
+ """
109
+ Only support the cases generated from torch.nn.BatchNorm2d module,
110
+ for which, let's checks if weight and bias are parameters and
111
+ running_mean and running_var are buffers.
112
+ """
113
+ if weight and not is_torch_param(weight, exported_program):
114
+ continue
115
+ if bias and not is_torch_param(bias, exported_program):
116
+ continue
117
+ if not is_torch_buffer(running_mean, exported_program):
118
+ continue
119
+ if not is_torch_buffer(running_var, exported_program):
120
+ continue
121
+
122
+ input_shape = extract_shape(input_)
123
+ assert len(input_shape) == 4
124
+ C = input_shape[1]
125
+
126
+ weight_value = (
127
+ get_torch_param_value(weight, exported_program)
128
+ if weight
129
+ else torch.tensor([1] * C)
130
+ )
131
+ bias_value = (
132
+ get_torch_param_value(bias, exported_program)
133
+ if bias
134
+ else torch.tensor([0] * C)
135
+ )
136
+ mean_value = get_torch_buffer_value(running_mean, exported_program)
137
+ var_value = get_torch_buffer_value(running_var, exported_program)
138
+
139
+ assert isinstance(weight_value, torch.Tensor)
140
+ assert isinstance(bias_value, torch.Tensor)
141
+ assert isinstance(mean_value, torch.Tensor)
142
+ assert isinstance(var_value, torch.Tensor)
143
+
144
+ assert (
145
+ weight_value.shape
146
+ == bias_value.shape
147
+ == mean_value.shape
148
+ == var_value.shape
149
+ )
150
+ # Calculate constants for mul and add
151
+ mul_const = weight_value / torch.sqrt(var_value + eps)
152
+ add_const = bias_value - (mul_const * mean_value)
153
+ # N, C, H, W
154
+ assert len(mul_const) == len(add_const) == C
155
+ # reshape along with channel dimension
156
+ mul_const = mul_const.view(1, mul_const.shape[0], 1, 1)
157
+ add_const = add_const.view(1, add_const.shape[0], 1, 1)
158
+
159
+ # Placeholder nodes must be the first N nodes in the nodes list of a graph.
160
+ # Therefore, insert the newly created placeholders at the start of the node list.
161
+ with exported_program.graph.inserting_before(
162
+ get_first_user_input(exported_program)
163
+ ):
164
+ mul_const_node = add_placeholder(
165
+ exported_program,
166
+ mul_const,
167
+ prefix=f"{node.name}_mul_const",
168
+ )
169
+ add_const_node = add_placeholder(
170
+ exported_program,
171
+ add_const,
172
+ prefix=f"{node.name}_add_const",
133
173
  )
134
- bias_value = (
135
- get_torch_param_value(bias, exported_program)
136
- if bias
137
- else torch.tensor([0] * C)
174
+
175
+ with gm.graph.inserting_before(node):
176
+ mul = graph.call_function(
177
+ torch.ops.aten.mul.Tensor,
178
+ args=(input_, mul_const_node),
138
179
  )
139
- mean_value = get_torch_buffer_value(running_mean, exported_program)
140
- var_value = get_torch_buffer_value(running_var, exported_program)
141
-
142
- assert isinstance(weight_value, torch.Tensor)
143
- assert isinstance(bias_value, torch.Tensor)
144
- assert isinstance(mean_value, torch.Tensor)
145
- assert isinstance(var_value, torch.Tensor)
146
-
147
- assert (
148
- weight_value.shape
149
- == bias_value.shape
150
- == mean_value.shape
151
- == var_value.shape
180
+ add = graph.call_function(
181
+ torch.ops.aten.add.Tensor,
182
+ args=(mul, add_const_node),
152
183
  )
153
- # Calculate constants for mul and add
154
- mul_const = weight_value / torch.sqrt(var_value + eps)
155
- add_const = bias_value - (mul_const * mean_value)
156
- # N, C, H, W
157
- assert len(mul_const) == len(add_const) == C
158
- # reshape along with channel dimension
159
- mul_const = mul_const.view(1, mul_const.shape[0], 1, 1)
160
- add_const = add_const.view(1, add_const.shape[0], 1, 1)
161
-
162
- # Placeholder nodes must be the first N nodes in the nodes list of a graph.
163
- # Therefore, insert the newly created placeholders at the start of the node list.
164
- with exported_program.graph.inserting_before(
165
- get_first_user_input(exported_program)
166
- ):
167
- mul_const_node = add_placeholder(
168
- exported_program,
169
- mul_const,
170
- prefix=f"{node.name}_mul_const",
171
- )
172
- add_const_node = add_placeholder(
173
- exported_program,
174
- add_const,
175
- prefix=f"{node.name}_add_const",
176
- )
177
-
178
- with gm.graph.inserting_before(node):
179
- mul = graph.call_function(
180
- torch.ops.aten.mul.Tensor,
181
- args=(input_, mul_const_node),
182
- )
183
- add = graph.call_function(
184
- torch.ops.aten.add.Tensor,
185
- args=(mul, add_const_node),
186
- )
187
- # Not set meta for propagating replacing get_item's meta.
188
- get_item, *_ = node.users.keys()
189
- get_item.replace_all_uses_with(add, propagate_meta=True)
190
-
191
- fill_meta_val(exported_program)
192
- modified = True
184
+ # Not set meta for propagating replacing get_item's meta.
185
+ get_item, *_ = node.users.keys()
186
+ get_item.replace_all_uses_with(add, propagate_meta=True)
187
+
188
+ fill_meta_val(exported_program)
189
+ logger.debug(f"{node.name} is decomposed to {mul.name} and {add.name}")
190
+ modified = True
193
191
 
194
192
  gm.graph.eliminate_dead_code()
195
193
  gm.graph.lint()
@@ -71,9 +71,9 @@ class DecomposeFakeQuantize(PassBase):
71
71
  gm = exported_program.graph_module
72
72
  qd = torch.ops.quantized_decomposed # type: ignore[return]
73
73
  for node in gm.graph.nodes:
74
- if node.op == "call_function" and node.target in [
75
- torch.ops.aten.fake_quantize_per_tensor_affine.default
76
- ]:
74
+ if node.op != "call_function":
75
+ continue
76
+ if node.target in [torch.ops.aten.fake_quantize_per_tensor_affine.default]:
77
77
  # tensor, scale, zero_p, quant_min, quant_max
78
78
  assert len(node.args) == 5
79
79
  _, _, _, quant_min, quant_max = node.args
@@ -97,9 +97,7 @@ class DecomposeFakeQuantize(PassBase):
97
97
  node.replace_all_uses_with(dequnt, propagate_meta=True)
98
98
  modified = True
99
99
 
100
- if node.op == "call_function" and node.target in [
101
- torch.ops.aten.fake_quantize_per_channel_affine.default
102
- ]:
100
+ if node.target in [torch.ops.aten.fake_quantize_per_channel_affine.default]:
103
101
  fq_args = FakeQuantizePerChannelArgs(*node.args, **node.kwargs)
104
102
  quant_min = fq_args.quant_min
105
103
  quant_max = fq_args.quant_max
@@ -202,9 +202,10 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
202
202
  gm = exported_program.graph_module
203
203
  qd = torch.ops.quantized_decomposed # type: ignore[return]
204
204
  for node in gm.graph.nodes:
205
+ if node.op != "call_function":
206
+ continue
205
207
  if (
206
- node.op == "call_function"
207
- and node.target
208
+ node.target
208
209
  == torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default
209
210
  ):
210
211
  # tensor, scale, zero_p, fake_quant_enabled, quant_min, quant_max
@@ -247,8 +248,7 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
247
248
  mask_user.args = ((mask_user.args[0][0],),)
248
249
  modified = True
249
250
  if (
250
- node.op == "call_function"
251
- and node.target
251
+ node.target
252
252
  == torch.ops.aten.fake_quantize_per_tensor_affine.tensor_qparams
253
253
  ):
254
254
  fq_args = FakeQuantizePerTensorTQParamArgs(*node.args, **node.kwargs)
@@ -25,6 +25,7 @@ from tico.serialize.circle_mapping import extract_shape
25
25
  from tico.utils import logging
26
26
  from tico.utils.passes import PassBase, PassResult
27
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import is_target_node
28
29
  from tico.utils.validate_args_kwargs import NativeGroupNormArgs, NativeLayerNormArgs
29
30
 
30
31
 
@@ -115,13 +116,13 @@ class DecomposeGroupNorm(PassBase):
115
116
  modified = False
116
117
 
117
118
  for node in graph.nodes:
118
- if node.op != "call_function":
119
- continue
120
-
121
- if node.target not in [
122
- torch.ops.aten.native_layer_norm.default,
123
- torch.ops.aten.native_group_norm.default,
124
- ]:
119
+ if not is_target_node(
120
+ node,
121
+ [
122
+ torch.ops.aten.native_layer_norm.default,
123
+ torch.ops.aten.native_group_norm.default,
124
+ ],
125
+ ):
125
126
  continue
126
127
 
127
128
  if node.target == torch.ops.aten.native_layer_norm.default:
@@ -26,6 +26,7 @@ from tico.utils.errors import InvalidArgumentError, NotYetSupportedError
26
26
  from tico.utils.graph import add_placeholder
27
27
  from tico.utils.passes import PassBase, PassResult
28
28
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
+ from tico.utils.utils import is_target_node
29
30
  from tico.utils.validate_args_kwargs import Conv2DArgs
30
31
 
31
32
 
@@ -88,9 +89,7 @@ class DecomposeGroupedConv2d(PassBase):
88
89
  modified = False
89
90
 
90
91
  for node in graph.nodes:
91
- if node.op != "call_function":
92
- continue
93
- if not node.target in ops.aten.conv2d:
92
+ if not is_target_node(node, ops.aten.conv2d):
94
93
  continue
95
94
 
96
95
  args = Conv2DArgs(*node.args)