tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,11 @@ import torch
21
21
  from circle_schema import circle
22
22
 
23
23
  from tico.serialize.circle_graph import CircleSubgraph
24
- from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
24
+ from tico.serialize.circle_mapping import (
25
+ extract_circle_dtype,
26
+ extract_shape,
27
+ to_circle_shape,
28
+ )
25
29
  from tico.serialize.operators.hashable_opcode import OpCode
26
30
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
31
  from tico.serialize.operators.utils import create_builtin_operator, get_op_index
@@ -70,11 +74,16 @@ class RepeatVisitor(NodeVisitor):
70
74
  if r > 1:
71
75
  # Except last created concat, a tensor should be created.
72
76
  if repeat_dim_cnt > 1:
73
- repeated_shape = list(tensor_shape)
77
+ repeated_shape: List[int | torch.SymInt] = list(tensor_shape)
74
78
  repeated_shape[idx] = repeated_shape[idx] * r
79
+
80
+ repeated_cshape, repeated_cshape_signature = to_circle_shape(
81
+ repeated_shape
82
+ )
75
83
  concat_output = self.graph.add_tensor_from_scratch(
76
84
  prefix=f"{node.name}_concat_{idx}",
77
- shape=repeated_shape,
85
+ shape=repeated_cshape,
86
+ shape_signature=repeated_cshape_signature,
78
87
  dtype=tensor_dtype,
79
88
  source_node=node,
80
89
  )
@@ -66,7 +66,7 @@ class ReshapeVisitor(NodeVisitor):
66
66
  circle.BuiltinOptions.BuiltinOptions.ReshapeOptions
67
67
  )
68
68
  option = circle.ReshapeOptions.ReshapeOptionsT()
69
- option.newShape = size_i32
69
+ option.newShape = size_i32.tolist()
70
70
 
71
71
  operator.builtinOptions = option
72
72
 
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import CircleRMSNormArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class RMSNormVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.circle_custom.rms_norm.default,
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(
40
+ self,
41
+ node: torch.fx.Node,
42
+ ) -> circle.Operator.OperatorT:
43
+ args = CircleRMSNormArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44
+ input = args.input
45
+ weight = args.weight
46
+ eps = args.eps
47
+
48
+ op_index = get_op_index(
49
+ circle.BuiltinOperator.BuiltinOperator.RMS_NORM, self._op_codes
50
+ )
51
+
52
+ inputs = [input, weight]
53
+ outputs = [node]
54
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
55
+
56
+ # Op-specific option
57
+ operator.builtinOptionsType = (
58
+ circle.BuiltinOptions.BuiltinOptions.RmsNormOptions
59
+ )
60
+ option = circle.RmsNormOptions.RmsNormOptionsT()
61
+ option.epsilon = eps
62
+
63
+ operator.builtinOptions = option
64
+
65
+ return operator
@@ -24,25 +24,18 @@ from tico.serialize.operators.hashable_opcode import OpCode
24
24
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
25
25
  from tico.serialize.operators.utils import create_builtin_operator, get_op_index
26
26
  from tico.utils.errors import NotYetSupportedError
27
- from tico.utils.utils import HAS_TORCH_OVER_25
28
27
  from tico.utils.validate_args_kwargs import SafeSoftmaxArgs, SoftmaxArgs
29
28
 
30
29
 
31
30
  @register_node_visitor
32
31
  class SoftMaxVisitor(NodeVisitor):
33
- target: List[torch._ops.OpOverload] = (
34
- [
35
- torch.ops.aten._softmax.default,
36
- # NOTE: Let's treat _safe_softmax as normal _softmax as its usage is for training.
37
- # In order for optimization during inference, it can be replaced to softmax.
38
- # ref: https://github.com/pytorch/pytorch/pull/133882
39
- torch.ops.aten._safe_softmax.default,
40
- ]
41
- if HAS_TORCH_OVER_25
42
- else [
43
- torch.ops.aten._softmax.default,
44
- ]
45
- )
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten._softmax.default,
34
+ # NOTE: Let's treat _safe_softmax as normal _softmax as its usage is for training.
35
+ # In order for optimization during inference, it can be replaced to softmax.
36
+ # ref: https://github.com/pytorch/pytorch/pull/133882
37
+ torch.ops.aten._safe_softmax.default,
38
+ ]
46
39
 
47
40
  def __init__(self, op_codes: Dict[OpCode, int], graph):
48
41
  super().__init__(op_codes, graph)
@@ -58,12 +58,14 @@ class SplitWithSizesVisitor(NodeVisitor):
58
58
  inputs = [input, split_sizes_i32, axis_i32]
59
59
 
60
60
  """
61
- `split_with_sizes` has multiple output tensors and they are represented as `getitem`.
62
- Therefore, unlike other ops, node itself doesn't become a circle tensor. Instead, each `getitem` will be
61
+ `split_with_sizes` has multiple output tensors along with `getitem`.
62
+ Unlike other ops, node itself doesn't become a circle tensor. Instead, each `getitem` will be
63
63
  a circle tensor.
64
- Further, torch module having `split_with_sizes` may somtimes return selected outputs. At that time, `getitem`
65
- nodes are generated only for the ouptut selected. Since one-compiler assumes that `CircleSplitV` always has
66
- all the outputs, let's add unused output tensors to compensate this restriction.
64
+
65
+ torch module having `split_with_sizes` may return selected outputs by using `getitem`.
66
+ However, one-compiler assumes that `CircleSplitV` always have all outputs.
67
+
68
+ So, let's add unused output tensors to compensate this restriction.
67
69
  """
68
70
  outputs: List[Union[circle.Tensor.TensorT, torch.fx.node.Node]] = []
69
71
  sorted_users = sorted(node.users.keys(), key=lambda x: x.args[1]) # type: ignore[arg-type, return-value]
@@ -80,11 +82,17 @@ class SplitWithSizesVisitor(NodeVisitor):
80
82
  fake_tensor = node_val[idx]
81
83
  assert isinstance(fake_tensor, FakeTensor)
82
84
  shape = list(fake_tensor.size())
85
+
86
+ if any(isinstance(s, torch.SymInt) for s in shape):
87
+ # TODO: support dynamic shape
88
+ raise NotImplementedError("Dynamic shape is not supported yet.")
89
+
83
90
  dtype = to_circle_dtype(fake_tensor.dtype)
84
91
  tensor = self.graph.add_tensor_from_scratch(
85
- f"{node.name}_unused_{idx}",
86
- shape,
87
- dtype,
92
+ prefix=f"{node.name}_unused_{idx}",
93
+ shape=shape,
94
+ shape_signature=None, # TODO: support dynamic shape
95
+ dtype=dtype,
88
96
  source_node=node,
89
97
  )
90
98
  outputs.append(tensor)
@@ -23,7 +23,8 @@ from circle_schema import circle
23
23
  from tico.serialize.circle_mapping import (
24
24
  circle_legalize_dtype_to,
25
25
  extract_circle_dtype,
26
- extract_shape,
26
+ extract_circle_shape,
27
+ to_circle_shape,
27
28
  )
28
29
  from tico.serialize.operators.hashable_opcode import OpCode
29
30
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
@@ -76,15 +77,13 @@ class TransposeConvVisitor(NodeVisitor):
76
77
  bias = args.bias
77
78
  stride = args.stride
78
79
  padding = args.padding
79
- output_padding = args.output_padding
80
80
  groups = args.groups
81
- dilation = args.dilation
82
81
 
83
82
  assert groups == 1, "Only support group 1"
84
83
 
85
- input_shape = list(extract_shape(input_))
86
- output_shape = list(extract_shape(node))
87
- weight_shape = list(extract_shape(weight))
84
+ input_shape, input_shape_signature = extract_circle_shape(input_)
85
+ output_shape, _ = extract_circle_shape(node)
86
+ weight_shape, _ = extract_circle_shape(weight)
88
87
  assert len(input_shape) == 4, len(input_shape)
89
88
  assert len(output_shape) == 4, len(output_shape)
90
89
  assert len(weight_shape) == 4, len(weight_shape)
@@ -103,17 +102,21 @@ class TransposeConvVisitor(NodeVisitor):
103
102
  ],
104
103
  dtype=torch.int32,
105
104
  )
106
- pad_output_shape = [
105
+ pad_output_shape: List[int | torch.SymInt] = [
107
106
  input_shape[0],
108
107
  input_shape[1] + pad_h * 2,
109
108
  input_shape[2] + pad_w * 2,
110
109
  input_shape[3],
111
110
  ]
111
+ pad_output_cshape, pad_output_cshape_signature = to_circle_shape(
112
+ pad_output_shape
113
+ )
112
114
  # create padded output tensor
113
115
  input_qparam: Optional[QuantParam] = input_.meta.get(QPARAM_KEY)
114
116
  pad_output = self.graph.add_tensor_from_scratch(
115
117
  prefix=f"{node.name}_input_pad_output",
116
- shape=pad_output_shape,
118
+ shape=pad_output_cshape,
119
+ shape_signature=pad_output_cshape_signature,
117
120
  dtype=extract_circle_dtype(input_),
118
121
  qparam=input_qparam,
119
122
  source_node=node,
@@ -56,6 +56,7 @@ class ViewVisitor(NodeVisitor):
56
56
  if isinstance(size, int):
57
57
  raise Exception("scalar size conversion is not supported yet.")
58
58
 
59
+ # TODO: support dynamic shape
59
60
  size_i32 = circle_legalize_dtype_to(size, dtype=torch.int32)
60
61
  inputs = [input, size_i32]
61
62
  outputs = [node]
@@ -67,7 +68,7 @@ class ViewVisitor(NodeVisitor):
67
68
  circle.BuiltinOptions.BuiltinOptions.ReshapeOptions
68
69
  )
69
70
  option = circle.ReshapeOptions.ReshapeOptionsT()
70
- option.newShape = size_i32
71
+ option.newShape = size_i32.tolist()
71
72
 
72
73
  operator.builtinOptions = option
73
74
 
@@ -12,6 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional
17
+
18
+ import torch
19
+
15
20
  """
16
21
  This is a key for torch.fx.Node's meta dict to save QuantParam
17
22
 
@@ -19,11 +24,6 @@ QuantParam can be retrieved as node.meta[QPARAM_KEY]
19
24
  """
20
25
  QPARAM_KEY = "_quantization_parameters_"
21
26
 
22
- from dataclasses import dataclass
23
- from typing import List, Optional
24
-
25
- import torch
26
-
27
27
 
28
28
  @dataclass
29
29
  class QuantParam:
tico/utils/convert.py CHANGED
@@ -20,25 +20,14 @@ import torch
20
20
  from torch.export import export, ExportedProgram
21
21
 
22
22
  from tico.config import CompileConfigBase, get_default_config
23
- from tico.experimental.quantization.passes.fold_quant_ops import FoldQuantOps
24
- from tico.experimental.quantization.passes.insert_quantize_on_dtype_mismatch import (
25
- InsertQuantizeOnDtypeMismatch,
26
- )
27
- from tico.experimental.quantization.passes.propagate_qparam_backward import (
28
- PropagateQParamBackward,
29
- )
30
- from tico.experimental.quantization.passes.propagate_qparam_forward import (
31
- PropagateQParamForward,
32
- )
33
- from tico.experimental.quantization.passes.quantize_bias import QuantizeBias
34
- from tico.experimental.quantization.passes.remove_weight_dequant_op import (
35
- RemoveWeightDequantOp,
36
- )
37
23
  from tico.passes.cast_aten_where_arg_type import CastATenWhereArgType
24
+ from tico.passes.cast_clamp_mixed_type_args import CastClampMixedTypeArgs
38
25
  from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
39
26
  from tico.passes.const_prop_pass import ConstPropPass
40
27
  from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
28
+ from tico.passes.convert_expand_to_slice_cat import ConvertExpandToSliceCat
41
29
  from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
30
+ from tico.passes.convert_matmul_to_linear import ConvertMatmulToLinear
42
31
  from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy
43
32
  from tico.passes.convert_to_relu6 import ConvertToReLU6
44
33
  from tico.passes.decompose_addmm import DecomposeAddmm
@@ -71,6 +60,14 @@ from tico.passes.remove_redundant_slice import RemoveRedundantSlice
71
60
  from tico.passes.remove_redundant_to_copy import RemoveRedundantToCopy
72
61
  from tico.passes.restore_linear import RestoreLinear
73
62
  from tico.passes.segment_index_select import SegmentIndexSelectConst
63
+ from tico.quantization.passes.fold_quant_ops import FoldQuantOps
64
+ from tico.quantization.passes.insert_quantize_on_dtype_mismatch import (
65
+ InsertQuantizeOnDtypeMismatch,
66
+ )
67
+ from tico.quantization.passes.propagate_qparam_backward import PropagateQParamBackward
68
+ from tico.quantization.passes.propagate_qparam_forward import PropagateQParamForward
69
+ from tico.quantization.passes.quantize_bias import QuantizeBias
70
+ from tico.quantization.passes.remove_weight_dequant_op import RemoveWeightDequantOp
74
71
  from tico.serialize.circle_serializer import build_circle
75
72
  from tico.serialize.operators.node_visitor import get_support_targets
76
73
  from tico.utils import logging
@@ -105,6 +102,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
105
102
  torch.ops.aten._safe_softmax.default,
106
103
  torch.ops.aten.relu6.default, # Do not decompose to hardtanh
107
104
  torch.ops.aten.linear.default,
105
+ torch.ops.aten.upsample_nearest2d.vec,
108
106
  )
109
107
  ep = ep.run_decompositions(_preserve_ops=_preserve_ops)
110
108
 
@@ -123,6 +121,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
123
121
  torch.ops.aten.relu6.default, # Do not decompose to hardtanh
124
122
  torch.ops.aten.prelu.default,
125
123
  torch.ops.aten.linear.default,
124
+ torch.ops.aten.upsample_nearest2d.vec,
126
125
  )
127
126
  for op in _preserve_ops:
128
127
  if op in _decomp_table:
@@ -137,6 +136,8 @@ def traced_run_decompositions(exported_program: ExportedProgram):
137
136
  torch.__version__.startswith("2.6")
138
137
  or torch.__version__.startswith("2.7")
139
138
  or torch.__version__.startswith("2.8")
139
+ or torch.__version__.startswith("2.9")
140
+ or torch.__version__.startswith("2.10")
140
141
  ):
141
142
  return run_decompositions(exported_program)
142
143
  else:
@@ -153,7 +154,7 @@ def check_unsupported_target(exported_program: ExportedProgram):
153
154
  for n in exported_program.graph.nodes:
154
155
  if n.op != "call_function":
155
156
  continue
156
- if not n.target in supported_target:
157
+ if n.target not in supported_target:
157
158
  unsupported.append(n)
158
159
 
159
160
  if unsupported:
@@ -245,12 +246,21 @@ def convert_exported_module_to_circle(
245
246
  ConstPropPass(),
246
247
  SegmentIndexSelectConst(),
247
248
  LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")),
249
+ ConvertExpandToSliceCat(enabled=config.get("convert_expand_to_slice_cat")),
250
+ ConvertMatmulToLinear(
251
+ enable_lhs_const=config.get("convert_lhs_const_mm_to_fc"),
252
+ enable_rhs_const=config.get("convert_rhs_const_mm_to_fc"),
253
+ enable_single_batch_lhs_const_bmm=config.get(
254
+ "convert_single_batch_lhs_const_bmm_to_fc"
255
+ ),
256
+ ),
248
257
  LowerToResizeNearestNeighbor(),
249
258
  LegalizePreDefinedLayoutOperators(),
250
259
  LowerPow2ToMul(),
251
260
  ConvertConv1dToConv2d(),
252
261
  *LowerToSlicePasses(),
253
262
  FuseLeadingUnsqueezeReshape(),
263
+ CastClampMixedTypeArgs(),
254
264
  ]
255
265
  )
256
266
  circle_legalize.run(exported_program)
@@ -282,7 +292,7 @@ def convert_exported_module_to_circle(
282
292
 
283
293
  check_unsupported_target(exported_program)
284
294
  check_training_ops(exported_program)
285
- circle_program = build_circle(exported_program)
295
+ circle_program = build_circle(exported_program, config)
286
296
 
287
297
  return circle_program
288
298
 
@@ -291,6 +301,7 @@ def convert(
291
301
  mod: torch.nn.Module,
292
302
  args: Tuple[Any, ...],
293
303
  kwargs: Optional[Dict[str, Any]] = None,
304
+ dynamic_shapes: Optional[dict] = None,
294
305
  strict: bool = True,
295
306
  config: CompileConfigBase = get_default_config(),
296
307
  ) -> CircleModel:
@@ -301,7 +312,9 @@ def convert(
301
312
  )
302
313
 
303
314
  with torch.no_grad():
304
- exported_program = export(mod, args, kwargs, strict=strict)
315
+ exported_program = export(
316
+ mod, args, kwargs, dynamic_shapes=dynamic_shapes, strict=strict
317
+ )
305
318
 
306
319
  circle_binary = convert_exported_module_to_circle(exported_program, config=config)
307
320
 
tico/utils/dtype.py ADDED
@@ -0,0 +1,42 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from circle_schema import circle
5
+
6
+ NUMPY_TO_TORCH_DTYPE_DICT = {
7
+ np.dtype("float32"): torch.float32,
8
+ np.dtype("float64"): torch.float64,
9
+ np.dtype("float16"): torch.float16,
10
+ np.dtype("complex64"): torch.complex64,
11
+ np.dtype("complex128"): torch.complex128,
12
+ np.dtype("int64"): torch.int64,
13
+ np.dtype("int32"): torch.int32,
14
+ np.dtype("int16"): torch.int16,
15
+ np.dtype("int8"): torch.int8,
16
+ np.dtype("uint8"): torch.uint8,
17
+ np.dtype("bool"): torch.bool,
18
+ }
19
+
20
+ CIRCLE_TO_TORCH_DTYPE_DICT = {
21
+ circle.TensorType.TensorType.FLOAT32: torch.float32,
22
+ circle.TensorType.TensorType.UINT8: torch.uint8,
23
+ circle.TensorType.TensorType.INT8: torch.int8,
24
+ circle.TensorType.TensorType.INT16: torch.int16,
25
+ circle.TensorType.TensorType.INT32: torch.int32,
26
+ circle.TensorType.TensorType.INT64: torch.int64,
27
+ circle.TensorType.TensorType.BOOL: torch.bool,
28
+ }
29
+
30
+
31
+ def numpy_dtype_to_torch_dtype(np_dtype: np.dtype) -> torch.dtype:
32
+ return NUMPY_TO_TORCH_DTYPE_DICT[np_dtype]
33
+
34
+
35
+ def circle_dtype_to_torch_dtype(circle_dtype: int) -> torch.dtype:
36
+ assert isinstance(circle_dtype, int)
37
+ if circle_dtype not in CIRCLE_TO_TORCH_DTYPE_DICT:
38
+ raise RuntimeError(f"Unsupported dtype {circle_dtype}")
39
+
40
+ torch_dtype = CIRCLE_TO_TORCH_DTYPE_DICT[circle_dtype]
41
+ assert torch_dtype is not None
42
+ return torch_dtype
tico/utils/graph.py CHANGED
@@ -24,7 +24,7 @@ import torch
24
24
  from torch.export import ExportedProgram
25
25
  from torch.export.exported_program import InputKind, InputSpec, TensorArgument
26
26
 
27
- from tico.utils.utils import get_fake_mode, set_new_meta_val
27
+ from tico.utils.utils import get_fake_mode
28
28
 
29
29
 
30
30
  def is_torch_param(node: torch.fx.Node, ep: ExportedProgram):
tico/utils/model.py CHANGED
@@ -14,6 +14,7 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ from pathlib import Path
17
18
  from typing import Any
18
19
 
19
20
  from tico.interpreter import infer
@@ -32,6 +33,6 @@ class CircleModel:
32
33
  buf = bytes(f.read())
33
34
  return CircleModel(buf)
34
35
 
35
- def save(self, circle_path: str) -> None:
36
+ def save(self, circle_path: str | Path) -> None:
36
37
  with open(circle_path, "wb") as f:
37
38
  f.write(self.circle_binary)
tico/utils/padding.py CHANGED
@@ -39,8 +39,8 @@ class ConvPaddingInfo(NamedTuple):
39
39
 
40
40
  def identify_padding(
41
41
  padding: PaddingValue,
42
- input_shape: Sequence[int],
43
- output_shape: Sequence[int],
42
+ input_shape: Sequence[int | torch.SymInt] | torch.Size,
43
+ output_shape: Sequence[int | torch.SymInt] | torch.Size,
44
44
  stride: Sequence[int],
45
45
  ) -> ConvPaddingInfo:
46
46
  """
@@ -0,0 +1,134 @@
1
+ import threading
2
+
3
+ import torch
4
+ from packaging.version import Version
5
+
6
+ from tico.utils import logging
7
+ from tico.utils.installed_packages import is_transformers_installed
8
+
9
+ __all__ = ["register_dynamic_cache"]
10
+
11
+
12
+ def register_dynamic_cache():
13
+ PyTreeRegistryHelper().register_dynamic_cache()
14
+
15
+
16
+ class PyTreeRegistryHelper:
17
+ """
18
+ Thread-safe singleton helper class for registering custom PyTree nodes.
19
+
20
+ This class provides functionality to register DynamicCache as a PyTree node
21
+ for torch.export compatibility. This registration is only needed for
22
+ transformers versions below 4.50.0.
23
+
24
+ Thread Safety:
25
+ - Uses a class-level threading.Lock() to ensure thread-safe singleton instantiation
26
+ - Uses the same lock to protect the registration process from concurrent calls
27
+ """
28
+
29
+ _instance = None # Class variable to hold the singleton instance
30
+ _has_called = False # Flag to track if registration has been performed
31
+ _lock = threading.Lock() # Class-level lock for thread-safe operations
32
+
33
+ def __init__(self):
34
+ """Private constructor to prevent direct instantiation"""
35
+ pass
36
+
37
+ def __new__(cls, *args, **kwargs):
38
+ """
39
+ Thread-safe singleton instance creation using double-checked locking pattern.
40
+
41
+ Returns:
42
+ PyTreeRegistryHelper: The singleton instance of this class
43
+ """
44
+ if not cls._instance:
45
+ with cls._lock: # Acquire lock for thread-safe instantiation
46
+ if not cls._instance: # Double-check after acquiring lock
47
+ cls._instance = super().__new__(cls)
48
+ return cls._instance
49
+
50
+ def register_dynamic_cache(self):
51
+ """
52
+ Registers DynamicCache as a PyTree node for torch.export compatibility.
53
+
54
+ This method is thread-safe and idempotent - it will only perform the
55
+ registration once, even if called multiple times from different threads.
56
+
57
+ Note:
58
+ This registration is only needed for transformers versions below 4.50.0.
59
+
60
+ Raises:
61
+ ImportError: If transformers package is not installed
62
+ """
63
+ with self._lock: # Acquire lock for thread-safe registration
64
+ if self.__class__._has_called:
65
+ logger = logging.getLogger(__name__)
66
+ logger.debug("register_dynamic_cache already called, skipping")
67
+ return
68
+
69
+ self.__class__._has_called = True
70
+ logger = logging.getLogger(__name__)
71
+ logger.info("Registering DynamicCache PyTree node")
72
+
73
+ if not is_transformers_installed: # type: ignore[truthy-function]
74
+ raise ImportError("transformers package is not installed")
75
+
76
+ import transformers
77
+
78
+ HAS_TRANSFORMERS_LESS_4_50_0 = Version(transformers.__version__) < Version(
79
+ "4.50.0"
80
+ )
81
+ if not HAS_TRANSFORMERS_LESS_4_50_0:
82
+ return
83
+
84
+ from transformers.cache_utils import DynamicCache
85
+
86
+ def _flatten_dynamic_cache(dynamic_cache: DynamicCache):
87
+ if not isinstance(dynamic_cache, DynamicCache):
88
+ raise RuntimeError(
89
+ "This pytree flattening function should only be applied to DynamicCache"
90
+ )
91
+ HAS_TORCH_2_6_0 = Version(torch.__version__) >= Version("2.6.0")
92
+ if not HAS_TORCH_2_6_0:
93
+ logger = logging.getLogger(__name__)
94
+ logger.warning_once(
95
+ "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions."
96
+ )
97
+ dictionary = {
98
+ "key_cache": getattr(dynamic_cache, "key_cache"),
99
+ "value_cache": getattr(dynamic_cache, "value_cache"),
100
+ }
101
+ return torch.utils._pytree._dict_flatten(dictionary)
102
+
103
+ def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache):
104
+ dictionary = {
105
+ "key_cache": getattr(dynamic_cache, "key_cache"),
106
+ "value_cache": getattr(dynamic_cache, "value_cache"),
107
+ }
108
+ return torch.utils._pytree._dict_flatten_with_keys(dictionary)
109
+
110
+ def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
111
+ dictionary = torch.utils._pytree._dict_unflatten(values, context)
112
+ cache = DynamicCache()
113
+ for k, v in dictionary.items():
114
+ setattr(cache, k, v)
115
+ return cache
116
+
117
+ def _flatten_dynamic_cache_for_fx(cache, spec):
118
+ dictionary = {
119
+ "key_cache": getattr(cache, "key_cache"),
120
+ "value_cache": getattr(cache, "value_cache"),
121
+ }
122
+ return torch.fx._pytree._dict_flatten_spec(dictionary, spec)
123
+
124
+ torch.utils._pytree.register_pytree_node(
125
+ DynamicCache,
126
+ _flatten_dynamic_cache,
127
+ _unflatten_dynamic_cache,
128
+ serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
129
+ flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
130
+ )
131
+ # TODO: This won't be needed in torch 2.7+.
132
+ torch.fx._pytree.register_pytree_flatten_spec(
133
+ DynamicCache, _flatten_dynamic_cache_for_fx
134
+ )