tico 0.1.0.dev250603__py3-none-any.whl → 0.1.0.dev250604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
21
21
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
22
22
 
23
23
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
24
- __version__ = "0.1.0.dev250603"
24
+ __version__ = "0.1.0.dev250604"
25
25
 
26
26
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
27
27
  SECURE_TORCH_VERSION = "2.6.0"
@@ -316,12 +316,6 @@ class LegalizePreDefinedLayoutOperators(PassBase):
316
316
  ceil_mode = args.ceil_mode
317
317
  if ceil_mode:
318
318
  raise NotYetSupportedError("Only support non-ceil model.")
319
- count_include_pad = args.count_include_pad
320
- if not count_include_pad:
321
- # NOTE count_include_pad = False can be partially supported with SAME padding in circle.
322
- raise NotYetSupportedError(
323
- "For the case that the count_include_pad is False is not yet supported."
324
- )
325
319
  divisor_override = args.divisor_override
326
320
  if divisor_override is not None:
327
321
  raise NotYetSupportedError(
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import math
15
16
  from typing import Dict, List, TYPE_CHECKING
16
17
 
17
18
  if TYPE_CHECKING:
@@ -26,25 +27,101 @@ from tico.serialize.operators.hashable_opcode import OpCode
26
27
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
28
  from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
29
  from tico.utils.define import define_pad_node
30
+ from tico.utils.errors import NotYetSupportedError
29
31
  from tico.utils.validate_args_kwargs import AvgPool2dArgs
30
32
 
31
33
 
32
34
  @register_node_visitor
33
35
  class AvgPool2DVisitor(NodeVisitor):
36
+ """
37
+ This class defines how to serialize AvgPool2D operation into Circle IR.
38
+
39
+ Torch | Circle
40
+
41
+ count_include_pad: True/False | (count_include_pad): Always False
42
+ padding: number (could be valid, same, or etc) | padding: "valid"/"same"
43
+
44
+ * Circle's avgpool2d has no option for count_include_pad, so we always set it as False.
45
+ """
46
+
34
47
  target: List[torch._ops.OpOverload] = [torch.ops.circle_custom.avgpool2d]
35
48
 
36
49
  def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
50
  super().__init__(op_codes, graph)
38
51
 
52
+ def has_padding(self, args: AvgPool2dArgs) -> bool:
53
+ padding = args.padding
54
+ if padding[0] == 0 and padding[1] == 0:
55
+ return False
56
+ else:
57
+ return True
58
+
59
+ def has_same_padding(self, args: AvgPool2dArgs) -> bool:
60
+ input_shape = list(extract_shape(args.input))
61
+ kernel_size = args.kernel_size
62
+ stride = args.stride
63
+ padding = args.padding
64
+ # TODO Update this function when supporting ceil_mode = True
65
+ assert args.ceil_mode is False
66
+ output_height = math.floor(
67
+ (input_shape[1] + padding[0] * 2 - kernel_size[0]) / stride[0] + 1
68
+ )
69
+ output_width = math.floor(
70
+ (input_shape[2] + padding[1] * 2 - kernel_size[1]) / stride[1] + 1
71
+ )
72
+
73
+ return input_shape[1] == output_height and input_shape[2] == output_width
74
+
75
+ def define_avgpool_node(self, inputs, outputs, padding, stride, kernel_size):
76
+ op_index = get_op_index(
77
+ circle.BuiltinOperator.BuiltinOperator.AVERAGE_POOL_2D,
78
+ self._op_codes,
79
+ )
80
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
81
+
82
+ # Op-specific option
83
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.Pool2DOptions
84
+ option = circle.Pool2DOptions.Pool2DOptionsT()
85
+
86
+ assert padding in {"SAME": 0, "VALID": 1}
87
+
88
+ option.padding = {"SAME": 0, "VALID": 1}[padding]
89
+ option.strideH = stride[0]
90
+ option.strideW = stride[1]
91
+ option.filterHeight = kernel_size[0]
92
+ option.filterWidth = kernel_size[1]
93
+ option.fusedActivationFunction = (
94
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
95
+ )
96
+
97
+ operator.builtinOptions = option
98
+ return operator
99
+
39
100
  def define_node(
40
101
  self,
41
102
  node: torch.fx.Node,
42
103
  ) -> circle.Operator.OperatorT:
104
+ """
105
+ PSEUDO CODE
106
+
107
+ if count_include_pad == True:
108
+ (Circle cannot represent count_include_pad=True in AvgPool2D. Therefore we manually add zero padding node.)
109
+ DEFINE zero padding node
110
+ DEFINE avgpool node with no padding (valid)
111
+ if count_include_pad == False:
112
+ (Lucky! Circle can represent count_include_pad=False)
113
+ DEFINE avgpool node with same/valid padding.
114
+
115
+ (However, it cannot represent all paddings. So, if the padding is not same or valid, we throw an error.)
116
+ if the paddding is neither same nor valid:
117
+ THROW an error.
118
+ """
43
119
  args = AvgPool2dArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44
120
  input = args.input
45
121
  kernel_size = args.kernel_size
46
122
  stride = args.stride
47
123
  padding = args.padding
124
+ count_include_pad = args.count_include_pad
48
125
 
49
126
  avgpool_input: torch.fx.Node | circle.Tensor.TensorT = input
50
127
 
@@ -81,32 +158,33 @@ class AvgPool2DVisitor(NodeVisitor):
81
158
  self.graph.add_operator(pad_operator)
82
159
  return padded_input_tensor
83
160
 
84
- if padding is not None:
85
- avgpool_input = define_padding_node()
161
+ if count_include_pad is True:
162
+ # Add padding before avgpool2d
163
+ # Circle's avgpool2d does not support count_include_pad=True, so we need to add padding manually
164
+ if self.has_padding(args):
165
+ avgpool_input = define_padding_node()
86
166
 
87
- inputs = [avgpool_input]
88
- outputs = [node]
89
-
90
- op_index = get_op_index(
91
- circle.BuiltinOperator.BuiltinOperator.AVERAGE_POOL_2D,
92
- self._op_codes,
93
- )
94
- operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
95
-
96
- # Op-specific option
97
- operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.Pool2DOptions
98
- option = circle.Pool2DOptions.Pool2DOptionsT()
99
-
100
- SAME, VALID = 0, 1
101
- option.padding = VALID
102
- option.strideH = stride[0]
103
- option.strideW = stride[1]
104
- option.filterHeight = kernel_size[0]
105
- option.filterWidth = kernel_size[1]
106
- option.fusedActivationFunction = (
107
- circle.ActivationFunctionType.ActivationFunctionType.NONE
108
- )
109
-
110
- operator.builtinOptions = option
111
-
112
- return operator
167
+ result = self.define_avgpool_node(
168
+ [avgpool_input], [node], "VALID", stride, kernel_size
169
+ )
170
+ elif count_include_pad is False:
171
+ if not self.has_padding(args): # valid padding
172
+ result = self.define_avgpool_node(
173
+ [avgpool_input], [node], "VALID", stride, kernel_size
174
+ )
175
+ elif self.has_same_padding(args):
176
+ result = self.define_avgpool_node(
177
+ [avgpool_input], [node], "SAME", stride, kernel_size
178
+ )
179
+ else:
180
+ # CASE: count_include_pad is False and not VALID/SAME padding
181
+ #
182
+ # Implement this when it's needed.
183
+ # If needed, may it help: the idea of ratio masking in https://github.com/Samsung/TICO/pull/119
184
+ raise NotYetSupportedError(
185
+ f"Padding({padding}) with count_include_pad({count_include_pad}) is not supported yet."
186
+ )
187
+ else:
188
+ raise RuntimeError("Cannot reach here")
189
+
190
+ return result
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250603
3
+ Version: 0.1.0.dev250604
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=b-iBFbe5L_XPwZaSnmweHVjabHVJPCZBqZOh5P1nUZU,1743
1
+ tico/__init__.py,sha256=X2eazANf-Y9gV1fcug-QytZyLbNFnJqqW658DamD_wI,1743
2
2
  tico/pt2_to_circle.py,sha256=PPmFNw20jw2Z2VyM3ln9pX__jTzBOAZiv0gT5a-p-Y8,2666
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
@@ -77,7 +77,7 @@ tico/passes/extract_dtype_kwargs.py,sha256=hfGJ_GfZULbBmLif2AJkhPHVifhucxBiLoQI8
77
77
  tico/passes/fill_meta_val.py,sha256=Xbam6Aq90ZfWItZw1dgLIwH_q8RCiU5JodKNqkj-ink,1797
78
78
  tico/passes/fuse_redundant_reshape_to_mean.py,sha256=SzHLLL2Yfj4k1c2L5i4PVI9EFUilHRfIS6S-hahKFCM,3702
79
79
  tico/passes/legalize_causal_mask_value.py,sha256=KZc_UPk7CGPXO35JOu6dVrOzRJx-ZpJoVvuzz-GvQek,4080
80
- tico/passes/legalize_predefined_layout_operators.py,sha256=OKVKxw039Bl9df7YCt17mWOHaGhQLWkJeSIPQpWdNBM,16224
80
+ tico/passes/legalize_predefined_layout_operators.py,sha256=N2TtJInjSTk-E5afnkDXXbo9v4zTM7yzsjna3VoihMw,15895
81
81
  tico/passes/lower_pow2_to_mul.py,sha256=imx9CoKG4bLyNqv4F-Z203s_P0-0SdRH-y4_Q0PTZVo,2304
82
82
  tico/passes/lower_to_resize_nearest_neighbor.py,sha256=4bIxPSyNEzznTw8f8D9hMNXbZ0KmMPPPfRvITXonEz0,8881
83
83
  tico/passes/lower_to_slice.py,sha256=6xK7A2hdqdBdN2XRyt2rFGnqJGcaXxmgndis4kn2q2w,7112
@@ -107,7 +107,7 @@ tico/serialize/operators/op_alias_copy.py,sha256=Xu9OiILbGf8oddh8yTqovvLfgVs8XYV
107
107
  tico/serialize/operators/op_any.py,sha256=WMsHLq7WIcl6rD2G3QqpWRSCR-a6UYX6y5AjB6BDS3U,5049
108
108
  tico/serialize/operators/op_arange_start_step.py,sha256=0T5lWwh3TfsFStmVv0v5qG03KENRDBmMix08RXQ4D-U,2132
109
109
  tico/serialize/operators/op_argmax.py,sha256=ARyGHlmWVmzwCct93V5x1-VyKqhxMOvV8GuM8yQWXdo,2290
110
- tico/serialize/operators/op_avg_pool2d.py,sha256=kLtbB1VmGC8H51rF4OmQIZxpUkpK1ZlqYXNkJdvABYc,4068
110
+ tico/serialize/operators/op_avg_pool2d.py,sha256=ABxhfowDz7SXlWnW2iQuSA5X52xm0PGLs-N1l9vGXbo,7488
111
111
  tico/serialize/operators/op_bmm.py,sha256=AELjHC9ISFPIzEEl5Kr1s4GSNLZElwZmVZJWkEyCEoA,2189
112
112
  tico/serialize/operators/op_cat.py,sha256=XDYOh0XAyrM0TlxVm6Sa0OFFGrKk7aSDcGXC-hYX4gs,2204
113
113
  tico/serialize/operators/op_clamp.py,sha256=V3rncHvUAuJ2nXOyywTnOGCvNBeCQGqQIW1_zxKlSsA,4231
@@ -194,9 +194,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
194
194
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
195
195
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
196
196
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
197
- tico-0.1.0.dev250603.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
198
- tico-0.1.0.dev250603.dist-info/METADATA,sha256=CcrcTSCD543XA0KQHdiYENeTXpcbw4p7OwkJtHOXtiY,8633
199
- tico-0.1.0.dev250603.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
200
- tico-0.1.0.dev250603.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
201
- tico-0.1.0.dev250603.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
202
- tico-0.1.0.dev250603.dist-info/RECORD,,
197
+ tico-0.1.0.dev250604.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
198
+ tico-0.1.0.dev250604.dist-info/METADATA,sha256=gN7iYzRhOodrWqExhbNBTMQ5S_6iOpb-vte13nECcak,8633
199
+ tico-0.1.0.dev250604.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
200
+ tico-0.1.0.dev250604.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
201
+ tico-0.1.0.dev250604.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
202
+ tico-0.1.0.dev250604.dist-info/RECORD,,