tico 0.1.0.dev251106__py3-none-any.whl → 0.2.0.dev260122__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tico/__init__.py +2 -2
  2. tico/_version.py +1 -0
  3. tico/passes/convert_conv3d_to_conv2d.py +435 -0
  4. tico/passes/convert_sym_size_to_circle_shape.py +99 -0
  5. tico/passes/decompose_batch_norm.py +9 -5
  6. tico/passes/lower_copy.py +95 -0
  7. tico/passes/ops.py +4 -0
  8. tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +251 -0
  9. tico/quantization/algorithm/fpi_gptq/quantizer.py +180 -0
  10. tico/quantization/algorithm/gptq/gptq.py +231 -11
  11. tico/quantization/algorithm/gptq/quantizer.py +18 -6
  12. tico/quantization/config/{pt2e.py → fpi_gptq.py} +11 -4
  13. tico/quantization/config/gptq.py +27 -4
  14. tico/quantization/public_interface.py +0 -10
  15. tico/quantization/wrapq/quantizer.py +2 -0
  16. tico/quantization/wrapq/wrappers/quant_elementwise.py +51 -11
  17. tico/serialize/operators/adapters/onert/llama_attention.py +51 -0
  18. tico/serialize/operators/op_attention.py +58 -0
  19. tico/serialize/operators/op_circle_shape.py +64 -0
  20. tico/serialize/operators/op_dequantize_per_channel.py +1 -0
  21. tico/serialize/operators/op_dequantize_per_tensor.py +1 -0
  22. tico/serialize/operators/op_transpose_conv.py +66 -50
  23. tico/utils/convert.py +16 -1
  24. tico/utils/padding.py +13 -5
  25. tico/utils/record_input.py +2 -2
  26. tico/utils/register_custom_op.py +63 -0
  27. tico/utils/validate_args_kwargs.py +49 -4
  28. tico-0.2.0.dev260122.dist-info/METADATA +631 -0
  29. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/RECORD +35 -46
  30. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/WHEEL +1 -1
  31. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/entry_points.txt +0 -1
  32. tico/quantization/algorithm/pt2e/annotation/annotator.py +0 -208
  33. tico/quantization/algorithm/pt2e/annotation/config.py +0 -26
  34. tico/quantization/algorithm/pt2e/annotation/op/__init__.py +0 -21
  35. tico/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +0 -63
  36. tico/quantization/algorithm/pt2e/annotation/op/add.py +0 -55
  37. tico/quantization/algorithm/pt2e/annotation/op/conv2d.py +0 -90
  38. tico/quantization/algorithm/pt2e/annotation/op/div.py +0 -55
  39. tico/quantization/algorithm/pt2e/annotation/op/linear.py +0 -92
  40. tico/quantization/algorithm/pt2e/annotation/op/mean.py +0 -51
  41. tico/quantization/algorithm/pt2e/annotation/op/mul.py +0 -55
  42. tico/quantization/algorithm/pt2e/annotation/op/relu6.py +0 -51
  43. tico/quantization/algorithm/pt2e/annotation/op/rsqrt.py +0 -51
  44. tico/quantization/algorithm/pt2e/annotation/op/sub.py +0 -55
  45. tico/quantization/algorithm/pt2e/annotation/spec.py +0 -45
  46. tico/quantization/algorithm/pt2e/annotation/utils.py +0 -88
  47. tico/quantization/algorithm/pt2e/quantizer.py +0 -81
  48. tico/quantization/algorithm/pt2e/transformation/__init__.py +0 -1
  49. tico/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -58
  50. tico/quantization/algorithm/pt2e/utils.py +0 -135
  51. tico/serialize/operators/op_copy.py +0 -187
  52. tico-0.1.0.dev251106.dist-info/METADATA +0 -392
  53. /tico/quantization/algorithm/{pt2e → fpi_gptq}/__init__.py +0 -0
  54. /tico/{quantization/algorithm/pt2e/annotation → serialize/operators/adapters/onert}/__init__.py +0 -0
  55. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info/licenses}/LICENSE +0 -0
  56. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,64 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+
28
+
29
+ @register_node_visitor
30
+ class CircleShapeVisitor(NodeVisitor):
31
+ """
32
+ Visitor for circle_custom::shape operator.
33
+
34
+ This operator extracts the shape of a tensor.
35
+ It's implemented using Circle's SHAPE operator.
36
+ """
37
+
38
+ target: List[torch._ops.OpOverload] = [
39
+ torch.ops.circle_custom.shape,
40
+ ]
41
+
42
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
43
+ super().__init__(op_codes, graph)
44
+
45
+ def define_node(
46
+ self,
47
+ node: torch.fx.Node,
48
+ ) -> circle.Operator.OperatorT:
49
+ # Args: (input)
50
+ input_node = node.args[0]
51
+
52
+ # SHAPE operator to get the full shape of input
53
+ op_index = get_op_index(
54
+ circle.BuiltinOperator.BuiltinOperator.SHAPE, self._op_codes
55
+ )
56
+
57
+ inputs = [input_node]
58
+ outputs = [node]
59
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
60
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ShapeOptions
61
+ operator.builtinOptions = circle.ShapeOptions.ShapeOptionsT()
62
+ operator.builtinOptions.outType = circle.TensorType.TensorType.INT32
63
+
64
+ return operator
@@ -19,6 +19,7 @@ if TYPE_CHECKING:
19
19
  import torch.fx
20
20
  import numpy as np
21
21
  import torch
22
+ import torch.ao.quantization.fx._decomposed # register `dequantize_per_channel`
22
23
  from circle_schema import circle
23
24
 
24
25
  from tico.serialize.circle_graph import CircleSubgraph
@@ -18,6 +18,7 @@ if TYPE_CHECKING:
18
18
  import torch._ops
19
19
  import torch.fx
20
20
  import torch
21
+ import torch.ao.quantization.fx._decomposed # register `dequantize_per_tensor`
21
22
  from circle_schema import circle
22
23
 
23
24
  from tico.serialize.circle_graph import CircleSubgraph
@@ -66,6 +66,29 @@ class TransposeConvVisitor(NodeVisitor):
66
66
  set_transpose_conv_option(operator, stride)
67
67
  return operator
68
68
 
69
+ def define_slice_node(
70
+ self,
71
+ src_tensor,
72
+ begin_vals: List[int],
73
+ size_vals: List[int],
74
+ dst_tensor,
75
+ ) -> circle.Operator.OperatorT:
76
+ slice_op_index = get_op_index(
77
+ circle.BuiltinOperator.BuiltinOperator.SLICE, self._op_codes
78
+ )
79
+
80
+ # Begin / Size as int32 const tensors
81
+ begin_arr = circle_legalize_dtype_to(begin_vals, dtype=torch.int32)
82
+ size_arr = circle_legalize_dtype_to(size_vals, dtype=torch.int32)
83
+
84
+ operator = create_builtin_operator(
85
+ self.graph, slice_op_index, [src_tensor, begin_arr, size_arr], [dst_tensor]
86
+ )
87
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.SliceOptions
88
+ option = circle.SliceOptions.SliceOptionsT()
89
+ operator.builtinOptions = option
90
+ return operator
91
+
69
92
  def __init__(self, op_codes: Dict[OpCode, int], graph):
70
93
  super().__init__(op_codes, graph)
71
94
 
@@ -88,65 +111,58 @@ class TransposeConvVisitor(NodeVisitor):
88
111
  assert len(output_shape) == 4, len(output_shape)
89
112
  assert len(weight_shape) == 4, len(weight_shape)
90
113
 
91
- pad_decision = identify_padding(padding, input_shape, output_shape, stride)
114
+ pad_decision = identify_padding(
115
+ padding, input_shape, output_shape, stride, is_transpose=True
116
+ )
92
117
 
93
118
  conv_input: torch.fx.Node | circle.Tensor.TensorT = input_
94
- if pad_decision.explicit_pad_hw is not None:
95
- pad_h, pad_w = pad_decision.explicit_pad_hw
96
- paddings = torch.tensor(
97
- [
98
- [0, 0],
99
- [pad_h, pad_h],
100
- [pad_w, pad_w],
101
- [0, 0],
102
- ],
103
- dtype=torch.int32,
104
- )
105
- pad_output_shape: List[int | torch.SymInt] = [
106
- input_shape[0],
107
- input_shape[1] + pad_h * 2,
108
- input_shape[2] + pad_w * 2,
109
- input_shape[3],
110
- ]
111
- pad_output_cshape, pad_output_cshape_signature = to_circle_shape(
112
- pad_output_shape
113
- )
114
- # create padded output tensor
115
- input_qparam: Optional[QuantParam] = input_.meta.get(QPARAM_KEY)
116
- pad_output = self.graph.add_tensor_from_scratch(
117
- prefix=f"{node.name}_input_pad_output",
118
- shape=pad_output_cshape,
119
- shape_signature=pad_output_cshape_signature,
120
- dtype=extract_circle_dtype(input_),
121
- qparam=input_qparam,
122
- source_node=node,
123
- )
124
- # CirclePad
125
- pad_operator = define_pad_node(
126
- self.graph, self._op_codes, [input_, paddings], [pad_output]
127
- )
128
- self.graph.add_operator(pad_operator)
129
- conv_input = pad_output
130
-
131
119
  if bias is None:
132
120
  # luci-interpreter can't run no bias conv. Let's add zero vector for bias.
133
121
  bias = [0.0] * weight_shape[0] # type: ignore[assignment]
134
122
 
135
- # First arguemnt is output shape of tconv.
136
- assert output_shape[0] == input_shape[0]
137
- assert output_shape[3] == weight_shape[0]
138
- tconv_output = circle_legalize_dtype_to(output_shape, dtype=torch.int32)
139
-
140
- tconv_output_tensor = self.graph.add_const_tensor(
141
- tconv_output, source_node=node
142
- )
123
+ # Compute pre-crop output shape if we need to apply an explicit crop.
124
+ if pad_decision.output_crop_hw is not None:
125
+ pad_h, pad_w = pad_decision.output_crop_hw
126
+ pre_h = int(output_shape[1]) + 2 * pad_h
127
+ pre_w = int(output_shape[2]) + 2 * pad_w
128
+ pre_out_shape = [output_shape[0], pre_h, pre_w, output_shape[3]] # NHWC
129
+ else:
130
+ pre_out_shape = list(output_shape)
131
+
132
+ tconv_output = circle_legalize_dtype_to(pre_out_shape, dtype=torch.int32)
133
+
134
+ pre_out_cshape, pre_out_csig = to_circle_shape(pre_out_shape)
135
+ tconv_tmp = node # type: ignore[assignment]
136
+ if pad_decision.output_crop_hw is not None:
137
+ tconv_tmp = self.graph.add_tensor_from_scratch( # type: ignore[assignment]
138
+ prefix=f"{node.name}_tconv_out_pre_crop",
139
+ shape=pre_out_cshape,
140
+ shape_signature=pre_out_csig,
141
+ dtype=extract_circle_dtype(node),
142
+ qparam=node.meta.get(QPARAM_KEY),
143
+ source_node=node,
144
+ )
143
145
 
144
- # TConv2D
145
146
  tconv2d_operator = self.define_transpose_conv_node(
146
- pad_decision.conv_padding_type, # 'SAME'(0) or 'VALID'(1)
147
+ pad_decision.conv_padding_type,
147
148
  stride,
148
- [tconv_output_tensor, weight, conv_input, bias],
149
- [node],
149
+ [tconv_output, weight, conv_input, bias],
150
+ [tconv_tmp],
150
151
  )
151
152
 
153
+ # If we need an output crop, insert a SLICE to produce the final tensor.
154
+ if pad_decision.output_crop_hw is not None:
155
+ self.graph.add_operator(tconv2d_operator)
156
+ pad_h, pad_w = pad_decision.output_crop_hw
157
+ begin = [0, pad_h, pad_w, 0]
158
+ size = [
159
+ int(output_shape[0]),
160
+ int(output_shape[1]),
161
+ int(output_shape[2]),
162
+ int(output_shape[3]),
163
+ ]
164
+
165
+ slice_op = self.define_slice_node(tconv_tmp, begin, size, node)
166
+ return slice_op
167
+
152
168
  return tconv2d_operator
tico/utils/convert.py CHANGED
@@ -25,10 +25,12 @@ from tico.passes.cast_clamp_mixed_type_args import CastClampMixedTypeArgs
25
25
  from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
26
26
  from tico.passes.const_prop_pass import ConstPropPass
27
27
  from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
28
+ from tico.passes.convert_conv3d_to_conv2d import ConvertConv3dToConv2d
28
29
  from tico.passes.convert_expand_to_slice_cat import ConvertExpandToSliceCat
29
30
  from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
30
31
  from tico.passes.convert_matmul_to_linear import ConvertMatmulToLinear
31
32
  from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy
33
+ from tico.passes.convert_sym_size_to_circle_shape import ConvertSymSizeToCircleShape
32
34
  from tico.passes.convert_to_relu6 import ConvertToReLU6
33
35
  from tico.passes.decompose_addmm import DecomposeAddmm
34
36
  from tico.passes.decompose_batch_norm import DecomposeBatchNorm
@@ -47,6 +49,7 @@ from tico.passes.legalize_causal_mask_value import LegalizeCausalMaskValue
47
49
  from tico.passes.legalize_predefined_layout_operators import (
48
50
  LegalizePreDefinedLayoutOperators,
49
51
  )
52
+ from tico.passes.lower_copy import LowerCopy
50
53
  from tico.passes.lower_pow2_to_mul import LowerPow2ToMul
51
54
  from tico.passes.lower_to_resize_nearest_neighbor import LowerToResizeNearestNeighbor
52
55
  from tico.passes.lower_to_slice import passes as LowerToSlicePasses
@@ -97,12 +100,15 @@ def traced_run_decompositions(exported_program: ExportedProgram):
97
100
  torch.ops.aten.conv2d.padding,
98
101
  torch.ops.aten.conv1d.default,
99
102
  torch.ops.aten.conv1d.padding,
103
+ torch.ops.aten.conv3d.default,
104
+ torch.ops.aten.conv3d.padding,
100
105
  torch.ops.aten.conv_transpose2d.input,
101
106
  torch.ops.aten.instance_norm.default,
102
107
  torch.ops.aten._safe_softmax.default,
103
108
  torch.ops.aten.relu6.default, # Do not decompose to hardtanh
104
109
  torch.ops.aten.linear.default,
105
110
  torch.ops.aten.upsample_nearest2d.vec,
111
+ torch.ops.aten.rms_norm.default,
106
112
  )
107
113
  ep = ep.run_decompositions(_preserve_ops=_preserve_ops)
108
114
 
@@ -115,6 +121,8 @@ def traced_run_decompositions(exported_program: ExportedProgram):
115
121
  torch.ops.aten.conv2d.padding,
116
122
  torch.ops.aten.conv1d.default,
117
123
  torch.ops.aten.conv1d.padding,
124
+ torch.ops.aten.conv3d.default,
125
+ torch.ops.aten.conv3d.padding,
118
126
  torch.ops.aten.conv_transpose2d.input,
119
127
  torch.ops.aten.instance_norm.default,
120
128
  torch.ops.aten._safe_softmax.default,
@@ -138,6 +146,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
138
146
  or torch.__version__.startswith("2.8")
139
147
  or torch.__version__.startswith("2.9")
140
148
  or torch.__version__.startswith("2.10")
149
+ or torch.__version__.startswith("2.11")
141
150
  ):
142
151
  return run_decompositions(exported_program)
143
152
  else:
@@ -224,6 +233,8 @@ def convert_exported_module_to_circle(
224
233
  FillMetaVal(),
225
234
  ExtractDtypeKwargsPass(),
226
235
  RemoveNop(),
236
+ LowerCopy(),
237
+ ConvertSymSizeToCircleShape(),
227
238
  ConvertLayoutOpToReshape(),
228
239
  RestoreLinear(),
229
240
  ConvertToReLU6(),
@@ -258,6 +269,7 @@ def convert_exported_module_to_circle(
258
269
  LegalizePreDefinedLayoutOperators(),
259
270
  LowerPow2ToMul(),
260
271
  ConvertConv1dToConv2d(),
272
+ ConvertConv3dToConv2d(),
261
273
  *LowerToSlicePasses(),
262
274
  FuseLeadingUnsqueezeReshape(),
263
275
  CastClampMixedTypeArgs(),
@@ -316,7 +328,10 @@ def convert(
316
328
  mod, args, kwargs, dynamic_shapes=dynamic_shapes, strict=strict
317
329
  )
318
330
 
319
- circle_binary = convert_exported_module_to_circle(exported_program, config=config)
331
+ with SuppressWarning(FutureWarning, ".*LeafSpec*"):
332
+ circle_binary = convert_exported_module_to_circle(
333
+ exported_program, config=config
334
+ )
320
335
 
321
336
  return CircleModel(circle_binary)
322
337
 
tico/utils/padding.py CHANGED
@@ -35,6 +35,7 @@ class ConvPaddingInfo(NamedTuple):
35
35
 
36
36
  conv_padding_type: ConvPadding
37
37
  explicit_pad_hw: Optional[Tuple[int, int]] # None -> no extra Pad() op needed
38
+ output_crop_hw: Optional[Tuple[int, int]]
38
39
 
39
40
 
40
41
  def identify_padding(
@@ -42,6 +43,7 @@ def identify_padding(
42
43
  input_shape: Sequence[int | torch.SymInt] | torch.Size,
43
44
  output_shape: Sequence[int | torch.SymInt] | torch.Size,
44
45
  stride: Sequence[int],
46
+ is_transpose: bool = False,
45
47
  ) -> ConvPaddingInfo:
46
48
  """
47
49
  Normalizes all PyTorch `padding` variants to a single decision.
@@ -61,9 +63,9 @@ def identify_padding(
61
63
  if isinstance(padding, str):
62
64
  pad = padding.lower()
63
65
  if pad == "valid":
64
- return ConvPaddingInfo(ConvPadding.VALID, None)
66
+ return ConvPaddingInfo(ConvPadding.VALID, None, None)
65
67
  if pad == "same":
66
- return ConvPaddingInfo(ConvPadding.SAME, None)
68
+ return ConvPaddingInfo(ConvPadding.SAME, None, None)
67
69
  raise InvalidArgumentError(f"Unknown padding string: {padding}")
68
70
 
69
71
  # ─── 2. List / tuple form ─────────────────────────────────────────────
@@ -73,15 +75,21 @@ def identify_padding(
73
75
  )
74
76
 
75
77
  pad_h, pad_w = padding
78
+
79
+ if is_transpose:
80
+ if pad_h == 0 and pad_w == 0:
81
+ return ConvPaddingInfo(ConvPadding.VALID, None, None)
82
+ return ConvPaddingInfo(ConvPadding.VALID, None, (pad_h, pad_w))
83
+
76
84
  # [0, 0] → VALID
77
85
  if pad_h == 0 and pad_w == 0:
78
- return ConvPaddingInfo(ConvPadding.VALID, None)
86
+ return ConvPaddingInfo(ConvPadding.VALID, None, None)
79
87
 
80
88
  # SAME heuristic: output H/W already match input when stride is 1
81
89
  hw_in = tuple(input_shape[1:3])
82
90
  hw_out = tuple(output_shape[1:3])
83
91
  if hw_in == hw_out and stride == [1, 1]:
84
- return ConvPaddingInfo(ConvPadding.SAME, None)
92
+ return ConvPaddingInfo(ConvPadding.SAME, None, None)
85
93
 
86
94
  # Anything else = explicit symmetric padding
87
- return ConvPaddingInfo(ConvPadding.VALID, (pad_h, pad_w))
95
+ return ConvPaddingInfo(ConvPadding.VALID, (pad_h, pad_w), None)
@@ -49,13 +49,13 @@ class RecordingInput:
49
49
  def __init__(
50
50
  self,
51
51
  module: nn.Module,
52
- condition: Callable[[dict], bool] = lambda args_dict: True,
52
+ condition: Optional[Callable[[dict], bool]] = None,
53
53
  *,
54
54
  input_to_remove: Optional[List[str]] = [],
55
55
  ):
56
56
  self.module = module
57
57
  self.forward_org = module.forward
58
- self.condition = condition
58
+ self.condition = condition or (lambda args_dict: True)
59
59
  self.input_to_remove = input_to_remove
60
60
  self.sig = inspect.signature(self.forward_org)
61
61
 
@@ -727,6 +727,67 @@ def CircleRMSNorm():
727
727
  return hidden_states.new_empty(hidden_states.size())
728
728
 
729
729
 
730
+ def CircleAttention():
731
+ @custom_op("circle_custom::attention", mutates_args=())
732
+ def attention(
733
+ hidden_states: torch.Tensor,
734
+ wq: torch.Tensor,
735
+ wk: torch.Tensor,
736
+ wv: torch.Tensor,
737
+ wo: torch.Tensor,
738
+ position_cos: torch.Tensor,
739
+ position_sin: torch.Tensor,
740
+ attention_mask: torch.Tensor,
741
+ past_key: torch.Tensor,
742
+ past_value: torch.Tensor,
743
+ cache_position: torch.Tensor,
744
+ ) -> torch.Tensor:
745
+ return None
746
+
747
+ @register_fake("circle_custom::attention")
748
+ def _(
749
+ hidden_states: torch.Tensor,
750
+ wq: torch.Tensor,
751
+ wk: torch.Tensor,
752
+ wv: torch.Tensor,
753
+ wo: torch.Tensor,
754
+ position_cos: torch.Tensor,
755
+ position_sin: torch.Tensor,
756
+ attention_mask: torch.Tensor,
757
+ past_key: torch.Tensor,
758
+ past_value: torch.Tensor,
759
+ cache_position: torch.Tensor,
760
+ ) -> torch.Tensor:
761
+ return hidden_states
762
+
763
+
764
+ def CircleShape():
765
+ """
766
+ Custom operator to extract the shape of a tensor.
767
+ This is similar to TensorFlow's shape operator and is used to preserve
768
+ dynamic shape information in the Circle model.
769
+
770
+ Args:
771
+ input_: Input tensor
772
+
773
+ Returns:
774
+ A 1D tensor containing the shape of the input tensor
775
+ """
776
+
777
+ @custom_op("circle_custom::shape", mutates_args=())
778
+ def shape(input_: torch.Tensor) -> torch.Tensor:
779
+ # Return the shape of the input tensor as a 1D tensor
780
+ shape_val = list(input_.size())
781
+ return torch.tensor(shape_val, dtype=torch.int32)
782
+
783
+ @register_fake("circle_custom::shape")
784
+ def _(input_: torch.Tensor) -> torch.Tensor:
785
+ # Return a 1D tensor with symbolic shape
786
+ # The actual value will be determined at runtime
787
+ rank = len(input_.size())
788
+ return torch.empty([rank], dtype=torch.int32)
789
+
790
+
730
791
  # Add custom ops to the torch namespace
731
792
  def RegisterOps():
732
793
  CircleResizeNearestNeighbor()
@@ -740,3 +801,5 @@ def RegisterOps():
740
801
  CircleInstanceNorm()
741
802
  CircleQuantizeMX()
742
803
  CircleRMSNorm()
804
+ CircleAttention()
805
+ CircleShape()
@@ -171,6 +171,26 @@ class CatArgs:
171
171
  dim: int = 0
172
172
 
173
173
 
174
+ @enforce_type
175
+ @dataclass
176
+ class CircleAttentionArgs:
177
+ """
178
+ For circle.BuiltinOperator.BuiltinOperator.ATTENTION
179
+ """
180
+
181
+ hidden_states: torch.fx.Node
182
+ wq: torch.fx.Node
183
+ wk: torch.fx.Node
184
+ wv: torch.fx.Node
185
+ wo: torch.fx.Node
186
+ position_cos: torch.fx.Node
187
+ position_sin: torch.fx.Node
188
+ attention_mask: torch.fx.Node
189
+ past_key: torch.fx.Node
190
+ past_value: torch.fx.Node
191
+ cache_position: torch.fx.Node
192
+
193
+
174
194
  @enforce_type
175
195
  @dataclass
176
196
  class CircleRMSNormArgs:
@@ -299,6 +319,27 @@ class Conv1DArgs:
299
319
  assert len(self.dilation) == 1, len(self.dilation)
300
320
 
301
321
 
322
+ @enforce_type
323
+ @dataclass
324
+ class Conv3DArgs:
325
+ """
326
+ conv3d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, SymInt groups=1) -> Tensor
327
+ conv3d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, str padding="valid", SymInt[3] dilation=1, SymInt groups=1) -> Tensor
328
+ """
329
+
330
+ input: torch.fx.Node
331
+ weight: torch.fx.Node
332
+ bias: Union[torch.fx.Node, None] = None
333
+ stride: List[int] = field(default_factory=lambda: [1, 1, 1])
334
+ padding: Union[List[int], str] = field(default_factory=lambda: [0, 0, 0])
335
+ dilation: List[int] = field(default_factory=lambda: [1, 1, 1])
336
+ groups: int = 1
337
+
338
+ def __post_init__(self):
339
+ assert len(self.stride) == 3, len(self.stride)
340
+ assert len(self.dilation) == 3, len(self.dilation)
341
+
342
+
302
343
  @enforce_type
303
344
  @dataclass
304
345
  class CopyArgs:
@@ -930,7 +971,7 @@ class RepeatArgs:
930
971
  """
931
972
 
932
973
  input: torch.fx.Node
933
- repeats: List[int]
974
+ repeats: List[Union[int, torch.SymInt, torch.fx.Node]]
934
975
 
935
976
 
936
977
  @enforce_type
@@ -938,10 +979,14 @@ class RepeatArgs:
938
979
  class ReshapeArgs:
939
980
  """
940
981
  reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
982
+
983
+ Note: After PrepareReshapeDynamicShape pass, shape can be either:
984
+ - A list of int/SymInt/Node (original or static)
985
+ - A single Node (dynamic shape tensor prepared by the pass)
941
986
  """
942
987
 
943
988
  input: torch.fx.Node
944
- shape: List[int]
989
+ shape: Union[List[Union[int, torch.SymInt, torch.fx.Node]], torch.fx.Node]
945
990
 
946
991
 
947
992
  @enforce_type
@@ -1077,7 +1122,7 @@ class SplitWithSizesArgs:
1077
1122
  """
1078
1123
 
1079
1124
  input: torch.fx.Node
1080
- split_sizes: List[int]
1125
+ split_sizes: List[Union[int, torch.SymInt, torch.fx.Node]]
1081
1126
  dim: int = 0
1082
1127
 
1083
1128
 
@@ -1218,7 +1263,7 @@ class ViewArgs:
1218
1263
  """
1219
1264
 
1220
1265
  input: torch.fx.Node
1221
- size: List[int]
1266
+ size: List[Union[int, torch.SymInt, torch.fx.Node]]
1222
1267
 
1223
1268
 
1224
1269
  @enforce_type