tico 0.1.0.dev250520__py3-none-any.whl → 0.1.0.dev250521__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -22,7 +22,7 @@ from tico.config import CompileConfigV1, get_default_config
22
22
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
23
23
 
24
24
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
25
- __version__ = "0.1.0.dev250520"
25
+ __version__ = "0.1.0.dev250521"
26
26
 
27
27
 
28
28
  if Version(torch.__version__) < Version("2.5.0"):
@@ -88,6 +88,25 @@ class DecomposeGroupNorm(PassBase):
88
88
  def __init__(self):
89
89
  super().__init__()
90
90
 
91
+ def _insert_norm(self, graph, tensor, eps):
92
+ """
93
+ Insert (tensor - mean) / sqrt(var + eps)) into the graph
94
+ and return the normalized tensor node.
95
+ """
96
+ mean = graph.call_function(
97
+ torch.ops.aten.mean.dim, (tensor, [-1]), {"keepdim": True}
98
+ )
99
+ deviation = graph.call_function(torch.ops.aten.sub.Tensor, (tensor, mean))
100
+ squared = graph.call_function(torch.ops.aten.pow.Tensor_Scalar, (deviation, 2))
101
+ var = graph.call_function(
102
+ torch.ops.aten.mean.dim, (squared, [-1]), {"keepdim": True}
103
+ )
104
+ inverse_std = graph.call_function(
105
+ torch.ops.aten.rsqrt.default,
106
+ (graph.call_function(torch.ops.aten.add.Tensor, (var, eps)),),
107
+ )
108
+ return graph.call_function(torch.ops.aten.mul.Tensor, (deviation, inverse_std))
109
+
91
110
  def call(self, exported_program: ExportedProgram) -> PassResult:
92
111
  logger = logging.getLogger(__name__)
93
112
 
@@ -155,52 +174,20 @@ class DecomposeGroupNorm(PassBase):
155
174
  pack_shape = [layer_size, norm_size]
156
175
 
157
176
  with gm.graph.inserting_before(node):
158
- layer = graph.call_function(
159
- # Sometimes, `x` has a stride for NHWC, which can't be reshaped with `aten.view.default`.
160
- # TODO Find out how to process such case properly.
161
- torch.ops.aten.reshape.default,
162
- (x, pack_shape),
163
- )
164
- layer_mean = graph.call_function(
165
- torch.ops.aten.mean.dim,
166
- (layer, [-1]),
167
- )
168
- layer_mean_reshape = graph.call_function(
169
- torch.ops.aten.view.default,
170
- (layer_mean, [layer_size, 1]),
171
- )
172
- layer_deviation = graph.call_function(
173
- torch.ops.aten.sub.Tensor,
174
- (layer, layer_mean_reshape),
175
- )
176
- layer_sqr_diff = graph.call_function(
177
- torch.ops.aten.pow.Tensor_Scalar,
178
- (layer_deviation, 2),
179
- )
180
- var = graph.call_function(
181
- torch.ops.aten.mean.dim,
182
- (layer_sqr_diff, [-1]),
183
- )
184
- var_eps = graph.call_function(
185
- torch.ops.aten.add.Tensor,
186
- (var, eps),
187
- )
188
- rstd = graph.call_function(
189
- torch.ops.aten.rsqrt.default,
190
- (var_eps,),
191
- )
192
- rstd_reshape = graph.call_function(
193
- torch.ops.aten.view.default,
194
- (rstd, [layer_size, 1]),
195
- )
196
- layer_norm = graph.call_function(
197
- torch.ops.aten.mul.Tensor,
198
- (layer_deviation, rstd_reshape),
199
- )
200
- layer_norm = graph.call_function(
201
- torch.ops.aten.view.default,
202
- (layer_norm, x_shape),
203
- )
177
+ # Branch only on whether a reshape is needed; the normalization is shared.
178
+ if norm_size != x_shape[-1]:
179
+ # Pack groups so that the last dimension equals norm_size.
180
+ packed = graph.call_function(
181
+ torch.ops.aten.reshape.default, (x, pack_shape)
182
+ )
183
+ normed = self._insert_norm(graph, packed, eps)
184
+ # Restore the original shape after normalization.
185
+ layer_norm = graph.call_function(
186
+ torch.ops.aten.reshape.default, (normed, x_shape)
187
+ )
188
+ else:
189
+ # The input already has norm_size in the last dimension.
190
+ layer_norm = self._insert_norm(graph, x, eps)
204
191
 
205
192
  # weight
206
193
  if weight:
@@ -45,7 +45,7 @@ class MeanVisitor(NodeVisitor):
45
45
  args = MeanDimArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
46
46
  input = args.input
47
47
  dim = args.dim
48
- keep_dims = args.keep_dims
48
+ keep_dims = args.keepdim
49
49
 
50
50
  dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
51
51
  inputs = [input, dim_i32]
@@ -599,7 +599,7 @@ class MeanDimArgs:
599
599
 
600
600
  input: torch.fx.Node
601
601
  dim: List[int]
602
- keep_dims: bool = False
602
+ keepdim: bool = False
603
603
  dtype: Optional[torch.dtype] = None
604
604
 
605
605
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250520
3
+ Version: 0.1.0.dev250521
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=80nv6fiJM89LQkz5EODIec-qHWe6hHtEOZ5u7L1wYnY,1181
1
+ tico/__init__.py,sha256=zsOHhqeEvcmCBCvdS74zuPFmwmyF7Ls-Y-7EGShtpIg,1181
2
2
  tico/pt2_to_circle.py,sha256=PPmFNw20jw2Z2VyM3ln9pX__jTzBOAZiv0gT5a-p-Y8,2666
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
@@ -70,7 +70,7 @@ tico/passes/decompose_addmm.py,sha256=_yNX7wx1Y9HJI5ksUJI-UQLHpoNawbUbF8kcm2zGHw
70
70
  tico/passes/decompose_batch_norm.py,sha256=d1V9UOkm_5BV0NGLyuQfz4I9NpO7I3ZrRugt7EXM-XM,7016
71
71
  tico/passes/decompose_fake_quantize.py,sha256=7ZJyTIDj2iKgWa5q8mBSq6k0GX0vs_XyQdsIiWFJoTU,5175
72
72
  tico/passes/decompose_fake_quantize_tensor_qparams.py,sha256=kOQaODKl_GCE19h-UZGmxnTcHtlvphI63dVAmMQL_Bk,13823
73
- tico/passes/decompose_group_norm.py,sha256=UVEzOBi4aUr3IlP6iWlYlo4u-8cmJ4JvKl8r2-qGBj4,9670
73
+ tico/passes/decompose_group_norm.py,sha256=xn1xnT-2e6BvelRAzX8O7wg9kBWURmPldkRvpfYFXHQ,9407
74
74
  tico/passes/decompose_grouped_conv2d.py,sha256=KJhH6PX7l9k9T8KBV8JDAvaSfJuUnRo_jtvGF2aM-LA,8277
75
75
  tico/passes/decompose_slice_scatter.py,sha256=ko9p8v-zY5rOx4aSpWomwSdSWb1lIF32gnU7ik5xgII,5604
76
76
  tico/passes/extract_dtype_kwargs.py,sha256=hfGJ_GfZULbBmLif2AJkhPHVifhucxBiLoQI862Yejk,4303
@@ -140,7 +140,7 @@ tico/serialize/operators/op_logical_not.py,sha256=ugrVcRqR3IvUUaiRVW5cArCYJbzmkc
140
140
  tico/serialize/operators/op_lt.py,sha256=_vA7dWpV9wVBxB7JL9pLQT9BsV91NGQBq_0auAtHK5Y,2080
141
141
  tico/serialize/operators/op_max_pool2d_with_indices.py,sha256=ilQdirgSOjJR6dRIgAEF-sFmjPLQB3O2F_Fq5mbpYNA,5203
142
142
  tico/serialize/operators/op_maximum.py,sha256=JjBr6gWEnuakLuk1_feotTHfIIm3s5YqWmqhUMpSPI0,1873
143
- tico/serialize/operators/op_mean.py,sha256=e7uRPXYHrq3lH7vANcjCHXTRdMWZgiGFBmFdGaOQmts,2269
143
+ tico/serialize/operators/op_mean.py,sha256=rVQZOxCJkHFY4kQBAS1HVK0HkcqxgkSy6zvEDLX_WYQ,2267
144
144
  tico/serialize/operators/op_minimum.py,sha256=fASjQVcTPCin02umQwFPdq2ss-Ve7S5A33J3QmmQ_wQ,1873
145
145
  tico/serialize/operators/op_mm.py,sha256=fHggR9dmlwXw0DAyn__2JbG7e0q1Jhfmi5-2jDlpRDk,6730
146
146
  tico/serialize/operators/op_mul.py,sha256=42Guc0MWBGBCZoj9-4LcLtTMtUPwsmDSVmvkR8tqLhM,3165
@@ -187,14 +187,14 @@ tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
187
187
  tico/utils/register_custom_op.py,sha256=iRQvdqlBqrJxq_pNkvJyDIJD_SYtCUl88wwbbuvSwlk,22952
188
188
  tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
189
189
  tico/utils/utils.py,sha256=pybDU1LoNhjEplANig11lboX9yzYRkvFCSmyYth_2Do,10359
190
- tico/utils/validate_args_kwargs.py,sha256=krT68b5CfBI9rxBIOsgYSy0LfEJqLfKfRikkp8ep9oQ,24726
190
+ tico/utils/validate_args_kwargs.py,sha256=z7JySDhyviI7G6Mo54AFdTiJIOE6Q6UrYnOFJj-5M24,24724
191
191
  tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
192
192
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
193
193
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
194
194
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
195
- tico-0.1.0.dev250520.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
196
- tico-0.1.0.dev250520.dist-info/METADATA,sha256=E3Tizp7XTBYhstti5Qh5Q5HYd-7XiVA6zfwUD8T0DII,8633
197
- tico-0.1.0.dev250520.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
198
- tico-0.1.0.dev250520.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
199
- tico-0.1.0.dev250520.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
200
- tico-0.1.0.dev250520.dist-info/RECORD,,
195
+ tico-0.1.0.dev250521.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
196
+ tico-0.1.0.dev250521.dist-info/METADATA,sha256=3DnR9WqN27XltH1Abj77Oc2hEgKbT99QZ0w0oVdvAuc,8633
197
+ tico-0.1.0.dev250521.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
198
+ tico-0.1.0.dev250521.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
199
+ tico-0.1.0.dev250521.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
200
+ tico-0.1.0.dev250521.dist-info/RECORD,,