tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -28,6 +28,42 @@ from tico.utils.errors import InvalidArgumentError
28
28
  from tico.utils.validate_args_kwargs import ConstantPadNdArgs
29
29
 
30
30
 
31
+ def convert_to_circle_padding(pad, input_shape_len):
32
+ MAX_RANK = 4
33
+
34
+ if not (1 <= input_shape_len <= MAX_RANK):
35
+ raise InvalidArgumentError(
36
+ f"Input rank must be between 1 and {MAX_RANK}, got {input_shape_len}"
37
+ )
38
+
39
+ if len(pad) % 2 != 0 or len(pad) < 2 or len(pad) > 8:
40
+ raise InvalidArgumentError(
41
+ f"Pad length must be an even number between 2 and 8, got {len(pad)}"
42
+ )
43
+
44
+ if len(pad) == 2:
45
+ padding = [[pad[0], pad[1]]]
46
+ elif len(pad) == 4:
47
+ padding = [[pad[2], pad[3]], [pad[0], pad[1]]]
48
+ elif len(pad) == 6:
49
+ padding = [[pad[4], pad[5]], [pad[2], pad[3]], [pad[0], pad[1]]]
50
+ elif len(pad) == 8:
51
+ padding = [
52
+ [pad[6], pad[7]],
53
+ [pad[4], pad[5]],
54
+ [pad[2], pad[3]],
55
+ [pad[0], pad[1]],
56
+ ]
57
+ else:
58
+ assert False, "Cannot reach here"
59
+
60
+ # Fill [0, 0] padding for the rest of dimension
61
+ while len(padding) < input_shape_len:
62
+ padding.insert(0, [0, 0])
63
+
64
+ return padding
65
+
66
+
31
67
  @register_node_visitor
32
68
  class ConstantPadNdVisitor(NodeVisitor):
33
69
  target: List[torch._ops.OpOverload] = [torch.ops.aten.constant_pad_nd.default]
@@ -45,19 +81,13 @@ class ConstantPadNdVisitor(NodeVisitor):
45
81
  val = args.value
46
82
 
47
83
  if val != 0:
48
- raise InvalidArgumentError("Only support 0 value padding.")
84
+ raise InvalidArgumentError(f"Only support 0 value padding. pad:{pad}")
49
85
 
50
86
  input_shape_len = len(extract_shape(input_))
51
- padding_size = [[pad[2], pad[3]], [pad[0], pad[1]]]
52
- if input_shape_len == 3:
53
- padding_size = [[0, 0]] + padding_size
54
- elif input_shape_len == 4:
55
- padding_size = [[0, 0], [0, 0]] + padding_size
56
- else:
57
- raise InvalidArgumentError("Only support 3D/4D inputs.")
58
-
59
- paddings = torch.tensor(padding_size, dtype=torch.int32)
60
- inputs = [input_, paddings]
87
+
88
+ padding = convert_to_circle_padding(pad, input_shape_len)
89
+
90
+ inputs = [input_, torch.tensor(padding, dtype=torch.int32)]
61
91
  outputs = [node]
62
92
 
63
93
  op_index = get_op_index(
@@ -20,7 +20,11 @@ if TYPE_CHECKING:
20
20
  import torch
21
21
  from circle_schema import circle
22
22
 
23
- from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
23
+ from tico.serialize.circle_mapping import (
24
+ extract_circle_dtype,
25
+ extract_shape,
26
+ to_circle_shape,
27
+ )
24
28
  from tico.serialize.operators.hashable_opcode import OpCode
25
29
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
30
  from tico.serialize.operators.utils import create_builtin_operator, get_op_index
@@ -111,9 +115,9 @@ class Conv2dVisitor(NodeVisitor):
111
115
 
112
116
  assert groups == 1, "Only support group 1 conv2d"
113
117
 
114
- input_shape = list(extract_shape(input_))
115
- output_shape = list(extract_shape(node))
116
- weight_shape = list(extract_shape(weight))
118
+ input_shape = extract_shape(input_)
119
+ output_shape = extract_shape(node)
120
+ weight_shape = extract_shape(weight)
117
121
  assert len(input_shape) == 4, len(input_shape)
118
122
  assert len(output_shape) == 4, len(output_shape)
119
123
  assert len(weight_shape) == 4, len(weight_shape)
@@ -132,17 +136,21 @@ class Conv2dVisitor(NodeVisitor):
132
136
  ],
133
137
  dtype=torch.int32,
134
138
  )
135
- pad_output_shape = [
139
+ pad_output_shape: List[int | torch.SymInt] = [
136
140
  input_shape[0],
137
141
  input_shape[1] + pad_h * 2,
138
142
  input_shape[2] + pad_w * 2,
139
143
  input_shape[3],
140
144
  ]
145
+ pad_output_cshape, pad_output_cshape_signature = to_circle_shape(
146
+ pad_output_shape
147
+ )
141
148
  # create padded output tensor
142
149
  input_qparam: Optional[QuantParam] = input_.meta.get(QPARAM_KEY)
143
150
  pad_output = self.graph.add_tensor_from_scratch(
144
151
  prefix=f"{node.name}_input_pad_output",
145
- shape=pad_output_shape,
152
+ shape=pad_output_cshape,
153
+ shape_signature=pad_output_cshape_signature,
146
154
  dtype=extract_circle_dtype(input_),
147
155
  qparam=input_qparam,
148
156
  source_node=node,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Dict, List, TYPE_CHECKING, Union
15
+ from typing import Dict, List, Optional, TYPE_CHECKING, Union
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  import torch._ops
@@ -52,7 +52,15 @@ class CopyVisitor(NodeVisitor):
52
52
  def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
53
53
  super().__init__(op_codes, graph)
54
54
 
55
- def check_to_do_broadcast(self, dst: List[int], src: List[int]) -> bool:
55
+ def check_to_do_broadcast(
56
+ self,
57
+ dst: List[int],
58
+ dst_sig: Optional[List[int]],
59
+ src: List[int],
60
+ src_sig: Optional[List[int]],
61
+ ) -> bool:
62
+ assert dst_sig is None
63
+ assert src_sig is None
56
64
  return dst != src
57
65
 
58
66
  def define_broadcast_to_node(
@@ -102,6 +110,12 @@ class CopyVisitor(NodeVisitor):
102
110
  # To connect 'dst' to Reshape node in the graph, 'dst' must be converted to Shape op.
103
111
  dst_tensor: circle.Tensor.TensorT = self.graph.get_tensor(dst)
104
112
  dst_shape: List[int] = dst_tensor.shape
113
+ dst_shape_signature: Optional[List[int]] = dst_tensor.shapeSignature
114
+
115
+ if dst_shape_signature is not None:
116
+ # TODO: support dynamic shape
117
+ raise NotYetSupportedError("Dynamic shape is not supported yet.")
118
+
105
119
  dst_shape_tensor = torch.as_tensor(dst_shape, dtype=torch.int32)
106
120
 
107
121
  dst_shape_shape = [len(dst_shape)]
@@ -110,6 +124,7 @@ class CopyVisitor(NodeVisitor):
110
124
  shape_output = self.graph.add_tensor_from_scratch(
111
125
  prefix=f"{dst_name}_shape_output",
112
126
  shape=dst_shape_shape,
127
+ shape_signature=None,
113
128
  dtype=circle.TensorType.TensorType.INT32,
114
129
  source_node=node,
115
130
  )
@@ -119,9 +134,16 @@ class CopyVisitor(NodeVisitor):
119
134
 
120
135
  src_tensor: circle.Tensor.TensorT = self.graph.get_tensor(src)
121
136
  src_shape: List[int] = src_tensor.shape
137
+ src_shape_signature: Optional[List[int]] = src_tensor.shapeSignature
138
+
139
+ if src_shape_signature is not None:
140
+ # TODO: support dynamic shape
141
+ raise NotYetSupportedError("Dynamic shape is not supported yet.")
122
142
 
123
143
  # The src tensor must be broadcastable with the dst tensor.
124
- do_broadcast = self.check_to_do_broadcast(dst_shape, src_shape)
144
+ do_broadcast = self.check_to_do_broadcast(
145
+ dst_shape, dst_shape_signature, src_shape, src_shape_signature
146
+ )
125
147
  if do_broadcast:
126
148
  # create braodcastTo output tensor
127
149
  src_name: str = src.name
@@ -131,6 +153,7 @@ class CopyVisitor(NodeVisitor):
131
153
  self.graph.add_tensor_from_scratch(
132
154
  prefix=f"{src_name}_broadcast_to_output",
133
155
  shape=dst_shape,
156
+ shape_signature=dst_shape_signature,
134
157
  dtype=src_type,
135
158
  source_node=node,
136
159
  )
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Dict, List, TYPE_CHECKING
15
+ from typing import Dict, List, Optional, TYPE_CHECKING
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  import torch._ops
@@ -57,6 +57,7 @@ class CumsumVisitor(NodeVisitor):
57
57
  if input_dtype == torch.int32:
58
58
  input_tensor: circle.Tensor.TensorT = self.graph.get_tensor(input)
59
59
  input_shape: List[int] = input_tensor.shape
60
+ input_shape_signature: Optional[List[int]] = input_tensor.shapeSignature
60
61
  cast_op_index = get_op_index(
61
62
  circle.BuiltinOperator.BuiltinOperator.CAST, self._op_codes
62
63
  )
@@ -66,6 +67,7 @@ class CumsumVisitor(NodeVisitor):
66
67
  prefix=cast_name,
67
68
  dtype=cast_dtype,
68
69
  shape=input_shape,
70
+ shape_signature=input_shape_signature,
69
71
  source_node=node,
70
72
  )
71
73
  cast_operator = create_builtin_operator(
@@ -20,7 +20,11 @@ if TYPE_CHECKING:
20
20
  import torch
21
21
  from circle_schema import circle
22
22
 
23
- from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
23
+ from tico.serialize.circle_mapping import (
24
+ extract_circle_dtype,
25
+ extract_shape,
26
+ to_circle_shape,
27
+ )
24
28
  from tico.serialize.operators.hashable_opcode import OpCode
25
29
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
30
  from tico.serialize.operators.utils import create_builtin_operator, get_op_index
@@ -115,12 +119,13 @@ class DepthwiseConv2dVisitor(NodeVisitor):
115
119
  dilation = args.dilation
116
120
  groups = args.groups
117
121
 
118
- input_shape = list(extract_shape(input_)) # OHWI
119
- output_shape = list(extract_shape(node)) # OHWI
120
- weight_shape = list(extract_shape(weight)) # 1HWO
122
+ input_shape = extract_shape(input_) # OHWI
123
+ output_shape = extract_shape(node) # OHWI
124
+ weight_shape = extract_shape(weight) # 1HWO
121
125
  assert len(input_shape) == 4, len(input_shape)
122
126
  assert len(output_shape) == 4, len(output_shape)
123
- assert len(weight_shape) == 4
127
+ assert len(weight_shape) == 4, len(weight_shape)
128
+
124
129
  assert weight_shape[0] == 1
125
130
  assert weight_shape[3] == output_shape[3]
126
131
  assert input_shape[3] == groups
@@ -145,17 +150,22 @@ class DepthwiseConv2dVisitor(NodeVisitor):
145
150
  ],
146
151
  dtype=torch.int32,
147
152
  )
148
- pad_output_shape = [
153
+ pad_output_shape: List[int | torch.SymInt] = [
149
154
  input_shape[0],
150
155
  input_shape[1] + pad_h * 2,
151
156
  input_shape[2] + pad_w * 2,
152
157
  input_shape[3],
153
158
  ]
159
+
160
+ pad_output_cshape, pad_output_cshape_signature = to_circle_shape(
161
+ pad_output_shape
162
+ )
154
163
  # create padded output tensor
155
164
  input_qparam: Optional[QuantParam] = input_.meta.get(QPARAM_KEY)
156
165
  pad_output = self.graph.add_tensor_from_scratch(
157
166
  prefix=f"{node.name}_input_pad_output",
158
- shape=pad_output_shape,
167
+ shape=pad_output_cshape,
168
+ shape_signature=pad_output_cshape_signature,
159
169
  dtype=extract_circle_dtype(input_),
160
170
  qparam=input_qparam,
161
171
  source_node=node,
@@ -21,10 +21,8 @@ import torch
21
21
  from circle_schema import circle
22
22
 
23
23
  from tico.serialize.circle_graph import CircleSubgraph
24
- from tico.serialize.circle_mapping import to_circle_dtype
25
24
  from tico.serialize.operators.hashable_opcode import OpCode
26
25
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
- from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
26
  from tico.utils.validate_args_kwargs import FullLikeArgs
29
27
 
30
28
 
@@ -49,7 +49,14 @@ class IndexSelectVisitor(NodeVisitor):
49
49
  self._op_codes,
50
50
  )
51
51
 
52
+ # TODO: Revise this to be simple
52
53
  dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
54
+ assert (
55
+ dim_i32.dim() == 0 or len(dim_i32) == 1
56
+ ), f"dim should be scalar: {dim_i32}"
57
+ dim_i32_item = dim_i32.item()
58
+ assert isinstance(dim_i32_item, int)
59
+
53
60
  inputs = [input, index]
54
61
  outputs = [node]
55
62
 
@@ -57,7 +64,7 @@ class IndexSelectVisitor(NodeVisitor):
57
64
 
58
65
  operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.GatherOptions
59
66
  option = circle.GatherOptions.GatherOptionsT()
60
- option.axis = dim_i32
67
+ option.axis = dim_i32_item
61
68
 
62
69
  operator.builtinOptions = option
63
70
 
@@ -73,12 +73,6 @@ class InstanceNormVisitor(NodeVisitor):
73
73
  eps = args.eps
74
74
 
75
75
  # Ignore training-related args
76
- running_mean = args.running_mean
77
- running_var = args.running_var
78
- use_input_stats = args.use_input_stats
79
- momentum = args.momentum
80
- cudnn_enabled = args.cudnn_enabled
81
-
82
76
  input_shape = list(extract_shape(input))
83
77
  assert len(input_shape) == 4, len(input_shape)
84
78
 
@@ -0,0 +1,54 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import LeArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class LeVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten.le.Scalar,
34
+ torch.ops.aten.le.Tensor,
35
+ ]
36
+
37
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
38
+ super().__init__(op_codes, graph)
39
+
40
+ def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
41
+ args = LeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
42
+ input = args.input
43
+ other = args.other
44
+
45
+ op_index = get_op_index(
46
+ circle.BuiltinOperator.BuiltinOperator.LESS_EQUAL, self._op_codes
47
+ )
48
+
49
+ inputs = [input, other]
50
+ outputs = [node]
51
+
52
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
53
+
54
+ return operator
@@ -23,7 +23,7 @@ from circle_schema import circle
23
23
  from tico.serialize.circle_graph import CircleSubgraph
24
24
  from tico.serialize.circle_mapping import (
25
25
  extract_circle_dtype,
26
- extract_shape,
26
+ extract_circle_shape,
27
27
  extract_torch_dtype,
28
28
  )
29
29
  from tico.serialize.operators.hashable_opcode import OpCode
@@ -62,11 +62,12 @@ class Log1pVisitor(NodeVisitor):
62
62
  args = Log1pArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
63
63
  input = args.input
64
64
 
65
- input_shape = list(extract_shape(input))
65
+ input_shape, input_shape_signature = extract_circle_shape(input)
66
66
  dst_dtype_circle = extract_circle_dtype(input)
67
67
  add_tensor: circle.Tensor.TensorT = self.graph.add_tensor_from_scratch(
68
68
  prefix=f"{input.name}_add",
69
69
  shape=input_shape,
70
+ shape_signature=input_shape_signature,
70
71
  dtype=dst_dtype_circle,
71
72
  source_node=node,
72
73
  )
@@ -22,7 +22,11 @@ import torch
22
22
  from circle_schema import circle
23
23
 
24
24
  from tico.serialize.circle_graph import CircleSubgraph
25
- from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
25
+ from tico.serialize.circle_mapping import (
26
+ extract_circle_dtype,
27
+ extract_shape,
28
+ to_circle_shape,
29
+ )
26
30
  from tico.serialize.operators.hashable_opcode import OpCode
27
31
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
28
32
  from tico.serialize.operators.utils import (
@@ -88,8 +92,15 @@ class MaxPool2DWithIndicesVisitor(NodeVisitor):
88
92
  ],
89
93
  dtype=torch.int32,
90
94
  )
91
- input_shape = list(extract_shape(input))
95
+
96
+ input_shape = extract_shape(input)
92
97
  input_dtype: int = extract_circle_dtype(input)
98
+
99
+ input_qparam: Optional[QuantParam] = (
100
+ input.meta[QPARAM_KEY] if QPARAM_KEY in input.meta else None
101
+ )
102
+
103
+ # create padded input tensor
93
104
  padded_input_shape = [
94
105
  input_shape[0],
95
106
  input_shape[1],
@@ -98,17 +109,16 @@ class MaxPool2DWithIndicesVisitor(NodeVisitor):
98
109
  ]
99
110
  padded_input_shape[1] += padding[0] * 2
100
111
  padded_input_shape[2] += padding[1] * 2
101
- input_qparam: Optional[QuantParam] = (
102
- input.meta[QPARAM_KEY] if QPARAM_KEY in input.meta else None
103
- )
104
- # create padded input tensor
112
+ padded_cshape, padded_cshape_signature = to_circle_shape(padded_input_shape)
105
113
  padded_input_tensor = self.graph.add_tensor_from_scratch(
106
114
  prefix=f"{input.name}_pad_output",
107
- shape=padded_input_shape,
115
+ shape=padded_cshape,
116
+ shape_signature=padded_cshape_signature,
108
117
  dtype=input_dtype,
109
118
  qparam=input_qparam,
110
119
  source_node=node,
111
120
  )
121
+
112
122
  if input_qparam is not None:
113
123
  padding_value = get_integer_dtype_min(input_qparam.dtype)
114
124
  else:
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
20
20
  import torch
21
21
  from circle_schema import circle
22
22
 
23
- from tico.serialize.circle_graph import CircleSubgraph, is_const
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
24
  from tico.serialize.operators.hashable_opcode import OpCode
25
25
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
26
  from tico.serialize.operators.utils import create_builtin_operator, get_op_index
@@ -28,9 +28,9 @@ from tico.utils.validate_args_kwargs import MatmulArgs
28
28
 
29
29
 
30
30
  @register_node_visitor
31
- class MatmulDefaultVisitor(NodeVisitor):
31
+ class MatmulVisitor(NodeVisitor):
32
32
  """
33
- Convert matmul to equavalent BatchMatMul or FullyConnected with Transpose.
33
+ Convert matmul to Circle BatchMatMul
34
34
  """
35
35
 
36
36
  target: List[torch._ops.OpOverload] = [torch.ops.aten.mm.default]
@@ -38,130 +38,7 @@ class MatmulDefaultVisitor(NodeVisitor):
38
38
  def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
39
39
  super().__init__(op_codes, graph)
40
40
 
41
- # NOTE: Matmul is equivalent to Batch MatMul (batch=1)
42
- def define_bmm_node(self, inputs, outputs) -> circle.Operator.OperatorT:
43
- def set_bmm_option(operator):
44
- operator.builtinOptionsType = (
45
- circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
46
- )
47
- option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
48
- option.adjointLhs, option.adjointRhs = False, False
49
- option.asymmetricQuantizeInputs = False
50
- operator.builtinOptions = option
51
-
52
- op_index = get_op_index(
53
- circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
54
- )
55
- operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
56
- set_bmm_option(operator)
57
-
58
- return operator
59
-
60
- def define_transpose_node(self, inputs, outputs) -> circle.Operator.OperatorT:
61
- def set_transpose_option(operator):
62
- operator.builtinOptionsType = (
63
- circle.BuiltinOptions.BuiltinOptions.TransposeOptions
64
- )
65
- option = circle.TransposeOptions.TransposeOptionsT()
66
- operator.builtinOptions = option
67
-
68
- transpose_op_index = get_op_index(
69
- circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes
70
- )
71
- operator = create_builtin_operator(
72
- self.graph, transpose_op_index, inputs, outputs
73
- )
74
- set_transpose_option(operator)
75
- return operator
76
-
77
- def define_fc_node(self, inputs, outputs) -> circle.Operator.OperatorT:
78
- def set_fc_option(operator):
79
- operator.builtinOptionsType = (
80
- circle.BuiltinOptions.BuiltinOptions.FullyConnectedOptions
81
- )
82
- option = circle.FullyConnectedOptions.FullyConnectedOptionsT()
83
-
84
- option.fusedActivationFunction = (
85
- circle.ActivationFunctionType.ActivationFunctionType.NONE
86
- )
87
- option.weightsFormat = (
88
- circle.FullyConnectedOptionsWeightsFormat.FullyConnectedOptionsWeightsFormat.DEFAULT
89
- )
90
- option.keepNumDims = False
91
- option.asymmetricQuantizeInputs = False
92
- option.quantizedBiasType = circle.TensorType.TensorType.FLOAT32
93
-
94
- operator.builtinOptions = option
95
-
96
- fc_op_index = get_op_index(
97
- circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, self._op_codes
98
- )
99
- operator = create_builtin_operator(self.graph, fc_op_index, inputs, outputs)
100
- set_fc_option(operator)
101
- return operator
102
-
103
- """
104
- Define FullyConnnected with Tranpose operator.
105
- Note that those sets of operators are equivalent.
106
- (1) Matmul
107
- matmul( lhs[H, K], rhs[K, W'] ) -> output(H, W')
108
-
109
- (2) Transpose + FullyConneccted
110
- transpose( rhs[K, W'] ) -> trs_output[W', K]
111
- fullyconnected( lhs[H, K], trs_output[W', K] ) -> output(H, W')
112
- """
113
-
114
- def define_fc_with_transpose(
115
- self, node, inputs, outputs
116
- ) -> circle.Operator.OperatorT:
117
- lhs, rhs = inputs
118
-
119
- # get transpose shape
120
- rhs_tid: int = self.graph.get_tid_registered(rhs)
121
- rhs_tensor: circle.Tensor.TensorT = self.graph.tensors[rhs_tid]
122
- rhs_name: str = rhs.name
123
- rhs_type: int = rhs_tensor.type
124
- rhs_shape: List[int] = rhs_tensor.shape
125
- assert len(rhs_shape) == 2, len(rhs_shape)
126
- rhs_shape_transpose = [rhs_shape[1], rhs_shape[0]]
127
-
128
- # create transpose output tensor
129
- trs_output = self.graph.add_tensor_from_scratch(
130
- prefix=f"{rhs_name}_transposed_output",
131
- shape=rhs_shape_transpose,
132
- dtype=rhs_type,
133
- source_node=node,
134
- )
135
- trs_perm = self.graph.add_const_tensor(data=[1, 0], source_node=node)
136
- trs_operator = self.define_transpose_node([rhs, trs_perm], [trs_output])
137
- self.graph.add_operator(trs_operator)
138
-
139
- # define fc node
140
- fc_input = lhs
141
- fc_weight = trs_output
142
- fc_shape = [fc_weight.shape[0]]
143
- fc_bias = self.graph.add_const_tensor(
144
- data=[0.0] * fc_shape[0], source_node=node
145
- )
146
-
147
- operator = self.define_fc_node([fc_input, fc_weight, fc_bias], outputs)
148
-
149
- return operator
150
-
151
- def define_node(
152
- self, node: torch.fx.Node, prior_latency=True
153
- ) -> circle.Operator.OperatorT:
154
- """
155
- NOTE: Possibility of accuracy-latency trade-off
156
- From ONE compiler's perspective:
157
- - BMM uses per-tensor quantization for both rhs and lhs.
158
- - FC uses per-channel quantization for weight and per-tensor for input.
159
- Thus, FC is better in terms of accuracy.
160
- FC necessarily involves an additional transpose operation to be identical with mm.
161
- If transposed operand is const, it can be optimized by constant folding.
162
- Thus, convert FC only if tranpose can be folded.
163
- TODO set prior_latency outside
164
- """
41
+ def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
165
42
  args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
166
43
  input = args.input
167
44
  other = args.other
@@ -169,9 +46,16 @@ class MatmulDefaultVisitor(NodeVisitor):
169
46
  inputs = [input, other]
170
47
  outputs = [node]
171
48
 
172
- if not is_const(other) and prior_latency:
173
- operator = self.define_bmm_node(inputs, outputs)
174
- else:
175
- operator = self.define_fc_with_transpose(node, inputs, outputs)
49
+ op_index = get_op_index(
50
+ circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
51
+ )
52
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
53
+ operator.builtinOptionsType = (
54
+ circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
55
+ )
56
+ option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
57
+ option.adjointLhs, option.adjointRhs = False, False
58
+ option.asymmetricQuantizeInputs = False
59
+ operator.builtinOptions = option
176
60
 
177
61
  return operator
@@ -66,10 +66,7 @@ class MulTensorVisitor(BaseMulVisitor):
66
66
  self,
67
67
  node: torch.fx.Node,
68
68
  ) -> circle.Operator.OperatorT:
69
- args = MulTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
70
- input = args.input
71
- other = args.other
72
-
69
+ _ = MulTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
73
70
  operator = super().define_node(
74
71
  node,
75
72
  )
@@ -88,10 +85,7 @@ class MulScalarVisitor(BaseMulVisitor):
88
85
  self,
89
86
  node: torch.fx.Node,
90
87
  ) -> circle.Operator.OperatorT:
91
- args = MulScalarArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
92
- input = args.input
93
- other = args.other
94
-
88
+ _ = MulScalarArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
95
89
  operator = super().define_node(
96
90
  node,
97
91
  )
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Dict, List, TYPE_CHECKING
15
+ from typing import Dict, List, Optional, TYPE_CHECKING
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  import torch._ops
@@ -36,6 +36,7 @@ class BasePowVisitor(NodeVisitor):
36
36
  assert isinstance(node, torch.fx.Node), type(node)
37
37
  node_tensor: circle.Tensor.TensorT = self.graph.get_tensor(node)
38
38
  node_shape: List[int] = node_tensor.shape
39
+ node_shape_signature: Optional[List[int]] = node_tensor.shapeSignature
39
40
  op_index = get_op_index(
40
41
  circle.BuiltinOperator.BuiltinOperator.CAST, self._op_codes
41
42
  )
@@ -45,6 +46,7 @@ class BasePowVisitor(NodeVisitor):
45
46
  prefix=cast_name,
46
47
  dtype=cast_dtype,
47
48
  shape=node_shape,
49
+ shape_signature=node_shape_signature,
48
50
  source_node=node,
49
51
  )
50
52
  cast_operator = create_builtin_operator(