tico 0.1.0.dev250610__py3-none-any.whl → 0.1.0.dev250611__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. tico/__init__.py +1 -1
  2. tico/passes/cast_aten_where_arg_type.py +65 -62
  3. tico/passes/cast_mixed_type_args.py +2 -5
  4. tico/passes/convert_conv1d_to_conv2d.py +3 -4
  5. tico/passes/convert_repeat_to_expand_copy.py +5 -9
  6. tico/passes/decompose_addmm.py +41 -48
  7. tico/passes/decompose_batch_norm.py +97 -99
  8. tico/passes/decompose_fake_quantize.py +4 -6
  9. tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -4
  10. tico/passes/decompose_group_norm.py +8 -7
  11. tico/passes/decompose_grouped_conv2d.py +2 -3
  12. tico/passes/decompose_slice_scatter.py +2 -4
  13. tico/passes/extract_dtype_kwargs.py +2 -1
  14. tico/passes/fuse_leading_unsqueeze_reshape.py +4 -7
  15. tico/passes/fuse_redundant_reshape_to_mean.py +3 -5
  16. tico/passes/legalize_causal_mask_value.py +2 -7
  17. tico/passes/legalize_predefined_layout_operators.py +2 -3
  18. tico/passes/lower_pow2_to_mul.py +5 -7
  19. tico/passes/lower_to_resize_nearest_neighbor.py +6 -10
  20. tico/passes/lower_to_slice.py +3 -9
  21. tico/passes/merge_consecutive_cat.py +2 -4
  22. tico/passes/remove_nop.py +2 -3
  23. tico/passes/remove_redundant_assert_nodes.py +2 -1
  24. tico/passes/remove_redundant_expand.py +5 -9
  25. tico/passes/remove_redundant_permute.py +6 -5
  26. tico/passes/remove_redundant_reshape.py +17 -34
  27. tico/passes/remove_redundant_slice.py +2 -4
  28. tico/passes/remove_redundant_to_copy.py +2 -4
  29. tico/passes/segment_index_select.py +2 -4
  30. tico/serialize/operators/op_where.py +2 -2
  31. tico/utils/utils.py +4 -6
  32. {tico-0.1.0.dev250610.dist-info → tico-0.1.0.dev250611.dist-info}/METADATA +1 -1
  33. {tico-0.1.0.dev250610.dist-info → tico-0.1.0.dev250611.dist-info}/RECORD +37 -37
  34. {tico-0.1.0.dev250610.dist-info → tico-0.1.0.dev250611.dist-info}/LICENSE +0 -0
  35. {tico-0.1.0.dev250610.dist-info → tico-0.1.0.dev250611.dist-info}/WHEEL +0 -0
  36. {tico-0.1.0.dev250610.dist-info → tico-0.1.0.dev250611.dist-info}/entry_points.txt +0 -0
  37. {tico-0.1.0.dev250610.dist-info → tico-0.1.0.dev250611.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,7 @@ from tico.serialize.circle_mapping import extract_shape
25
25
  from tico.utils import logging
26
26
  from tico.utils.passes import PassBase, PassResult
27
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
- from tico.utils.utils import enforce_type
28
+ from tico.utils.utils import enforce_type, is_target_node
29
29
 
30
30
 
31
31
  @trace_graph_diff_on_pass
@@ -82,9 +82,7 @@ class DecomposeSliceScatter(PassBase):
82
82
  modified = False
83
83
 
84
84
  for node in graph.nodes:
85
- if node.op != "call_function":
86
- continue
87
- if node.target != torch.ops.aten.slice_scatter.default:
85
+ if not is_target_node(node, torch.ops.aten.slice_scatter.default):
88
86
  continue
89
87
 
90
88
  @enforce_type
@@ -23,6 +23,7 @@ from torch.utils import _pytree as pytree
23
23
  from tico.utils import logging
24
24
  from tico.utils.passes import PassBase, PassResult
25
25
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+ from tico.utils.utils import is_target_node
26
27
 
27
28
 
28
29
  def _extract_to_output(node: torch.fx.Node, graph: torch.fx.Graph) -> bool:
@@ -107,7 +108,7 @@ class ExtractDtypeKwargsPass(PassBase):
107
108
  graph: torch.fx.Graph = graph_module.graph
108
109
  modified = False
109
110
  for node in graph.nodes:
110
- if not node.op == "call_function" or node.target not in self.target_ops:
111
+ if not is_target_node(node, list(self.target_ops.keys())):
111
112
  continue
112
113
  if "dtype" not in node.kwargs:
113
114
  continue
@@ -22,7 +22,7 @@ from tico.serialize.circle_mapping import extract_shape
22
22
  from tico.utils import logging
23
23
  from tico.utils.passes import PassBase, PassResult
24
24
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
- from tico.utils.utils import is_single_use_target_node
25
+ from tico.utils.utils import is_target_node
26
26
  from tico.utils.validate_args_kwargs import PermuteArgs, ReshapeArgs
27
27
 
28
28
 
@@ -55,20 +55,17 @@ class FuseLeadingUnsqueezeReshape(PassBase):
55
55
  graph = gm.graph
56
56
  modified = False
57
57
  for reshape_back in graph.nodes:
58
- if (
59
- reshape_back.op != "call_function"
60
- or reshape_back.target not in ops.aten.reshape
61
- ):
58
+ if not is_target_node(reshape_back, ops.aten.reshape):
62
59
  continue
63
60
  reshape_back_args = ReshapeArgs(*reshape_back.args, **reshape_back.kwargs) # type: ignore[arg-type]
64
61
  permute, reshape_back_size = reshape_back_args.input, reshape_back_args.size
65
62
 
66
- if not is_single_use_target_node(permute, ops.aten.permute):
63
+ if not is_target_node(permute, ops.aten.permute):
67
64
  continue
68
65
  permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
69
66
  reshape_front, permute_dims = permute_args.input, permute_args.dims
70
67
 
71
- if not is_single_use_target_node(reshape_front, ops.aten.reshape):
68
+ if not is_target_node(reshape_front, ops.aten.reshape):
72
69
  continue
73
70
  reshape_front_args = ReshapeArgs(*reshape_front.args, **reshape_front.kwargs) # type: ignore[arg-type]
74
71
  reshape_front_input, reshape_front_size = (
@@ -25,6 +25,7 @@ from tico.serialize.circle_mapping import extract_shape
25
25
  from tico.utils import logging
26
26
  from tico.utils.passes import PassBase, PassResult
27
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import is_target_node
28
29
 
29
30
 
30
31
  @trace_graph_diff_on_pass
@@ -45,10 +46,7 @@ class FuseRedundantReshapeToMean(PassBase):
45
46
  graph = graph_module.graph
46
47
  modified = False
47
48
  for node in graph.nodes:
48
- if not node.op == "call_function":
49
- continue
50
-
51
- if node.target != torch.ops.aten.mean.dim:
49
+ if not is_target_node(node, torch.ops.aten.mean.dim):
52
50
  continue
53
51
 
54
52
  # If mean is being used in other nodes, do not fuse it.
@@ -56,7 +54,7 @@ class FuseRedundantReshapeToMean(PassBase):
56
54
  continue
57
55
 
58
56
  user_node = next(iter(node.users))
59
- if user_node.target not in ops.aten.reshape:
57
+ if not is_target_node(user_node, ops.aten.reshape):
60
58
  continue
61
59
 
62
60
  mean_args, mean_kwargs = pytree.tree_map_only(
@@ -20,10 +20,10 @@ import torch
20
20
  from torch.export import ExportedProgram
21
21
 
22
22
  from tico.passes import ops
23
-
24
23
  from tico.utils import logging
25
24
  from tico.utils.passes import PassBase, PassResult
26
25
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+ from tico.utils.utils import is_target_node
27
27
  from tico.utils.validate_args_kwargs import AddTensorArgs
28
28
 
29
29
 
@@ -53,14 +53,9 @@ class LegalizeCausalMaskValue(PassBase):
53
53
  graph = graph_module.graph
54
54
  modified = False
55
55
  for node in graph.nodes:
56
- if not node.op == "call_function":
57
- continue
58
-
59
- if not node.target in ops.aten.add:
56
+ if not is_target_node(node, ops.aten.add):
60
57
  continue
61
58
 
62
- assert len(node.args) == 2
63
-
64
59
  args = AddTensorArgs(*node.args, **node.kwargs)
65
60
  input = args.input
66
61
  other = args.other
@@ -25,6 +25,7 @@ from tico.utils import logging
25
25
  from tico.utils.errors import NotYetSupportedError
26
26
  from tico.utils.passes import PassBase, PassResult
27
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import is_target_node
28
29
  from tico.utils.validate_args_kwargs import (
29
30
  AvgPool2dArgs,
30
31
  Conv2DArgs,
@@ -363,11 +364,9 @@ class LegalizePreDefinedLayoutOperators(PassBase):
363
364
  graph = graph_module.graph
364
365
  modified = False
365
366
  for node in graph.nodes:
366
- if not node.op == "call_function":
367
+ if not is_target_node(node, list(target_to_legalize_func.keys())):
367
368
  continue
368
369
 
369
- if node.target not in target_to_legalize_func:
370
- continue
371
370
  modified |= target_to_legalize_func[node.target](exported_program, node)
372
371
 
373
372
  graph.eliminate_dead_code()
@@ -22,6 +22,8 @@ from torch.export import ExportedProgram
22
22
  from tico.utils import logging
23
23
  from tico.utils.passes import PassBase, PassResult
24
24
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
+ from tico.utils.utils import is_target_node
26
+ from tico.utils.validate_args_kwargs import PowTensorScalarArgs
25
27
 
26
28
 
27
29
  @trace_graph_diff_on_pass
@@ -42,15 +44,11 @@ class LowerPow2ToMul(PassBase):
42
44
  graph = graph_module.graph
43
45
  modified = False
44
46
  for node in graph.nodes:
45
- if not node.op == "call_function":
47
+ if not is_target_node(node, torch.ops.aten.pow.Tensor_Scalar):
46
48
  continue
47
49
 
48
- if node.target != torch.ops.aten.pow.Tensor_Scalar:
49
- continue
50
-
51
- assert len(node.args) == 2, len(node.args)
52
- in_, exp = node.args
53
- assert isinstance(in_, torch.fx.Node), type(in_)
50
+ args = PowTensorScalarArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
51
+ in_, exp = args.input, args.exponent
54
52
 
55
53
  if exp != 2:
56
54
  continue
@@ -12,12 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING
15
+ from typing import Optional, TYPE_CHECKING
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  import torch.fx
19
- from typing import Optional
20
-
21
19
  import torch
22
20
  from torch.export import ExportedProgram
23
21
 
@@ -26,6 +24,7 @@ from tico.utils import logging
26
24
  from tico.utils.errors import NotYetSupportedError
27
25
  from tico.utils.passes import PassBase, PassResult
28
26
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
+ from tico.utils.utils import is_target_node
29
28
  from tico.utils.validate_args_kwargs import IndexArgs, UpsampleNearest2DVecArgs
30
29
 
31
30
 
@@ -194,13 +193,10 @@ class LowerToResizeNearestNeighbor(PassBase):
194
193
  graph_module = exported_program.graph_module
195
194
  graph = graph_module.graph
196
195
  for node in graph.nodes:
197
- if not node.op == "call_function":
198
- continue
199
-
200
- if node.target not in [
201
- torch.ops.aten.index.Tensor,
202
- torch.ops.aten.upsample_nearest2d.vec,
203
- ]:
196
+ if not is_target_node(
197
+ node,
198
+ [torch.ops.aten.index.Tensor, torch.ops.aten.upsample_nearest2d.vec],
199
+ ):
204
200
  continue
205
201
 
206
202
  resize_nearest_neighbor = None
@@ -28,13 +28,12 @@ from torch._export.utils import (
28
28
  from torch.export import ExportedProgram
29
29
 
30
30
  from tico.passes import ops
31
-
32
31
  from tico.serialize.circle_graph import extract_shape
33
32
  from tico.utils import logging
34
-
35
33
  from tico.utils.graph import is_single_value_tensor
36
34
  from tico.utils.passes import PassBase, PassResult
37
35
  from tico.utils.trace_decorators import trace_const_diff_on_pass
36
+ from tico.utils.utils import is_target_node
38
37
  from tico.utils.validate_args_kwargs import IndexSelectArgs, SelectCopyIntArgs
39
38
 
40
39
 
@@ -85,10 +84,7 @@ class LowerSelectCopyToSlice(PassBase):
85
84
  graph = graph_module.graph
86
85
  modified = False
87
86
  for node in graph.nodes:
88
- if not node.op == "call_function":
89
- continue
90
-
91
- if not node.target in ops.aten.select:
87
+ if not is_target_node(node, ops.aten.select):
92
88
  continue
93
89
 
94
90
  args = SelectCopyIntArgs(*node.args, **node.kwargs)
@@ -163,11 +159,9 @@ class LowerIndexSelectToSlice(PassBase):
163
159
  graph = graph_module.graph
164
160
  modified = False
165
161
  for node in graph.nodes:
166
- if not node.op == "call_function":
162
+ if not is_target_node(node, ops.aten.index_select):
167
163
  continue
168
164
 
169
- if not node.target in ops.aten.index_select:
170
- continue
171
165
  args = IndexSelectArgs(*node.args, **node.kwargs)
172
166
  input = args.input
173
167
  dim = args.dim
@@ -18,6 +18,7 @@ from tico.passes import ops
18
18
  from tico.utils import logging
19
19
  from tico.utils.passes import PassBase, PassResult
20
20
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
21
+ from tico.utils.utils import is_target_node
21
22
  from tico.utils.validate_args_kwargs import CatArgs
22
23
 
23
24
 
@@ -37,10 +38,7 @@ class MergeConsecutiveCat(PassBase):
37
38
  graph = graph_module.graph
38
39
  modified = False
39
40
  for cat in graph.nodes:
40
- if not cat.op == "call_function":
41
- continue
42
-
43
- if not cat.target in ops.aten.cat:
41
+ if not is_target_node(cat, ops.aten.cat):
44
42
  continue
45
43
 
46
44
  args = CatArgs(*cat.args, **cat.kwargs) # type: ignore[arg-type]
tico/passes/remove_nop.py CHANGED
@@ -23,6 +23,7 @@ from tico.passes import ops
23
23
  from tico.utils import logging
24
24
  from tico.utils.passes import PassBase, PassResult
25
25
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+ from tico.utils.utils import is_target_node
26
27
 
27
28
 
28
29
  @trace_graph_diff_on_pass
@@ -51,11 +52,9 @@ class RemoveNop(PassBase):
51
52
  graph = graph_module.graph
52
53
  modified = False
53
54
  for node in graph.nodes:
54
- if not node.op == "call_function":
55
+ if not is_target_node(node, RemoveNop.target_ops):
55
56
  continue
56
57
 
57
- if not node.target in RemoveNop.target_ops:
58
- continue
59
58
  # TODO Consider memory format
60
59
  if node.target in ops.aten.clone and "memory_format" in node.kwargs:
61
60
  if node.kwargs["memory_format"] not in [
@@ -17,6 +17,7 @@ from torch.export import ExportedProgram
17
17
 
18
18
  from tico.utils.passes import PassBase, PassResult
19
19
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
20
+ from tico.utils.utils import is_target_node
20
21
 
21
22
 
22
23
  assert_node_targets = [
@@ -39,7 +40,7 @@ class RemoveRedundantAssertionNodes(PassBase):
39
40
  graph = graph_module.graph
40
41
  modified = False
41
42
  for node in graph.nodes:
42
- if node.op == "call_function" and node.target in assert_node_targets:
43
+ if is_target_node(node, assert_node_targets):
43
44
  graph.erase_node(node)
44
45
  modified = True
45
46
 
@@ -24,6 +24,8 @@ from tico.serialize.circle_mapping import extract_shape
24
24
  from tico.utils import logging
25
25
  from tico.utils.passes import PassBase, PassResult
26
26
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
+ from tico.utils.utils import is_target_node
28
+ from tico.utils.validate_args_kwargs import ExpandArgs
27
29
 
28
30
 
29
31
  @trace_graph_diff_on_pass
@@ -42,17 +44,11 @@ class RemoveRedundantExpand(PassBase):
42
44
  graph = graph_module.graph
43
45
  modified = False
44
46
  for node in graph.nodes:
45
- if not node.op == "call_function":
47
+ if not is_target_node(node, ops.aten.expand):
46
48
  continue
47
49
 
48
- if not node.target in ops.aten.expand:
49
- continue
50
-
51
- assert len(node.args) == 2
52
-
53
- input, size = list(node.args)
54
- assert isinstance(input, torch.fx.Node), type(input)
55
- assert isinstance(size, list), type(size)
50
+ args = ExpandArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
51
+ input, size = args.input, args.size
56
52
 
57
53
  input_shape = extract_shape(input)
58
54
  if list(input_shape) != size:
@@ -14,16 +14,18 @@
14
14
 
15
15
  from typing import TYPE_CHECKING
16
16
 
17
+
17
18
  if TYPE_CHECKING:
18
19
  import torch.fx
19
20
  import torch
20
21
  from torch.export import ExportedProgram
21
22
 
22
23
  from tico.passes import ops
23
- from tico.serialize.circle_mapping import extract_shape, extract_stride
24
+ from tico.serialize.circle_mapping import extract_shape
24
25
  from tico.utils import logging
25
26
  from tico.utils.passes import PassBase, PassResult
26
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import is_target_node
27
29
  from tico.utils.validate_args_kwargs import PermuteArgs
28
30
 
29
31
 
@@ -75,16 +77,15 @@ class RemoveRedundantPermutePattern1(PassBase):
75
77
  graph = graph_module.graph
76
78
  modified = False
77
79
  for permute2 in graph.nodes:
78
- if not permute2.op == "call_function":
79
- continue
80
- if not permute2.target in ops.aten.permute:
80
+ if not is_target_node(permute2, ops.aten.permute):
81
81
  continue
82
+
82
83
  if len(permute2.users) != 1:
83
84
  continue
84
85
  permute2_args = PermuteArgs(*permute2.args, **permute2.kwargs) # type: ignore[arg-type]
85
86
  permute1, permute2_dims = permute2_args.input, permute2_args.dims
86
87
 
87
- if not permute1.target in ops.aten.permute:
88
+ if not is_target_node(permute1, ops.aten.permute):
88
89
  continue
89
90
  if len(permute1.users) != 1:
90
91
  continue
@@ -24,7 +24,7 @@ from tico.serialize.circle_mapping import extract_shape
24
24
  from tico.utils import logging
25
25
  from tico.utils.passes import PassBase, PassResult
26
26
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
- from tico.utils.utils import broadcastable, set_new_meta_val
27
+ from tico.utils.utils import broadcastable, is_target_node, set_new_meta_val
28
28
  from tico.utils.validate_args_kwargs import (
29
29
  AddTensorArgs,
30
30
  PermuteArgs,
@@ -67,12 +67,10 @@ class RemoveRedundantReshapePattern1(PassBase):
67
67
  graph = graph_module.graph
68
68
  modified = False
69
69
  for reshape1 in graph.nodes:
70
- if not reshape1.op == "call_function":
71
- continue
72
-
73
70
  ### first reshape
74
- if not reshape1.target in ops.aten.reshape:
71
+ if not is_target_node(reshape1, ops.aten.reshape):
75
72
  continue
73
+
76
74
  # Assumes that other node do not use ops in the pattern for simplisity.
77
75
  if len(reshape1.users) != 1:
78
76
  continue
@@ -86,7 +84,7 @@ class RemoveRedundantReshapePattern1(PassBase):
86
84
 
87
85
  ### permute
88
86
  permute = next(iter(reshape1.users))
89
- if not permute.target in ops.aten.permute:
87
+ if not is_target_node(permute, ops.aten.permute):
90
88
  continue
91
89
  if len(permute.users) != 1:
92
90
  continue
@@ -98,14 +96,14 @@ class RemoveRedundantReshapePattern1(PassBase):
98
96
 
99
97
  ### mul
100
98
  mul = next(iter(permute.users))
101
- if not mul.target in RemoveRedundantReshapePattern1.mul_ops:
99
+ if not is_target_node(mul, RemoveRedundantReshapePattern1.mul_ops):
102
100
  continue
103
101
  if len(mul.users) != 1:
104
102
  continue
105
103
 
106
104
  ### second reshape
107
105
  reshape2 = next(iter(mul.users))
108
- if not reshape2.target in ops.aten.reshape:
106
+ if not is_target_node(reshape2, ops.aten.reshape):
109
107
  continue
110
108
  if len(reshape2.users) != 1:
111
109
  continue
@@ -153,11 +151,8 @@ class RemoveRedundantReshapePattern2(PassBase):
153
151
  graph = graph_module.graph
154
152
  modified = False
155
153
  for reshape1 in graph.nodes:
156
- if not reshape1.op == "call_function":
157
- continue
158
-
159
154
  ### first reshape
160
- if not reshape1.target in ops.aten.reshape:
155
+ if not is_target_node(reshape1, ops.aten.reshape):
161
156
  continue
162
157
  if len(reshape1.users) != 1:
163
158
  continue
@@ -171,7 +166,7 @@ class RemoveRedundantReshapePattern2(PassBase):
171
166
 
172
167
  ### permute
173
168
  permute = next(iter(reshape1.users))
174
- if not permute.target in ops.aten.permute:
169
+ if not is_target_node(permute, ops.aten.permute):
175
170
  continue
176
171
  if len(permute.users) != 1:
177
172
  continue
@@ -183,7 +178,7 @@ class RemoveRedundantReshapePattern2(PassBase):
183
178
 
184
179
  ### second reshape
185
180
  reshape2 = next(iter(permute.users))
186
- if not reshape2.target in ops.aten.reshape:
181
+ if not is_target_node(reshape2, ops.aten.reshape):
187
182
  continue
188
183
  if len(reshape2.users) != 1:
189
184
  continue
@@ -239,20 +234,14 @@ class RemoveRedundantReshapePattern3(PassBase):
239
234
  graph = graph_module.graph
240
235
  modified = False
241
236
  for reshape_1 in graph.nodes:
242
- assert isinstance(reshape_1, torch.fx.Node), type(reshape_1)
243
237
  # reshape_1
244
- if not reshape_1.op == "call_function":
245
- continue
246
- if not reshape_1.target in ops.aten.reshape:
238
+ if not is_target_node(reshape_1, ops.aten.reshape):
247
239
  continue
248
240
  reshape_1_args = ReshapeArgs(*reshape_1.args, **reshape_1.kwargs) # type: ignore[arg-type]
249
- softmax, reshape_1_size = reshape_1_args.input, reshape_1_args.size
241
+ softmax = reshape_1_args.input
250
242
 
251
243
  # softmax
252
- assert isinstance(softmax, torch.fx.Node), type(softmax)
253
- if not softmax.op == "call_function":
254
- continue
255
- if not softmax.target in ops.aten.softmax:
244
+ if not is_target_node(softmax, ops.aten.softmax):
256
245
  continue
257
246
  if softmax.target == torch.ops.aten._softmax.default:
258
247
  softmax_args = SoftmaxArgs(*softmax.args, **softmax.kwargs) # type: ignore[arg-type, assignment]
@@ -354,10 +343,9 @@ class RemoveRedundantReshapePattern4(PassBase):
354
343
  modified = False
355
344
  for reshape1 in graph.nodes:
356
345
  # reshape_1
357
- if not reshape1.op == "call_function":
358
- continue
359
- if not reshape1.target in ops.aten.reshape:
346
+ if not is_target_node(reshape1, ops.aten.reshape):
360
347
  continue
348
+
361
349
  reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs) # type: ignore[arg-type]
362
350
  reshape1_input, size = reshape1_args.input, reshape1_args.size
363
351
  assert isinstance(reshape1_input, torch.fx.Node), type(reshape1_input)
@@ -370,12 +358,10 @@ class RemoveRedundantReshapePattern4(PassBase):
370
358
 
371
359
  # reshape_2
372
360
  reshape2 = next(iter(reshape1.users))
373
- if not reshape2.op == "call_function":
361
+ if not is_target_node(reshape2, ops.aten.reshape):
374
362
  continue
375
- if not reshape2.target in ops.aten.reshape:
376
- continue
377
- reshape2_args = ReshapeArgs(*reshape2.args, **reshape2.kwargs) # type: ignore[arg-type]
378
363
 
364
+ reshape2_args = ReshapeArgs(*reshape2.args, **reshape2.kwargs) # type: ignore[arg-type]
379
365
  reshape2_input, reshape2_size = reshape2_args.input, reshape2_args.size
380
366
  assert isinstance(reshape2_input, torch.fx.Node), type(reshape2_input)
381
367
  assert isinstance(reshape2_size, list), type(reshape2_size)
@@ -420,10 +406,7 @@ class RemoveRedundantReshapePattern5(PassBase):
420
406
  modified = False
421
407
 
422
408
  for node in graph.nodes:
423
- if not node.op == "call_function":
424
- continue
425
-
426
- if not node.target in ops.aten.reshape:
409
+ if not is_target_node(node, ops.aten.reshape):
427
410
  continue
428
411
 
429
412
  args = ReshapeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
@@ -19,6 +19,7 @@ from tico.serialize.circle_mapping import extract_shape
19
19
  from tico.utils import logging
20
20
  from tico.utils.passes import PassBase, PassResult
21
21
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
22
+ from tico.utils.utils import is_target_node
22
23
  from tico.utils.validate_args_kwargs import SliceArgs
23
24
 
24
25
 
@@ -38,10 +39,7 @@ class RemoveRedundantSlice(PassBase):
38
39
  graph = graph_module.graph
39
40
  modified = False
40
41
  for node in graph.nodes:
41
- if not node.op == "call_function":
42
- continue
43
-
44
- if not node.target in ops.aten.slice:
42
+ if not is_target_node(node, ops.aten.slice):
45
43
  continue
46
44
 
47
45
  args = SliceArgs(*node.args, **node.kwargs)
@@ -22,6 +22,7 @@ from tico.serialize.circle_mapping import extract_torch_dtype
22
22
  from tico.utils import logging
23
23
  from tico.utils.passes import PassBase, PassResult
24
24
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
+ from tico.utils.utils import is_target_node
25
26
  from tico.utils.validate_args_kwargs import ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs
26
27
 
27
28
 
@@ -41,10 +42,7 @@ class RemoveRedundantToCopy(PassBase):
41
42
  graph = graph_module.graph
42
43
  modified = False
43
44
  for node in graph.nodes:
44
- if not node.op == "call_function":
45
- continue
46
-
47
- if not node.target in ops.aten.to_copy:
45
+ if not is_target_node(node, ops.aten.to_copy):
48
46
  continue
49
47
 
50
48
  args: Union[ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs]
@@ -34,6 +34,7 @@ from tico.utils import logging
34
34
  from tico.utils.graph import add_placeholder, is_single_value_tensor
35
35
  from tico.utils.passes import PassBase, PassResult
36
36
  from tico.utils.trace_decorators import trace_const_diff_on_pass
37
+ from tico.utils.utils import is_target_node
37
38
  from tico.utils.validate_args_kwargs import IndexSelectArgs
38
39
 
39
40
 
@@ -78,10 +79,7 @@ class SegmentIndexSelectConst(PassBase):
78
79
  graph = graph_module.graph
79
80
  modified = False
80
81
  for node in graph.nodes:
81
- if not node.op == "call_function":
82
- continue
83
-
84
- if not node.target in ops.aten.index_select:
82
+ if not is_target_node(node, ops.aten.index_select):
85
83
  continue
86
84
 
87
85
  args = IndexSelectArgs(*node.args, **node.kwargs)
@@ -52,12 +52,12 @@ class WhereVisitor(NodeVisitor):
52
52
 
53
53
  result_true_dtype = (
54
54
  extract_torch_dtype(input)
55
- if isinstance(input, torch.fx.node.Node)
55
+ if isinstance(input, torch.fx.Node)
56
56
  else input.dtype # type: ignore[union-attr]
57
57
  )
58
58
  result_false_dtype = (
59
59
  extract_torch_dtype(other)
60
- if isinstance(other, torch.fx.node.Node)
60
+ if isinstance(other, torch.fx.Node)
61
61
  else other.dtype # type: ignore[union-attr]
62
62
  )
63
63
 
tico/utils/utils.py CHANGED
@@ -331,6 +331,7 @@ def get_quant_dtype(qmin: int, qmax: int):
331
331
  """
332
332
  known_ranges = {
333
333
  (-32768, 32767): "int16",
334
+ (-32767, 32767): "int16",
334
335
  (0, 65535): "uint16",
335
336
  (-128, 127): "int8",
336
337
  (0, 255): "uint8",
@@ -380,19 +381,18 @@ def broadcastable(
380
381
  return True
381
382
 
382
383
 
383
- def is_single_use_target_node(
384
+ def is_target_node(
384
385
  node: torch.fx.Node, target_ops: list[torch._ops.OpOverload] | torch._ops.OpOverload
385
386
  ):
386
387
  """
387
- Check whether a given node is a `call_function` node that matches one of the specified targets
388
- and is used by only one other node.
388
+ Check whether a given node is a `call_function` node that matches one of the specified targets.
389
389
 
390
390
  Args:
391
391
  node (torch.fx.Node): The node to check.
392
392
  target_ops (Iterable[Callable]): A list or set of target operations to match (e.g., ops.aten.reshape).
393
393
 
394
394
  Returns:
395
- bool: True if the node is a call_function, its target is in `target_ops`, and it has exactly one user.
395
+ bool: True if the node is a call_function, its target is in `target_ops`.
396
396
  """
397
397
  if not isinstance(target_ops, list):
398
398
  target_ops = [target_ops]
@@ -402,7 +402,5 @@ def is_single_use_target_node(
402
402
  return False
403
403
  if node.target not in target_ops:
404
404
  return False
405
- if len(node.users) != 1:
406
- return False
407
405
 
408
406
  return True
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250610
3
+ Version: 0.1.0.dev250611
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN