tico 0.1.0.dev250722__py3-none-any.whl → 0.1.0.dev250724__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +9 -1
- tico/config/base.py +1 -1
- tico/experimental/quantization/__init__.py +5 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +1 -6
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +1 -1
- tico/experimental/quantization/algorithm/pt2e/utils.py +0 -1
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +1 -1
- tico/experimental/quantization/evaluation/evaluate.py +1 -1
- tico/experimental/quantization/passes/fold_quant_ops.py +0 -1
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +1 -1
- tico/experimental/quantization/passes/quantize_bias.py +0 -1
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +1 -1
- tico/passes/cast_aten_where_arg_type.py +1 -1
- tico/passes/cast_mixed_type_args.py +2 -2
- tico/passes/const_prop_pass.py +1 -1
- tico/passes/convert_conv1d_to_conv2d.py +1 -1
- tico/passes/decompose_addmm.py +0 -3
- tico/passes/decompose_batch_norm.py +2 -2
- tico/passes/decompose_fake_quantize.py +0 -3
- tico/passes/decompose_fake_quantize_tensor_qparams.py +0 -2
- tico/passes/decompose_group_norm.py +0 -3
- tico/passes/legalize_predefined_layout_operators.py +2 -11
- tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
- tico/passes/lower_to_slice.py +1 -1
- tico/passes/merge_consecutive_cat.py +1 -1
- tico/passes/remove_redundant_expand.py +0 -5
- tico/passes/remove_redundant_reshape.py +5 -5
- tico/passes/segment_index_select.py +1 -1
- tico/serialize/circle_graph.py +1 -1
- tico/serialize/circle_serializer.py +234 -141
- tico/serialize/operators/op_any.py +0 -3
- tico/serialize/operators/op_clamp.py +2 -5
- tico/serialize/operators/op_full_like.py +0 -2
- tico/serialize/operators/op_instance_norm.py +0 -6
- tico/serialize/operators/op_mul.py +2 -8
- tico/serialize/operators/op_transpose_conv.py +0 -2
- tico/serialize/quant_param.py +5 -5
- tico/utils/convert.py +1 -1
- tico/utils/graph.py +1 -1
- tico/utils/padding.py +0 -2
- tico/utils/serialize.py +0 -3
- tico/utils/utils.py +1 -2
- tico/utils/validate_args_kwargs.py +1 -3
- {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/RECORD +49 -49
- {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -20,8 +20,16 @@ from packaging.version import Version
|
|
20
20
|
from tico.config import CompileConfigV1, get_default_config
|
21
21
|
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
22
22
|
|
23
|
+
__all__ = [
|
24
|
+
"CompileConfigV1",
|
25
|
+
"get_default_config",
|
26
|
+
"convert",
|
27
|
+
"convert_from_exported_program",
|
28
|
+
"convert_from_pt2",
|
29
|
+
]
|
30
|
+
|
23
31
|
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
24
|
-
__version__ = "0.1.0.
|
32
|
+
__version__ = "0.1.0.dev250724"
|
25
33
|
|
26
34
|
MINIMUM_SUPPORTED_VERSION = "2.5.0"
|
27
35
|
SECURE_TORCH_VERSION = "2.6.0"
|
tico/config/base.py
CHANGED
@@ -21,12 +21,7 @@ if TYPE_CHECKING:
|
|
21
21
|
import torch.fx
|
22
22
|
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
23
23
|
import torch
|
24
|
-
from torch.ao.quantization.observer import
|
25
|
-
MinMaxObserver,
|
26
|
-
MovingAverageMinMaxObserver,
|
27
|
-
MovingAveragePerChannelMinMaxObserver,
|
28
|
-
PerChannelMinMaxObserver,
|
29
|
-
)
|
24
|
+
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
|
30
25
|
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
|
31
26
|
from torch.ao.quantization.quantizer.utils import _get_module_name_filter
|
32
27
|
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Callable,
|
15
|
+
from typing import Callable, Optional, TYPE_CHECKING
|
16
16
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
import torch.fx
|
@@ -19,7 +19,6 @@ if TYPE_CHECKING:
|
|
19
19
|
import torch
|
20
20
|
from torch.ao.quantization.quantizer import QuantizationSpec
|
21
21
|
from torch.ao.quantization.quantizer.utils import _get_module_name_filter
|
22
|
-
from torch.utils import _pytree as pytree
|
23
22
|
|
24
23
|
from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
|
25
24
|
QuantizationConfig,
|
@@ -166,7 +166,7 @@ def evaluate(
|
|
166
166
|
)
|
167
167
|
if not isinstance(backend, BACKEND):
|
168
168
|
raise RuntimeError(
|
169
|
-
|
169
|
+
"Invalid backend. Please use tico.quantization.evaluate.BACKEND enum class"
|
170
170
|
)
|
171
171
|
# Make it a list for simpler logic.
|
172
172
|
if input_data is not None:
|
@@ -176,7 +176,7 @@ class CastATenWhereArgType(PassBase):
|
|
176
176
|
node_dtype = extract_torch_dtype(node)
|
177
177
|
assert (
|
178
178
|
node_dtype == node_dtype_ori
|
179
|
-
),
|
179
|
+
), "Type casting doesn't change node's dtype."
|
180
180
|
|
181
181
|
logger.debug(
|
182
182
|
f"{to_cast.name}'s dtype was casted from {buf_data.dtype} to {dtype_to_cast}"
|
@@ -124,7 +124,7 @@ class CastMixedTypeArgs(PassBase):
|
|
124
124
|
if rhs_val.dtype == type_to_promote:
|
125
125
|
ori_type = lhs_val.dtype
|
126
126
|
arg_to_promote = lhs
|
127
|
-
assert arg_to_promote
|
127
|
+
assert arg_to_promote is not None
|
128
128
|
|
129
129
|
if isinstance(arg_to_promote, torch.fx.Node):
|
130
130
|
with graph.inserting_after(arg_to_promote):
|
@@ -178,7 +178,7 @@ class CastMixedTypeArgs(PassBase):
|
|
178
178
|
node_dtype = extract_torch_dtype(node)
|
179
179
|
assert (
|
180
180
|
node_dtype == node_dtype_ori
|
181
|
-
),
|
181
|
+
), "Type casting doesn't change node's dtype."
|
182
182
|
|
183
183
|
graph.eliminate_dead_code()
|
184
184
|
graph.lint()
|
tico/passes/const_prop_pass.py
CHANGED
@@ -301,7 +301,7 @@ class ConstPropPass(PassBase):
|
|
301
301
|
graph.eliminate_dead_code()
|
302
302
|
graph_module.recompile()
|
303
303
|
|
304
|
-
logger.debug(
|
304
|
+
logger.debug("Constant nodes are propagated")
|
305
305
|
# Constant folding can be done with only one time run. Let's set `modified` to False.
|
306
306
|
modified = False
|
307
307
|
return PassResult(modified)
|
@@ -19,7 +19,7 @@ if TYPE_CHECKING:
|
|
19
19
|
import torch
|
20
20
|
from torch.export import ExportedProgram
|
21
21
|
|
22
|
-
from tico.serialize.
|
22
|
+
from tico.serialize.circle_mapping import extract_shape
|
23
23
|
from tico.utils import logging
|
24
24
|
from tico.utils.errors import NotYetSupportedError
|
25
25
|
from tico.utils.graph import create_node
|
tico/passes/decompose_addmm.py
CHANGED
@@ -20,7 +20,6 @@ import torch
|
|
20
20
|
from torch.export import ExportedProgram
|
21
21
|
|
22
22
|
from tico.serialize.circle_mapping import extract_shape
|
23
|
-
from tico.utils import logging
|
24
23
|
from tico.utils.graph import add_placeholder, create_node
|
25
24
|
from tico.utils.passes import PassBase, PassResult
|
26
25
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
@@ -59,8 +58,6 @@ class DecomposeAddmm(PassBase):
|
|
59
58
|
super().__init__()
|
60
59
|
|
61
60
|
def call(self, exported_program: ExportedProgram) -> PassResult:
|
62
|
-
logger = logging.getLogger(__name__)
|
63
|
-
|
64
61
|
gm = exported_program.graph_module
|
65
62
|
graph: torch.fx.Graph = gm.graph
|
66
63
|
modified = False
|
@@ -96,9 +96,9 @@ class DecomposeBatchNorm(PassBase):
|
|
96
96
|
eps = args.eps
|
97
97
|
|
98
98
|
if not running_mean:
|
99
|
-
raise NotYetSupportedError(
|
99
|
+
raise NotYetSupportedError("running_mean=None is not supported yet")
|
100
100
|
if not running_var:
|
101
|
-
raise NotYetSupportedError(
|
101
|
+
raise NotYetSupportedError("running_var=None is not supported yet")
|
102
102
|
|
103
103
|
"""
|
104
104
|
Only support the cases generated from torch.nn.BatchNorm2d module,
|
@@ -19,10 +19,8 @@ if TYPE_CHECKING:
|
|
19
19
|
import torch
|
20
20
|
|
21
21
|
# To import torch.ops.quantized_decomposed related operator
|
22
|
-
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
|
23
22
|
from torch.export import ExportedProgram
|
24
23
|
|
25
|
-
from tico.utils import logging
|
26
24
|
from tico.utils.graph import create_node
|
27
25
|
from tico.utils.passes import PassBase, PassResult
|
28
26
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
@@ -66,7 +64,6 @@ class DecomposeFakeQuantize(PassBase):
|
|
66
64
|
super().__init__()
|
67
65
|
|
68
66
|
def call(self, exported_program: ExportedProgram) -> PassResult:
|
69
|
-
logger = logging.getLogger(__name__)
|
70
67
|
modified = False
|
71
68
|
|
72
69
|
gm = exported_program.graph_module
|
@@ -26,10 +26,8 @@ from torch._export.utils import (
|
|
26
26
|
)
|
27
27
|
|
28
28
|
# To import torch.ops.quantized_decomposed related operator
|
29
|
-
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
|
30
29
|
from torch.export import ExportedProgram
|
31
30
|
|
32
|
-
from tico.utils import logging
|
33
31
|
from tico.utils.graph import create_node
|
34
32
|
from tico.utils.passes import PassBase, PassResult
|
35
33
|
from tico.utils.trace_decorators import (
|
@@ -22,7 +22,6 @@ import torch
|
|
22
22
|
from torch.export import ExportedProgram
|
23
23
|
|
24
24
|
from tico.serialize.circle_mapping import extract_shape
|
25
|
-
from tico.utils import logging
|
26
25
|
from tico.utils.graph import create_node
|
27
26
|
from tico.utils.passes import PassBase, PassResult
|
28
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
@@ -126,8 +125,6 @@ class DecomposeGroupNorm(PassBase):
|
|
126
125
|
)
|
127
126
|
|
128
127
|
def call(self, exported_program: ExportedProgram) -> PassResult:
|
129
|
-
logger = logging.getLogger(__name__)
|
130
|
-
|
131
128
|
gm = exported_program.graph_module
|
132
129
|
graph: torch.fx.Graph = gm.graph
|
133
130
|
modified = False
|
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
|
|
20
20
|
import torch
|
21
21
|
from torch.export import ExportedProgram
|
22
22
|
|
23
|
-
from tico.serialize.
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
24
24
|
from tico.utils import logging
|
25
25
|
from tico.utils.errors import NotYetSupportedError
|
26
26
|
from tico.utils.graph import create_node
|
@@ -206,7 +206,6 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
206
206
|
|
207
207
|
args = ConvTranspose2DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
208
208
|
input = args.input
|
209
|
-
padding = args.padding
|
210
209
|
groups = args.groups
|
211
210
|
dilation = args.dilation
|
212
211
|
|
@@ -288,13 +287,12 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
288
287
|
input = args.input
|
289
288
|
weight = args.weight
|
290
289
|
bias = args.bias
|
291
|
-
eps = args.eps
|
292
290
|
|
293
291
|
running_mean = args.running_mean
|
294
292
|
running_var = args.running_var
|
295
293
|
use_input_stats = args.use_input_stats
|
296
294
|
|
297
|
-
if not
|
295
|
+
if not use_input_stats:
|
298
296
|
raise NotYetSupportedError("Only support use_input_stats is True.")
|
299
297
|
if not isinstance(running_mean, NoneType):
|
300
298
|
raise NotYetSupportedError("Only support running_mean=None")
|
@@ -350,10 +348,6 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
350
348
|
# max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
|
351
349
|
args = MaxPool2dWithIndicesArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
352
350
|
input_ = args.input
|
353
|
-
kernel_size = args.kernel_size
|
354
|
-
stride = args.stride
|
355
|
-
padding = args.padding
|
356
|
-
dilation = args.dilation
|
357
351
|
ceil_mode = args.ceil_mode
|
358
352
|
if ceil_mode:
|
359
353
|
raise NotYetSupportedError("Only support non-ceil model.")
|
@@ -402,9 +396,6 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
402
396
|
# avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
|
403
397
|
args = AvgPool2dArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
404
398
|
input_ = args.input
|
405
|
-
kernel_size = args.kernel_size
|
406
|
-
stride = args.stride
|
407
|
-
padding = args.padding
|
408
399
|
ceil_mode = args.ceil_mode
|
409
400
|
if ceil_mode:
|
410
401
|
raise NotYetSupportedError("Only support non-ceil model.")
|
@@ -67,7 +67,7 @@ class LowerToResizeNearestNeighbor(PassBase):
|
|
67
67
|
return None
|
68
68
|
# indices = [None, None, H index, W index]
|
69
69
|
N, C, H, W = indices
|
70
|
-
if N
|
70
|
+
if N is not None or C is not None:
|
71
71
|
return None
|
72
72
|
if not isinstance(H, torch.fx.Node):
|
73
73
|
return None
|
tico/passes/lower_to_slice.py
CHANGED
@@ -28,7 +28,7 @@ from torch._export.utils import (
|
|
28
28
|
from torch.export import ExportedProgram
|
29
29
|
|
30
30
|
from tico.passes import ops
|
31
|
-
from tico.serialize.
|
31
|
+
from tico.serialize.circle_mapping import extract_shape
|
32
32
|
from tico.utils import logging
|
33
33
|
from tico.utils.graph import create_node, is_single_value_tensor
|
34
34
|
from tico.utils.passes import PassBase, PassResult
|
@@ -51,7 +51,7 @@ class MergeConsecutiveCat(PassBase):
|
|
51
51
|
if not prev_cat.op == "call_function":
|
52
52
|
continue
|
53
53
|
|
54
|
-
if
|
54
|
+
if prev_cat.target not in ops.aten.cat:
|
55
55
|
continue
|
56
56
|
|
57
57
|
prev_args = CatArgs(*prev_cat.args, **prev_cat.kwargs) # type: ignore[arg-type]
|
@@ -12,11 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import TYPE_CHECKING
|
16
|
-
|
17
|
-
if TYPE_CHECKING:
|
18
|
-
import torch.fx
|
19
|
-
import torch
|
20
15
|
from torch.export import ExportedProgram
|
21
16
|
|
22
17
|
from tico.passes import ops
|
@@ -90,7 +90,7 @@ class RemoveRedundantReshapePattern1(PassBase):
|
|
90
90
|
if len(permute.users) != 1:
|
91
91
|
continue
|
92
92
|
permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
|
93
|
-
|
93
|
+
permute_dims = permute_args.dims
|
94
94
|
# (1xAxBxC) - `aten.permute` - (1xAxCxB)
|
95
95
|
if permute_dims != [0, 1, 3, 2]:
|
96
96
|
continue
|
@@ -172,7 +172,7 @@ class RemoveRedundantReshapePattern2(PassBase):
|
|
172
172
|
if len(permute.users) != 1:
|
173
173
|
continue
|
174
174
|
permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
|
175
|
-
|
175
|
+
permute_dims = permute_args.dims
|
176
176
|
# (1xAxBxC) - `aten.permute` - (Bx1xAxC)
|
177
177
|
if permute_dims != [2, 0, 1, 3]:
|
178
178
|
continue
|
@@ -262,7 +262,7 @@ class RemoveRedundantReshapePattern3(PassBase):
|
|
262
262
|
continue
|
263
263
|
|
264
264
|
# add
|
265
|
-
if
|
265
|
+
if add.target not in ops.aten.add:
|
266
266
|
continue
|
267
267
|
add_args = AddTensorArgs(*add.args, **add.kwargs) # type: ignore[arg-type]
|
268
268
|
reshape_2, reshape_3 = add_args.input, add_args.other
|
@@ -272,7 +272,7 @@ class RemoveRedundantReshapePattern3(PassBase):
|
|
272
272
|
# reshape_2
|
273
273
|
if not reshape_2.op == "call_function":
|
274
274
|
continue
|
275
|
-
if
|
275
|
+
if reshape_2.target not in ops.aten.reshape:
|
276
276
|
continue
|
277
277
|
reshape_2_args = ReshapeArgs(*reshape_2.args, **reshape_2.kwargs) # type: ignore[arg-type]
|
278
278
|
reshape_2_input = reshape_2_args.input
|
@@ -280,7 +280,7 @@ class RemoveRedundantReshapePattern3(PassBase):
|
|
280
280
|
# reshape_3
|
281
281
|
if not reshape_3.op == "call_function":
|
282
282
|
continue
|
283
|
-
if
|
283
|
+
if reshape_3.target not in ops.aten.reshape:
|
284
284
|
continue
|
285
285
|
reshape_3_args = ReshapeArgs(*reshape_3.args, **reshape_3.kwargs) # type: ignore[arg-type]
|
286
286
|
reshape_3_input = reshape_3_args.input
|
@@ -29,7 +29,7 @@ from torch._export.utils import (
|
|
29
29
|
from torch.export import ExportedProgram
|
30
30
|
|
31
31
|
from tico.passes import ops
|
32
|
-
from tico.serialize.
|
32
|
+
from tico.serialize.circle_mapping import extract_shape
|
33
33
|
from tico.utils import logging
|
34
34
|
from tico.utils.graph import add_placeholder, create_node, is_single_value_tensor
|
35
35
|
from tico.utils.passes import PassBase, PassResult
|
tico/serialize/circle_graph.py
CHANGED
@@ -323,7 +323,7 @@ class CircleSubgraph(circle.SubGraph.SubGraphT):
|
|
323
323
|
self, node: Union[torch.fx.Node, circle.Tensor.TensorT, ConstData]
|
324
324
|
) -> int:
|
325
325
|
# return -1 if node is None. This is for generating CircleOutputExclude
|
326
|
-
if node
|
326
|
+
if node is None:
|
327
327
|
return -1
|
328
328
|
|
329
329
|
if hasattr(node, "name") and node.name in self.name_to_tid:
|