tico 0.1.0.dev250519__py3-none-any.whl → 0.1.0.dev250521__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +10 -4
- tico/passes/decompose_fake_quantize_tensor_qparams.py +14 -0
- tico/passes/decompose_group_norm.py +33 -46
- tico/serialize/operators/op_mean.py +1 -1
- tico/utils/validate_args_kwargs.py +1 -1
- {tico-0.1.0.dev250519.dist-info → tico-0.1.0.dev250521.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250519.dist-info → tico-0.1.0.dev250521.dist-info}/RECORD +12 -12
- {tico-0.1.0.dev250519.dist-info → tico-0.1.0.dev250521.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250519.dist-info → tico-0.1.0.dev250521.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250519.dist-info → tico-0.1.0.dev250521.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250519.dist-info → tico-0.1.0.dev250521.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -22,7 +22,7 @@ from tico.config import CompileConfigV1, get_default_config
|
|
22
22
|
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
23
23
|
|
24
24
|
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
25
|
-
__version__ = "0.1.0.
|
25
|
+
__version__ = "0.1.0.dev250521"
|
26
26
|
|
27
27
|
|
28
28
|
if Version(torch.__version__) < Version("2.5.0"):
|
@@ -17,6 +17,12 @@ from typing import List, Optional, TYPE_CHECKING, Union
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
import torch.fx
|
19
19
|
import torch
|
20
|
+
from torch._export.utils import (
|
21
|
+
get_buffer,
|
22
|
+
get_lifted_tensor_constant,
|
23
|
+
is_buffer,
|
24
|
+
is_lifted_tensor_constant,
|
25
|
+
)
|
20
26
|
from torch._subclasses.fake_tensor import FakeTensor
|
21
27
|
from torch.export import ExportedProgram
|
22
28
|
|
@@ -34,10 +40,10 @@ def get_constant(exported_program: ExportedProgram, node: torch.fx.Node):
|
|
34
40
|
assert isinstance(node, torch.fx.Node)
|
35
41
|
if node.name in exported_program.constants:
|
36
42
|
return exported_program.constants[node.name]
|
37
|
-
elif node
|
38
|
-
|
39
|
-
|
40
|
-
return
|
43
|
+
elif is_buffer(exported_program, node):
|
44
|
+
return get_buffer(exported_program, node)
|
45
|
+
elif is_lifted_tensor_constant(exported_program, node):
|
46
|
+
return get_lifted_tensor_constant(exported_program, node)
|
41
47
|
else:
|
42
48
|
raise RuntimeError("NYI constant")
|
43
49
|
|
@@ -18,6 +18,12 @@ if TYPE_CHECKING:
|
|
18
18
|
import torch._ops
|
19
19
|
import torch.fx
|
20
20
|
import torch
|
21
|
+
from torch._export.utils import (
|
22
|
+
get_buffer,
|
23
|
+
get_lifted_tensor_constant,
|
24
|
+
is_buffer,
|
25
|
+
is_lifted_tensor_constant,
|
26
|
+
)
|
21
27
|
|
22
28
|
# To import torch.ops.quantized_decomposed related operator
|
23
29
|
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
|
@@ -55,6 +61,14 @@ def get_constant_from_tensor(
|
|
55
61
|
"""
|
56
62
|
if isinstance(node, float):
|
57
63
|
return node
|
64
|
+
if is_buffer(ep, node):
|
65
|
+
buf = get_buffer(ep, node)
|
66
|
+
assert isinstance(buf, torch.Tensor)
|
67
|
+
return buf.item()
|
68
|
+
elif is_lifted_tensor_constant(ep, node):
|
69
|
+
lifted = get_lifted_tensor_constant(ep, node)
|
70
|
+
assert isinstance(lifted, torch.Tensor)
|
71
|
+
return lifted.item()
|
58
72
|
assert isinstance(node.target, torch._ops.OpOverload)
|
59
73
|
if node.target.__name__ == "mul.Tensor":
|
60
74
|
assert len(node.args) == 2
|
@@ -88,6 +88,25 @@ class DecomposeGroupNorm(PassBase):
|
|
88
88
|
def __init__(self):
|
89
89
|
super().__init__()
|
90
90
|
|
91
|
+
def _insert_norm(self, graph, tensor, eps):
|
92
|
+
"""
|
93
|
+
Insert (tensor - mean) / sqrt(var + eps)) into the graph
|
94
|
+
and return the normalized tensor node.
|
95
|
+
"""
|
96
|
+
mean = graph.call_function(
|
97
|
+
torch.ops.aten.mean.dim, (tensor, [-1]), {"keepdim": True}
|
98
|
+
)
|
99
|
+
deviation = graph.call_function(torch.ops.aten.sub.Tensor, (tensor, mean))
|
100
|
+
squared = graph.call_function(torch.ops.aten.pow.Tensor_Scalar, (deviation, 2))
|
101
|
+
var = graph.call_function(
|
102
|
+
torch.ops.aten.mean.dim, (squared, [-1]), {"keepdim": True}
|
103
|
+
)
|
104
|
+
inverse_std = graph.call_function(
|
105
|
+
torch.ops.aten.rsqrt.default,
|
106
|
+
(graph.call_function(torch.ops.aten.add.Tensor, (var, eps)),),
|
107
|
+
)
|
108
|
+
return graph.call_function(torch.ops.aten.mul.Tensor, (deviation, inverse_std))
|
109
|
+
|
91
110
|
def call(self, exported_program: ExportedProgram) -> PassResult:
|
92
111
|
logger = logging.getLogger(__name__)
|
93
112
|
|
@@ -155,52 +174,20 @@ class DecomposeGroupNorm(PassBase):
|
|
155
174
|
pack_shape = [layer_size, norm_size]
|
156
175
|
|
157
176
|
with gm.graph.inserting_before(node):
|
158
|
-
|
159
|
-
|
160
|
-
#
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
(
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
layer_deviation = graph.call_function(
|
173
|
-
torch.ops.aten.sub.Tensor,
|
174
|
-
(layer, layer_mean_reshape),
|
175
|
-
)
|
176
|
-
layer_sqr_diff = graph.call_function(
|
177
|
-
torch.ops.aten.pow.Tensor_Scalar,
|
178
|
-
(layer_deviation, 2),
|
179
|
-
)
|
180
|
-
var = graph.call_function(
|
181
|
-
torch.ops.aten.mean.dim,
|
182
|
-
(layer_sqr_diff, [-1]),
|
183
|
-
)
|
184
|
-
var_eps = graph.call_function(
|
185
|
-
torch.ops.aten.add.Tensor,
|
186
|
-
(var, eps),
|
187
|
-
)
|
188
|
-
rstd = graph.call_function(
|
189
|
-
torch.ops.aten.rsqrt.default,
|
190
|
-
(var_eps,),
|
191
|
-
)
|
192
|
-
rstd_reshape = graph.call_function(
|
193
|
-
torch.ops.aten.view.default,
|
194
|
-
(rstd, [layer_size, 1]),
|
195
|
-
)
|
196
|
-
layer_norm = graph.call_function(
|
197
|
-
torch.ops.aten.mul.Tensor,
|
198
|
-
(layer_deviation, rstd_reshape),
|
199
|
-
)
|
200
|
-
layer_norm = graph.call_function(
|
201
|
-
torch.ops.aten.view.default,
|
202
|
-
(layer_norm, x_shape),
|
203
|
-
)
|
177
|
+
# Branch only on whether a reshape is needed; the normalization is shared.
|
178
|
+
if norm_size != x_shape[-1]:
|
179
|
+
# Pack groups so that the last dimension equals norm_size.
|
180
|
+
packed = graph.call_function(
|
181
|
+
torch.ops.aten.reshape.default, (x, pack_shape)
|
182
|
+
)
|
183
|
+
normed = self._insert_norm(graph, packed, eps)
|
184
|
+
# Restore the original shape after normalization.
|
185
|
+
layer_norm = graph.call_function(
|
186
|
+
torch.ops.aten.reshape.default, (normed, x_shape)
|
187
|
+
)
|
188
|
+
else:
|
189
|
+
# The input already has norm_size in the last dimension.
|
190
|
+
layer_norm = self._insert_norm(graph, x, eps)
|
204
191
|
|
205
192
|
# weight
|
206
193
|
if weight:
|
@@ -45,7 +45,7 @@ class MeanVisitor(NodeVisitor):
|
|
45
45
|
args = MeanDimArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
46
46
|
input = args.input
|
47
47
|
dim = args.dim
|
48
|
-
keep_dims = args.
|
48
|
+
keep_dims = args.keepdim
|
49
49
|
|
50
50
|
dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
|
51
51
|
inputs = [input, dim_i32]
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=zsOHhqeEvcmCBCvdS74zuPFmwmyF7Ls-Y-7EGShtpIg,1181
|
2
2
|
tico/pt2_to_circle.py,sha256=PPmFNw20jw2Z2VyM3ln9pX__jTzBOAZiv0gT5a-p-Y8,2666
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
|
@@ -54,7 +54,7 @@ tico/experimental/quantization/passes/fold_quant_ops.py,sha256=Jq5wmQDhdjsXxae2p
|
|
54
54
|
tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py,sha256=i4rkM1vlN85fXA9oOrU25o8KWAaqA65NKngTX6MgctQ,12960
|
55
55
|
tico/experimental/quantization/passes/propagate_qparam_backward.py,sha256=TGtyW0Z2qOTgVIasBdGRgbwH31YYd6ek7OvLTmCV614,3118
|
56
56
|
tico/experimental/quantization/passes/propagate_qparam_forward.py,sha256=RhUHGCR2RpBO5KYkQ7Z8U5u7HEwDq2wdKHLKAJCi-5c,5138
|
57
|
-
tico/experimental/quantization/passes/remove_weight_dequant_op.py,sha256=
|
57
|
+
tico/experimental/quantization/passes/remove_weight_dequant_op.py,sha256=lNemHkr_IMg6kTIQjk4xLgW4DkDNBr0wTW3miNqmvkc,6450
|
58
58
|
tico/interpreter/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
59
59
|
tico/interpreter/infer.py,sha256=vJ3b69ce9HrxNT0gFwbEhHpAyvVyuiunTgAeiqn5t64,4350
|
60
60
|
tico/interpreter/interpreter.py,sha256=tGbluCbrehTCqBu8mtGDNzby_ieJ2ry8_RH_eC0CQxk,3828
|
@@ -69,8 +69,8 @@ tico/passes/convert_to_relu6.py,sha256=3sfKfggvjbl9N73pLOwgUTNyoecODsy367nwoX2S-
|
|
69
69
|
tico/passes/decompose_addmm.py,sha256=_yNX7wx1Y9HJI5ksUJI-UQLHpoNawbUbF8kcm2zGHw0,4221
|
70
70
|
tico/passes/decompose_batch_norm.py,sha256=d1V9UOkm_5BV0NGLyuQfz4I9NpO7I3ZrRugt7EXM-XM,7016
|
71
71
|
tico/passes/decompose_fake_quantize.py,sha256=7ZJyTIDj2iKgWa5q8mBSq6k0GX0vs_XyQdsIiWFJoTU,5175
|
72
|
-
tico/passes/decompose_fake_quantize_tensor_qparams.py,sha256=
|
73
|
-
tico/passes/decompose_group_norm.py,sha256=
|
72
|
+
tico/passes/decompose_fake_quantize_tensor_qparams.py,sha256=kOQaODKl_GCE19h-UZGmxnTcHtlvphI63dVAmMQL_Bk,13823
|
73
|
+
tico/passes/decompose_group_norm.py,sha256=xn1xnT-2e6BvelRAzX8O7wg9kBWURmPldkRvpfYFXHQ,9407
|
74
74
|
tico/passes/decompose_grouped_conv2d.py,sha256=KJhH6PX7l9k9T8KBV8JDAvaSfJuUnRo_jtvGF2aM-LA,8277
|
75
75
|
tico/passes/decompose_slice_scatter.py,sha256=ko9p8v-zY5rOx4aSpWomwSdSWb1lIF32gnU7ik5xgII,5604
|
76
76
|
tico/passes/extract_dtype_kwargs.py,sha256=hfGJ_GfZULbBmLif2AJkhPHVifhucxBiLoQI862Yejk,4303
|
@@ -140,7 +140,7 @@ tico/serialize/operators/op_logical_not.py,sha256=ugrVcRqR3IvUUaiRVW5cArCYJbzmkc
|
|
140
140
|
tico/serialize/operators/op_lt.py,sha256=_vA7dWpV9wVBxB7JL9pLQT9BsV91NGQBq_0auAtHK5Y,2080
|
141
141
|
tico/serialize/operators/op_max_pool2d_with_indices.py,sha256=ilQdirgSOjJR6dRIgAEF-sFmjPLQB3O2F_Fq5mbpYNA,5203
|
142
142
|
tico/serialize/operators/op_maximum.py,sha256=JjBr6gWEnuakLuk1_feotTHfIIm3s5YqWmqhUMpSPI0,1873
|
143
|
-
tico/serialize/operators/op_mean.py,sha256=
|
143
|
+
tico/serialize/operators/op_mean.py,sha256=rVQZOxCJkHFY4kQBAS1HVK0HkcqxgkSy6zvEDLX_WYQ,2267
|
144
144
|
tico/serialize/operators/op_minimum.py,sha256=fASjQVcTPCin02umQwFPdq2ss-Ve7S5A33J3QmmQ_wQ,1873
|
145
145
|
tico/serialize/operators/op_mm.py,sha256=fHggR9dmlwXw0DAyn__2JbG7e0q1Jhfmi5-2jDlpRDk,6730
|
146
146
|
tico/serialize/operators/op_mul.py,sha256=42Guc0MWBGBCZoj9-4LcLtTMtUPwsmDSVmvkR8tqLhM,3165
|
@@ -187,14 +187,14 @@ tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
|
|
187
187
|
tico/utils/register_custom_op.py,sha256=iRQvdqlBqrJxq_pNkvJyDIJD_SYtCUl88wwbbuvSwlk,22952
|
188
188
|
tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
|
189
189
|
tico/utils/utils.py,sha256=pybDU1LoNhjEplANig11lboX9yzYRkvFCSmyYth_2Do,10359
|
190
|
-
tico/utils/validate_args_kwargs.py,sha256=
|
190
|
+
tico/utils/validate_args_kwargs.py,sha256=z7JySDhyviI7G6Mo54AFdTiJIOE6Q6UrYnOFJj-5M24,24724
|
191
191
|
tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
192
192
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
193
193
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
194
194
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
195
|
-
tico-0.1.0.
|
196
|
-
tico-0.1.0.
|
197
|
-
tico-0.1.0.
|
198
|
-
tico-0.1.0.
|
199
|
-
tico-0.1.0.
|
200
|
-
tico-0.1.0.
|
195
|
+
tico-0.1.0.dev250521.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
196
|
+
tico-0.1.0.dev250521.dist-info/METADATA,sha256=3DnR9WqN27XltH1Abj77Oc2hEgKbT99QZ0w0oVdvAuc,8633
|
197
|
+
tico-0.1.0.dev250521.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
198
|
+
tico-0.1.0.dev250521.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
199
|
+
tico-0.1.0.dev250521.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
200
|
+
tico-0.1.0.dev250521.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|