tico 0.1.0.dev250520__py3-none-any.whl → 0.1.0.dev250521__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/passes/decompose_group_norm.py +33 -46
- tico/serialize/operators/op_mean.py +1 -1
- tico/utils/validate_args_kwargs.py +1 -1
- {tico-0.1.0.dev250520.dist-info → tico-0.1.0.dev250521.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250520.dist-info → tico-0.1.0.dev250521.dist-info}/RECORD +10 -10
- {tico-0.1.0.dev250520.dist-info → tico-0.1.0.dev250521.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250520.dist-info → tico-0.1.0.dev250521.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250520.dist-info → tico-0.1.0.dev250521.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250520.dist-info → tico-0.1.0.dev250521.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -22,7 +22,7 @@ from tico.config import CompileConfigV1, get_default_config
|
|
22
22
|
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
23
23
|
|
24
24
|
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
25
|
-
__version__ = "0.1.0.
|
25
|
+
__version__ = "0.1.0.dev250521"
|
26
26
|
|
27
27
|
|
28
28
|
if Version(torch.__version__) < Version("2.5.0"):
|
@@ -88,6 +88,25 @@ class DecomposeGroupNorm(PassBase):
|
|
88
88
|
def __init__(self):
|
89
89
|
super().__init__()
|
90
90
|
|
91
|
+
def _insert_norm(self, graph, tensor, eps):
|
92
|
+
"""
|
93
|
+
Insert (tensor - mean) / sqrt(var + eps)) into the graph
|
94
|
+
and return the normalized tensor node.
|
95
|
+
"""
|
96
|
+
mean = graph.call_function(
|
97
|
+
torch.ops.aten.mean.dim, (tensor, [-1]), {"keepdim": True}
|
98
|
+
)
|
99
|
+
deviation = graph.call_function(torch.ops.aten.sub.Tensor, (tensor, mean))
|
100
|
+
squared = graph.call_function(torch.ops.aten.pow.Tensor_Scalar, (deviation, 2))
|
101
|
+
var = graph.call_function(
|
102
|
+
torch.ops.aten.mean.dim, (squared, [-1]), {"keepdim": True}
|
103
|
+
)
|
104
|
+
inverse_std = graph.call_function(
|
105
|
+
torch.ops.aten.rsqrt.default,
|
106
|
+
(graph.call_function(torch.ops.aten.add.Tensor, (var, eps)),),
|
107
|
+
)
|
108
|
+
return graph.call_function(torch.ops.aten.mul.Tensor, (deviation, inverse_std))
|
109
|
+
|
91
110
|
def call(self, exported_program: ExportedProgram) -> PassResult:
|
92
111
|
logger = logging.getLogger(__name__)
|
93
112
|
|
@@ -155,52 +174,20 @@ class DecomposeGroupNorm(PassBase):
|
|
155
174
|
pack_shape = [layer_size, norm_size]
|
156
175
|
|
157
176
|
with gm.graph.inserting_before(node):
|
158
|
-
|
159
|
-
|
160
|
-
#
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
(
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
layer_deviation = graph.call_function(
|
173
|
-
torch.ops.aten.sub.Tensor,
|
174
|
-
(layer, layer_mean_reshape),
|
175
|
-
)
|
176
|
-
layer_sqr_diff = graph.call_function(
|
177
|
-
torch.ops.aten.pow.Tensor_Scalar,
|
178
|
-
(layer_deviation, 2),
|
179
|
-
)
|
180
|
-
var = graph.call_function(
|
181
|
-
torch.ops.aten.mean.dim,
|
182
|
-
(layer_sqr_diff, [-1]),
|
183
|
-
)
|
184
|
-
var_eps = graph.call_function(
|
185
|
-
torch.ops.aten.add.Tensor,
|
186
|
-
(var, eps),
|
187
|
-
)
|
188
|
-
rstd = graph.call_function(
|
189
|
-
torch.ops.aten.rsqrt.default,
|
190
|
-
(var_eps,),
|
191
|
-
)
|
192
|
-
rstd_reshape = graph.call_function(
|
193
|
-
torch.ops.aten.view.default,
|
194
|
-
(rstd, [layer_size, 1]),
|
195
|
-
)
|
196
|
-
layer_norm = graph.call_function(
|
197
|
-
torch.ops.aten.mul.Tensor,
|
198
|
-
(layer_deviation, rstd_reshape),
|
199
|
-
)
|
200
|
-
layer_norm = graph.call_function(
|
201
|
-
torch.ops.aten.view.default,
|
202
|
-
(layer_norm, x_shape),
|
203
|
-
)
|
177
|
+
# Branch only on whether a reshape is needed; the normalization is shared.
|
178
|
+
if norm_size != x_shape[-1]:
|
179
|
+
# Pack groups so that the last dimension equals norm_size.
|
180
|
+
packed = graph.call_function(
|
181
|
+
torch.ops.aten.reshape.default, (x, pack_shape)
|
182
|
+
)
|
183
|
+
normed = self._insert_norm(graph, packed, eps)
|
184
|
+
# Restore the original shape after normalization.
|
185
|
+
layer_norm = graph.call_function(
|
186
|
+
torch.ops.aten.reshape.default, (normed, x_shape)
|
187
|
+
)
|
188
|
+
else:
|
189
|
+
# The input already has norm_size in the last dimension.
|
190
|
+
layer_norm = self._insert_norm(graph, x, eps)
|
204
191
|
|
205
192
|
# weight
|
206
193
|
if weight:
|
@@ -45,7 +45,7 @@ class MeanVisitor(NodeVisitor):
|
|
45
45
|
args = MeanDimArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
46
46
|
input = args.input
|
47
47
|
dim = args.dim
|
48
|
-
keep_dims = args.
|
48
|
+
keep_dims = args.keepdim
|
49
49
|
|
50
50
|
dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
|
51
51
|
inputs = [input, dim_i32]
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=zsOHhqeEvcmCBCvdS74zuPFmwmyF7Ls-Y-7EGShtpIg,1181
|
2
2
|
tico/pt2_to_circle.py,sha256=PPmFNw20jw2Z2VyM3ln9pX__jTzBOAZiv0gT5a-p-Y8,2666
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
|
@@ -70,7 +70,7 @@ tico/passes/decompose_addmm.py,sha256=_yNX7wx1Y9HJI5ksUJI-UQLHpoNawbUbF8kcm2zGHw
|
|
70
70
|
tico/passes/decompose_batch_norm.py,sha256=d1V9UOkm_5BV0NGLyuQfz4I9NpO7I3ZrRugt7EXM-XM,7016
|
71
71
|
tico/passes/decompose_fake_quantize.py,sha256=7ZJyTIDj2iKgWa5q8mBSq6k0GX0vs_XyQdsIiWFJoTU,5175
|
72
72
|
tico/passes/decompose_fake_quantize_tensor_qparams.py,sha256=kOQaODKl_GCE19h-UZGmxnTcHtlvphI63dVAmMQL_Bk,13823
|
73
|
-
tico/passes/decompose_group_norm.py,sha256=
|
73
|
+
tico/passes/decompose_group_norm.py,sha256=xn1xnT-2e6BvelRAzX8O7wg9kBWURmPldkRvpfYFXHQ,9407
|
74
74
|
tico/passes/decompose_grouped_conv2d.py,sha256=KJhH6PX7l9k9T8KBV8JDAvaSfJuUnRo_jtvGF2aM-LA,8277
|
75
75
|
tico/passes/decompose_slice_scatter.py,sha256=ko9p8v-zY5rOx4aSpWomwSdSWb1lIF32gnU7ik5xgII,5604
|
76
76
|
tico/passes/extract_dtype_kwargs.py,sha256=hfGJ_GfZULbBmLif2AJkhPHVifhucxBiLoQI862Yejk,4303
|
@@ -140,7 +140,7 @@ tico/serialize/operators/op_logical_not.py,sha256=ugrVcRqR3IvUUaiRVW5cArCYJbzmkc
|
|
140
140
|
tico/serialize/operators/op_lt.py,sha256=_vA7dWpV9wVBxB7JL9pLQT9BsV91NGQBq_0auAtHK5Y,2080
|
141
141
|
tico/serialize/operators/op_max_pool2d_with_indices.py,sha256=ilQdirgSOjJR6dRIgAEF-sFmjPLQB3O2F_Fq5mbpYNA,5203
|
142
142
|
tico/serialize/operators/op_maximum.py,sha256=JjBr6gWEnuakLuk1_feotTHfIIm3s5YqWmqhUMpSPI0,1873
|
143
|
-
tico/serialize/operators/op_mean.py,sha256=
|
143
|
+
tico/serialize/operators/op_mean.py,sha256=rVQZOxCJkHFY4kQBAS1HVK0HkcqxgkSy6zvEDLX_WYQ,2267
|
144
144
|
tico/serialize/operators/op_minimum.py,sha256=fASjQVcTPCin02umQwFPdq2ss-Ve7S5A33J3QmmQ_wQ,1873
|
145
145
|
tico/serialize/operators/op_mm.py,sha256=fHggR9dmlwXw0DAyn__2JbG7e0q1Jhfmi5-2jDlpRDk,6730
|
146
146
|
tico/serialize/operators/op_mul.py,sha256=42Guc0MWBGBCZoj9-4LcLtTMtUPwsmDSVmvkR8tqLhM,3165
|
@@ -187,14 +187,14 @@ tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
|
|
187
187
|
tico/utils/register_custom_op.py,sha256=iRQvdqlBqrJxq_pNkvJyDIJD_SYtCUl88wwbbuvSwlk,22952
|
188
188
|
tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
|
189
189
|
tico/utils/utils.py,sha256=pybDU1LoNhjEplANig11lboX9yzYRkvFCSmyYth_2Do,10359
|
190
|
-
tico/utils/validate_args_kwargs.py,sha256=
|
190
|
+
tico/utils/validate_args_kwargs.py,sha256=z7JySDhyviI7G6Mo54AFdTiJIOE6Q6UrYnOFJj-5M24,24724
|
191
191
|
tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
192
192
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
193
193
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
194
194
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
195
|
-
tico-0.1.0.
|
196
|
-
tico-0.1.0.
|
197
|
-
tico-0.1.0.
|
198
|
-
tico-0.1.0.
|
199
|
-
tico-0.1.0.
|
200
|
-
tico-0.1.0.
|
195
|
+
tico-0.1.0.dev250521.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
196
|
+
tico-0.1.0.dev250521.dist-info/METADATA,sha256=3DnR9WqN27XltH1Abj77Oc2hEgKbT99QZ0w0oVdvAuc,8633
|
197
|
+
tico-0.1.0.dev250521.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
198
|
+
tico-0.1.0.dev250521.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
199
|
+
tico-0.1.0.dev250521.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
200
|
+
tico-0.1.0.dev250521.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|