tico 0.1.0.dev250414__py3-none-any.whl → 0.1.0.dev250415__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -22,7 +22,7 @@ from tico.config import CompileConfigV1, get_default_config
22
22
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
23
23
 
24
24
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
25
- __version__ = '0.1.0.dev250414'
25
+ __version__ = '0.1.0.dev250415'
26
26
 
27
27
 
28
28
  if Version(torch.__version__) < Version("2.5.0"):
@@ -32,6 +32,7 @@ from tico.utils.validate_args_kwargs import (
32
32
  LinearArgs,
33
33
  MulTensorArgs,
34
34
  PermuteArgs,
35
+ ReshapeArgs,
35
36
  )
36
37
 
37
38
 
@@ -102,6 +103,26 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
102
103
  output dtype (int16) is updated to the input dtype (uint8), which breaks the semantics.
103
104
  This problem can occur in the tools (ex: circle2circle) that automatically apply type inference.
104
105
  - To resolve the issue, we insert quantize operators not to violate circle's type inference logic.
106
+ - NOTE For some cases, Quantize Op is inserted before the operators.
107
+
108
+ Let's assume Reshape Op's input is int16 and output is uint8. There are two possible places to insert
109
+ Quantize Op.
110
+
111
+ 1. Insert Quantize before Reshape.
112
+
113
+ ```
114
+ Predecessor (int16)-> Quantize (uint8) -> Reshape (uint8) -> ...
115
+ ```
116
+
117
+ 2. Insert Quantize after Reshape.
118
+
119
+ ```
120
+ Predecessor (int16)-> Reshape (int16) -> Quantize (uint8) -> ...
121
+ ```
122
+
123
+ Comparing 1) and 2), the difference is that Reshape operation is conducted in uint8 or int16.
124
+ We go with 1), which does Reshape in uint8, for faster execution. Note that Reshape Op does not
125
+ change the value, so its dytpe does not affect accuracy.
105
126
  """
106
127
 
107
128
  def __init__(self):
@@ -264,7 +285,38 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
264
285
  quantize = _insert_quantize_op_before(node, inp)
265
286
 
266
287
  quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
267
- node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
288
+ logger.debug(
289
+ f"quantize_per_tensor.default is inserted before {node.name}."
290
+ )
291
+ elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
292
+ quantize = _insert_quantize_op_after(node)
293
+
294
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
295
+ node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
296
+ logger.debug(
297
+ f"quantize_per_tensor.default is inserted after {node.name}."
298
+ )
299
+ else:
300
+ raise NotYetSupportedError("Unsupported dtype")
301
+ elif node.target == torch.ops.aten.reshape.default:
302
+ reshape_args = ReshapeArgs(*node.args, **node.kwargs)
303
+ inp = reshape_args.input
304
+
305
+ if QPARAM_KEY not in inp.meta:
306
+ continue
307
+
308
+ if QPARAM_KEY not in node.meta:
309
+ continue
310
+
311
+ if qparam_dtype(inp) == qparam_dtype(node):
312
+ continue
313
+
314
+ if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
315
+ # A new Quantize Op (s16 to u8) is inserted before (not after)
316
+ # reshape Op to reduce tensor size ealier
317
+ quantize = _insert_quantize_op_before(node, inp)
318
+
319
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
268
320
  logger.debug(
269
321
  f"quantize_per_tensor.default is inserted before {node.name}."
270
322
  )
@@ -35,32 +35,6 @@ class LowerToResizeNearestNeighbor(PassBase):
35
35
  This pass lowers `aten.index` and `aten.upsample_nearest2d.vec` to `circle_custom.resize_nearest_neighbor` when it is possible.
36
36
 
37
37
  Until torch 2.7, `torch.nn.functional.interpolate` is converted to `aten.index` op.
38
-
39
- [EXAMPLE]
40
- class InterpolateDouble(torch.nn.Module):
41
- def __init__(self):
42
- super().__init__()
43
-
44
- def forward(self, x):
45
- return torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
46
-
47
- def get_example_inputs(self):
48
- return (torch.randn(1, 2, 3, 4),)
49
-
50
- [EXPORTED GRAPH]
51
- [constants]
52
- _prop_tensor_constant0 = tensor([0, 0, 1, 1, 2, 2, 3, 3]
53
- _prop_tensor_constant1 = tensor([[0], [0], [1], [1], [2], [2]])
54
-
55
- [graph]
56
- %_prop_tensor_constant0 : [num_users=1] = placeholder[target=_prop_tensor_constant0]
57
- %_prop_tensor_constant1 : [num_users=1] = placeholder[target=_prop_tensor_constant1]
58
- %x : [num_users=1] = placeholder[target=x]
59
- %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.float32})
60
- %index : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%_to_copy, [None, None, %_prop_tensor_constant1, %_prop_tensor_constant0]), kwargs = {})
61
- %_to_copy_3 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%index,), kwargs = {dtype: torch.float32})
62
- return (_to_copy_3,)
63
-
64
38
  [BEFORE PASS]
65
39
  input - aten.index - output
66
40
 
@@ -68,6 +42,11 @@ class LowerToResizeNearestNeighbor(PassBase):
68
42
  input - aten.permute(NCHW_to_NHWC) - circle_custom.resize_nearest_neighbor - aten.permute(NHWC_to_NCHW) - output
69
43
 
70
44
  Since torch 2.8, `torch.nn.functional.interpolate` is converted to aten.upsample_nearest2d.vec` op.
45
+ [BEFORE PASS]
46
+ input - aten.upsample_nearest2d.vec - output
47
+
48
+ [AFTER PASS]
49
+ input - aten.permute(NCHW_to_NHWC) - circle_custom.resize_nearest_neighbor - aten.permute(NHWC_to_NCHW) - output
71
50
  """
72
51
 
73
52
  def __init__(self):
@@ -17,37 +17,56 @@ from typing import TYPE_CHECKING
17
17
  if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
+ from torch._export.utils import (
21
+ get_buffer,
22
+ get_lifted_tensor_constant,
23
+ get_param,
24
+ is_buffer,
25
+ is_lifted_tensor_constant,
26
+ is_param,
27
+ )
20
28
  from torch.export import ExportedProgram
21
29
 
22
30
  from tico.passes import ops
23
31
 
24
32
  from tico.serialize.circle_graph import extract_shape
25
33
  from tico.utils import logging
34
+
35
+ from tico.utils.graph import is_single_value_tensor
26
36
  from tico.utils.passes import PassBase, PassResult
27
37
  from tico.utils.trace_decorators import trace_const_diff_on_pass
28
- from tico.utils.validate_args_kwargs import SelectCopyIntArgs
38
+ from tico.utils.validate_args_kwargs import IndexSelectArgs, SelectCopyIntArgs
29
39
 
30
40
 
31
- @trace_const_diff_on_pass
32
- class LowerToSlice(PassBase):
41
+ def passes():
33
42
  """
34
43
  This pass lowers aten.ops.select/selct_copy.int to aten.ops.slice.
35
44
  We support only when it is index in args, which is a constant tensor.
36
45
  Since the index in node'args isn't constant tensor, we can't support converting the below op list yet.
37
- - torch.ops.aten.index_select.default
46
+
47
+ TODO Support below with const indices
38
48
  - torch.ops.aten.embedding.default
39
49
  - torch.ops.aten.index.Tensor
50
+ """
51
+ return [
52
+ LowerSelectCopyToSlice(),
53
+ LowerIndexSelectToSlice(),
54
+ ]
55
+
40
56
 
57
+ @trace_const_diff_on_pass
58
+ class LowerSelectCopyToSlice(PassBase):
59
+ """
41
60
  [before]
42
- input (tensor, dim, *index)
61
+ input
43
62
  |
44
- select
63
+ select (tensor, dim, *index)
45
64
  |
46
65
  output
47
66
 
48
67
  [after]
49
68
 
50
- input (tensor, dim, *index)
69
+ input
51
70
  |
52
71
  slice (input=tensor, dim=dim, start=index, end=index+1, step=1)
53
72
  |
@@ -110,3 +129,98 @@ class LowerToSlice(PassBase):
110
129
  graph_module.recompile()
111
130
 
112
131
  return PassResult(modified)
132
+
133
+
134
+ @trace_const_diff_on_pass
135
+ class LowerIndexSelectToSlice(PassBase):
136
+ """
137
+
138
+ [before]
139
+ input
140
+ |
141
+ index_select.default (tensor, dim, *index)
142
+ |
143
+ output
144
+
145
+ [after]
146
+
147
+ input
148
+ |
149
+ slice (input=tensor, dim=dim, start=index, end=index+1, step=1)
150
+ |
151
+ reshape (input=slice_copy, size=select_shape)
152
+ |
153
+ output
154
+ """
155
+
156
+ def __init__(self):
157
+ super().__init__()
158
+
159
+ def call(self, exported_program: ExportedProgram) -> PassResult:
160
+ logger = logging.getLogger(__name__)
161
+
162
+ graph_module = exported_program.graph_module
163
+ graph = graph_module.graph
164
+ modified = False
165
+ for node in graph.nodes:
166
+ if not node.op == "call_function":
167
+ continue
168
+
169
+ if not node.target in ops.aten.index_select:
170
+ continue
171
+ args = IndexSelectArgs(*node.args, **node.kwargs)
172
+ input = args.input
173
+ dim = args.dim
174
+ index = args.index
175
+
176
+ input_shape = extract_shape(input)
177
+ if dim < 0:
178
+ dim = dim % len(input_shape)
179
+
180
+ if isinstance(index, torch.fx.Node):
181
+ if is_lifted_tensor_constant(exported_program, index):
182
+ index = get_lifted_tensor_constant(exported_program, index) # type: ignore[assignment]
183
+ elif is_param(exported_program, index):
184
+ index = get_param(exported_program, index) # type: ignore[assignment]
185
+ elif is_buffer(exported_program, index):
186
+ index = get_buffer(exported_program, index) # type: ignore[assignment]
187
+ else:
188
+ continue
189
+
190
+ if not isinstance(index, torch.Tensor):
191
+ continue
192
+
193
+ if not is_single_value_tensor(index):
194
+ # need to be lowered by LowerIndexSelect pass
195
+ continue
196
+ index_int = index.item() # convert scalar tensor to int
197
+
198
+ start = index_int
199
+ end = index_int + 1
200
+ step = 1
201
+ slice_copy_args = (input, dim, start, end, step)
202
+
203
+ with graph.inserting_after(node):
204
+ # slice
205
+ slice_node = graph.call_function(
206
+ torch.ops.aten.slice.Tensor, args=slice_copy_args
207
+ )
208
+ node_shape = extract_shape(node)
209
+ with graph.inserting_after(slice_node):
210
+ # reshape
211
+ reshape_args = (slice_node, list(node_shape))
212
+ reshape_node = graph.call_function(
213
+ torch.ops.aten.reshape.default, args=reshape_args
214
+ )
215
+ node.replace_all_uses_with(reshape_node, propagate_meta=False)
216
+
217
+ modified = True
218
+ logger.debug(
219
+ f"{node.name} is replaced with {slice_node.name} and {reshape_node.name} operators"
220
+ )
221
+
222
+ graph.eliminate_dead_code()
223
+ graph.lint()
224
+ graph_module.recompile()
225
+
226
+ return PassResult(modified)
tico/utils/convert.py CHANGED
@@ -58,7 +58,7 @@ from tico.passes.legalize_predefined_layout_operators import (
58
58
  )
59
59
  from tico.passes.lower_pow2_to_mul import LowerPow2ToMul
60
60
  from tico.passes.lower_to_resize_nearest_neighbor import LowerToResizeNearestNeighbor
61
- from tico.passes.lower_to_slice import LowerToSlice
61
+ from tico.passes.lower_to_slice import passes as LowerToSlicePasses
62
62
  from tico.passes.merge_consecutive_cat import MergeConsecutiveCat
63
63
  from tico.passes.remove_nop import RemoveNop
64
64
  from tico.passes.remove_redundant_assert_nodes import RemoveRedundantAssertionNodes
@@ -224,7 +224,7 @@ def convert_exported_module_to_circle(
224
224
  LegalizePreDefinedLayoutOperators(),
225
225
  LowerPow2ToMul(),
226
226
  ConvertConv1dToConv2d(),
227
- LowerToSlice(),
227
+ *LowerToSlicePasses(),
228
228
  ]
229
229
  )
230
230
  circle_legalize.run(exported_program)
@@ -852,7 +852,7 @@ class ReshapeArgs:
852
852
  @dataclass
853
853
  class ResizeNearestNeighborArgs:
854
854
  """
855
- # Maps from `torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode='nearest')` case.
855
+ # Mapped from `torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode='nearest')` case.
856
856
  """
857
857
 
858
858
  input: torch.fx.Node
@@ -1,6 +1,6 @@
1
1
  This file provides full text of licenses used in this project
2
2
 
3
- - Apache Licence 2.0
3
+ - Apache License 2.0
4
4
  - BSD 3-Clause
5
5
 
6
6
  ...............................................................................
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250414
3
+ Version: 0.1.0.dev250415
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=drwOk40YiN-xglot_aW8vG4uxq2GWx4F853Gxd6Jdv0,1181
1
+ tico/__init__.py,sha256=2PwjELyAXeZkJt02ODDKWpNPwQj5W4GtKFjH-VaZBdk,1181
2
2
  tico/pt2_to_circle.py,sha256=PPmFNw20jw2Z2VyM3ln9pX__jTzBOAZiv0gT5a-p-Y8,2666
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
@@ -51,7 +51,7 @@ tico/experimental/quantization/evaluation/executor/circle_executor.py,sha256=eCC
51
51
  tico/experimental/quantization/evaluation/executor/triv24_executor.py,sha256=sUoXl6oOO2arAKaNjOBg7HiQja145_Jv6qgY7XtR7A8,5159
52
52
  tico/experimental/quantization/passes/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
53
53
  tico/experimental/quantization/passes/fold_quant_ops.py,sha256=Jq5wmQDhdjsXxae2p6TnZj2gY5UMBEQ-sHkTodgkfUs,3327
54
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py,sha256=pqKPk-F5AYsas3aiCC0SmNginvHFcPKjSgmOg4JWlJ0,10822
54
+ tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py,sha256=i4rkM1vlN85fXA9oOrU25o8KWAaqA65NKngTX6MgctQ,12960
55
55
  tico/experimental/quantization/passes/propagate_qparam_backward.py,sha256=TGtyW0Z2qOTgVIasBdGRgbwH31YYd6ek7OvLTmCV614,3118
56
56
  tico/experimental/quantization/passes/propagate_qparam_forward.py,sha256=RhUHGCR2RpBO5KYkQ7Z8U5u7HEwDq2wdKHLKAJCi-5c,5138
57
57
  tico/experimental/quantization/passes/remove_weight_dequant_op.py,sha256=V-0tZqLxHbB0Dvw2D69aesfGdr3vmhvFlfNAIYnGTuo,6361
@@ -79,8 +79,8 @@ tico/passes/fuse_redundant_reshape_to_mean.py,sha256=SzHLLL2Yfj4k1c2L5i4PVI9EFUi
79
79
  tico/passes/legalize_causal_mask_value.py,sha256=KZc_UPk7CGPXO35JOu6dVrOzRJx-ZpJoVvuzz-GvQek,4080
80
80
  tico/passes/legalize_predefined_layout_operators.py,sha256=OKVKxw039Bl9df7YCt17mWOHaGhQLWkJeSIPQpWdNBM,16224
81
81
  tico/passes/lower_pow2_to_mul.py,sha256=imx9CoKG4bLyNqv4F-Z203s_P0-0SdRH-y4_Q0PTZVo,2304
82
- tico/passes/lower_to_resize_nearest_neighbor.py,sha256=EXH-xuXggLFSLaDEknxG_TdY4wHznP4zDG6M3qzq1A0,9982
83
- tico/passes/lower_to_slice.py,sha256=XUZ5DLDmWQeR2LU7yC4DDWK0lDwsdCS7BDiqambncA4,3651
82
+ tico/passes/lower_to_resize_nearest_neighbor.py,sha256=4bIxPSyNEzznTw8f8D9hMNXbZ0KmMPPPfRvITXonEz0,8881
83
+ tico/passes/lower_to_slice.py,sha256=6xK7A2hdqdBdN2XRyt2rFGnqJGcaXxmgndis4kn2q2w,7112
84
84
  tico/passes/merge_consecutive_cat.py,sha256=u1E-7axmX4yw0TMlFO5jXB0rn3xe5ASmqXcR42c9eNA,2722
85
85
  tico/passes/ops.py,sha256=BmUtfBPrSmCd2TsFc6PtOmhuznuFNBNHl12RS7CS0KI,2836
86
86
  tico/passes/remove_nop.py,sha256=5QE3inFsXgzyPT_t7pKeXNqD1LRf6ed_Mp7YMadA6AI,2707
@@ -175,7 +175,7 @@ tico/serialize/operators/op_view.py,sha256=5EMww-ve17Vm9XPuV03Tn7vJsjpU2J8U4d_FO
175
175
  tico/serialize/operators/op_where.py,sha256=qZDFKVQ2u8HB0Jjpto_1A4g-jfpEprAbBT3PmIGzdoY,2855
176
176
  tico/serialize/operators/utils.py,sha256=wQrcrnZ942_4SfhEW_7hIVi0tncGor-o7zuf2kmk9Io,1803
177
177
  tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
178
- tico/utils/convert.py,sha256=bu-9T_tBL9UawITqMPDo-T2WkI-d64XZ13NHXffKIGk,11661
178
+ tico/utils/convert.py,sha256=KCllPnvQ8bjEYR1yI72s9aNBp7Py1CzIEEpYSYZcu60,11684
179
179
  tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
180
180
  tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
181
181
  tico/utils/errors.py,sha256=f3csJjgbXG9W1aHhqEcou008Aor19W57X8oT5Hx8w1M,954
@@ -187,10 +187,10 @@ tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
187
187
  tico/utils/register_custom_op.py,sha256=FbMcrg8o5vKWC_aoVxL2GrIcR14KFi1yKG0mFGqXkPY,21595
188
188
  tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
189
189
  tico/utils/utils.py,sha256=pybDU1LoNhjEplANig11lboX9yzYRkvFCSmyYth_2Do,10359
190
- tico/utils/validate_args_kwargs.py,sha256=DXW7W5x-6pQR43_q4EFoeWE5pn1Nm7gNsrzaO5avm4k,24724
191
- tico-0.1.0.dev250414.dist-info/LICENSE,sha256=dAq6L2M49W6wEPmkyVtYxE3Omm3oInXPd6qw5mQX6wY,12644
192
- tico-0.1.0.dev250414.dist-info/METADATA,sha256=NxIJgt_iKq8KOz07g5VLVRyCv7Q9UQXJYML-EWcQYoY,7353
193
- tico-0.1.0.dev250414.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
194
- tico-0.1.0.dev250414.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
195
- tico-0.1.0.dev250414.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
196
- tico-0.1.0.dev250414.dist-info/RECORD,,
190
+ tico/utils/validate_args_kwargs.py,sha256=krT68b5CfBI9rxBIOsgYSy0LfEJqLfKfRikkp8ep9oQ,24726
191
+ tico-0.1.0.dev250415.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
192
+ tico-0.1.0.dev250415.dist-info/METADATA,sha256=zqEFm2T9WSrEat4XqjQXesrRZTWeXN5BoaOZrdD6mhI,7353
193
+ tico-0.1.0.dev250415.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
194
+ tico-0.1.0.dev250415.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
195
+ tico-0.1.0.dev250415.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
196
+ tico-0.1.0.dev250415.dist-info/RECORD,,