tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +60 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +128 -0
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/utils/convert.py +20 -15
- tico/utils/dtype.py +22 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/signature.py +247 -0
- tico/utils/utils.py +50 -53
- tico/utils/validate_args_kwargs.py +37 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/METADATA +49 -2
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/RECORD +130 -73
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/top_level.txt +0 -0
|
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
|
|
|
20
20
|
import torch
|
|
21
21
|
from circle_schema import circle
|
|
22
22
|
|
|
23
|
-
from tico.serialize.circle_graph import CircleSubgraph
|
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
|
24
24
|
from tico.serialize.operators.hashable_opcode import OpCode
|
|
25
25
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
|
26
26
|
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
|
@@ -28,9 +28,9 @@ from tico.utils.validate_args_kwargs import MatmulArgs
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
@register_node_visitor
|
|
31
|
-
class
|
|
31
|
+
class MatmulVisitor(NodeVisitor):
|
|
32
32
|
"""
|
|
33
|
-
Convert matmul to
|
|
33
|
+
Convert matmul to Circle BatchMatMul
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
36
|
target: List[torch._ops.OpOverload] = [torch.ops.aten.mm.default]
|
|
@@ -38,131 +38,7 @@ class MatmulDefaultVisitor(NodeVisitor):
|
|
|
38
38
|
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
|
39
39
|
super().__init__(op_codes, graph)
|
|
40
40
|
|
|
41
|
-
|
|
42
|
-
def define_bmm_node(self, inputs, outputs) -> circle.Operator.OperatorT:
|
|
43
|
-
def set_bmm_option(operator):
|
|
44
|
-
operator.builtinOptionsType = (
|
|
45
|
-
circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
|
|
46
|
-
)
|
|
47
|
-
option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
|
|
48
|
-
option.adjointLhs, option.adjointRhs = False, False
|
|
49
|
-
option.asymmetricQuantizeInputs = False
|
|
50
|
-
operator.builtinOptions = option
|
|
51
|
-
|
|
52
|
-
op_index = get_op_index(
|
|
53
|
-
circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
|
|
54
|
-
)
|
|
55
|
-
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
|
56
|
-
set_bmm_option(operator)
|
|
57
|
-
|
|
58
|
-
return operator
|
|
59
|
-
|
|
60
|
-
def define_transpose_node(self, inputs, outputs) -> circle.Operator.OperatorT:
|
|
61
|
-
def set_transpose_option(operator):
|
|
62
|
-
operator.builtinOptionsType = (
|
|
63
|
-
circle.BuiltinOptions.BuiltinOptions.TransposeOptions
|
|
64
|
-
)
|
|
65
|
-
option = circle.TransposeOptions.TransposeOptionsT()
|
|
66
|
-
operator.builtinOptions = option
|
|
67
|
-
|
|
68
|
-
transpose_op_index = get_op_index(
|
|
69
|
-
circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes
|
|
70
|
-
)
|
|
71
|
-
operator = create_builtin_operator(
|
|
72
|
-
self.graph, transpose_op_index, inputs, outputs
|
|
73
|
-
)
|
|
74
|
-
set_transpose_option(operator)
|
|
75
|
-
return operator
|
|
76
|
-
|
|
77
|
-
def define_fc_node(self, inputs, outputs) -> circle.Operator.OperatorT:
|
|
78
|
-
def set_fc_option(operator):
|
|
79
|
-
operator.builtinOptionsType = (
|
|
80
|
-
circle.BuiltinOptions.BuiltinOptions.FullyConnectedOptions
|
|
81
|
-
)
|
|
82
|
-
option = circle.FullyConnectedOptions.FullyConnectedOptionsT()
|
|
83
|
-
|
|
84
|
-
option.fusedActivationFunction = (
|
|
85
|
-
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
|
86
|
-
)
|
|
87
|
-
option.weightsFormat = (
|
|
88
|
-
circle.FullyConnectedOptionsWeightsFormat.FullyConnectedOptionsWeightsFormat.DEFAULT
|
|
89
|
-
)
|
|
90
|
-
option.keepNumDims = False
|
|
91
|
-
option.asymmetricQuantizeInputs = False
|
|
92
|
-
option.quantizedBiasType = circle.TensorType.TensorType.FLOAT32
|
|
93
|
-
|
|
94
|
-
operator.builtinOptions = option
|
|
95
|
-
|
|
96
|
-
fc_op_index = get_op_index(
|
|
97
|
-
circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, self._op_codes
|
|
98
|
-
)
|
|
99
|
-
operator = create_builtin_operator(self.graph, fc_op_index, inputs, outputs)
|
|
100
|
-
set_fc_option(operator)
|
|
101
|
-
return operator
|
|
102
|
-
|
|
103
|
-
"""
|
|
104
|
-
Define FullyConnnected with Tranpose operator.
|
|
105
|
-
Note that those sets of operators are equivalent.
|
|
106
|
-
(1) Matmul
|
|
107
|
-
matmul( lhs[H, K], rhs[K, W'] ) -> output(H, W')
|
|
108
|
-
|
|
109
|
-
(2) Transpose + FullyConneccted
|
|
110
|
-
transpose( rhs[K, W'] ) -> trs_output[W', K]
|
|
111
|
-
fullyconnected( lhs[H, K], trs_output[W', K] ) -> output(H, W')
|
|
112
|
-
"""
|
|
113
|
-
|
|
114
|
-
def define_fc_with_transpose(
|
|
115
|
-
self, node, inputs, outputs
|
|
116
|
-
) -> circle.Operator.OperatorT:
|
|
117
|
-
lhs, rhs = inputs
|
|
118
|
-
|
|
119
|
-
# get transpose shape
|
|
120
|
-
rhs_tid: int = self.graph.get_tid_registered(rhs)
|
|
121
|
-
rhs_tensor: circle.Tensor.TensorT = self.graph.tensors[rhs_tid]
|
|
122
|
-
rhs_name: str = rhs.name
|
|
123
|
-
rhs_type: int = rhs_tensor.type
|
|
124
|
-
rhs_shape: List[int] = rhs_tensor.shape
|
|
125
|
-
assert len(rhs_shape) == 2, len(rhs_shape)
|
|
126
|
-
rhs_shape_transpose = [rhs_shape[1], rhs_shape[0]]
|
|
127
|
-
|
|
128
|
-
# create transpose output tensor
|
|
129
|
-
trs_output = self.graph.add_tensor_from_scratch(
|
|
130
|
-
prefix=f"{rhs_name}_transposed_output",
|
|
131
|
-
shape=rhs_shape_transpose,
|
|
132
|
-
shape_signature=None,
|
|
133
|
-
dtype=rhs_type,
|
|
134
|
-
source_node=node,
|
|
135
|
-
)
|
|
136
|
-
trs_perm = self.graph.add_const_tensor(data=[1, 0], source_node=node)
|
|
137
|
-
trs_operator = self.define_transpose_node([rhs, trs_perm], [trs_output])
|
|
138
|
-
self.graph.add_operator(trs_operator)
|
|
139
|
-
|
|
140
|
-
# define fc node
|
|
141
|
-
fc_input = lhs
|
|
142
|
-
fc_weight = trs_output
|
|
143
|
-
fc_shape = [fc_weight.shape[0]]
|
|
144
|
-
fc_bias = self.graph.add_const_tensor(
|
|
145
|
-
data=[0.0] * fc_shape[0], source_node=node
|
|
146
|
-
)
|
|
147
|
-
|
|
148
|
-
operator = self.define_fc_node([fc_input, fc_weight, fc_bias], outputs)
|
|
149
|
-
|
|
150
|
-
return operator
|
|
151
|
-
|
|
152
|
-
def define_node(
|
|
153
|
-
self, node: torch.fx.Node, prior_latency=True
|
|
154
|
-
) -> circle.Operator.OperatorT:
|
|
155
|
-
"""
|
|
156
|
-
NOTE: Possibility of accuracy-latency trade-off
|
|
157
|
-
From ONE compiler's perspective:
|
|
158
|
-
- BMM uses per-tensor quantization for both rhs and lhs.
|
|
159
|
-
- FC uses per-channel quantization for weight and per-tensor for input.
|
|
160
|
-
Thus, FC is better in terms of accuracy.
|
|
161
|
-
FC necessarily involves an additional transpose operation to be identical with mm.
|
|
162
|
-
If transposed operand is const, it can be optimized by constant folding.
|
|
163
|
-
Thus, convert FC only if tranpose can be folded.
|
|
164
|
-
TODO set prior_latency outside
|
|
165
|
-
"""
|
|
41
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
|
166
42
|
args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
167
43
|
input = args.input
|
|
168
44
|
other = args.other
|
|
@@ -170,9 +46,16 @@ class MatmulDefaultVisitor(NodeVisitor):
|
|
|
170
46
|
inputs = [input, other]
|
|
171
47
|
outputs = [node]
|
|
172
48
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
49
|
+
op_index = get_op_index(
|
|
50
|
+
circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
|
|
51
|
+
)
|
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
|
53
|
+
operator.builtinOptionsType = (
|
|
54
|
+
circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
|
|
55
|
+
)
|
|
56
|
+
option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
|
|
57
|
+
option.adjointLhs, option.adjointRhs = False, False
|
|
58
|
+
option.asymmetricQuantizeInputs = False
|
|
59
|
+
operator.builtinOptions = option
|
|
177
60
|
|
|
178
61
|
return operator
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch._ops
|
|
19
|
+
import torch.fx
|
|
20
|
+
import torch
|
|
21
|
+
from circle_schema import circle
|
|
22
|
+
|
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
|
27
|
+
from tico.utils.validate_args_kwargs import CircleRMSNormArgs
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@register_node_visitor
|
|
31
|
+
class RMSNormVisitor(NodeVisitor):
|
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
|
33
|
+
torch.ops.circle_custom.rms_norm.default,
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
|
37
|
+
super().__init__(op_codes, graph)
|
|
38
|
+
|
|
39
|
+
def define_node(
|
|
40
|
+
self,
|
|
41
|
+
node: torch.fx.Node,
|
|
42
|
+
) -> circle.Operator.OperatorT:
|
|
43
|
+
args = CircleRMSNormArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
44
|
+
input = args.input
|
|
45
|
+
weight = args.weight
|
|
46
|
+
eps = args.eps
|
|
47
|
+
|
|
48
|
+
op_index = get_op_index(
|
|
49
|
+
circle.BuiltinOperator.BuiltinOperator.RMS_NORM, self._op_codes
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
inputs = [input, weight]
|
|
53
|
+
outputs = [node]
|
|
54
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
|
55
|
+
|
|
56
|
+
# Op-specific option
|
|
57
|
+
operator.builtinOptionsType = (
|
|
58
|
+
circle.BuiltinOptions.BuiltinOptions.RmsNormOptions
|
|
59
|
+
)
|
|
60
|
+
option = circle.RmsNormOptions.RmsNormOptionsT()
|
|
61
|
+
option.epsilon = eps
|
|
62
|
+
|
|
63
|
+
operator.builtinOptions = option
|
|
64
|
+
|
|
65
|
+
return operator
|
tico/utils/convert.py
CHANGED
|
@@ -20,26 +20,14 @@ import torch
|
|
|
20
20
|
from torch.export import export, ExportedProgram
|
|
21
21
|
|
|
22
22
|
from tico.config import CompileConfigBase, get_default_config
|
|
23
|
-
from tico.experimental.quantization.passes.fold_quant_ops import FoldQuantOps
|
|
24
|
-
from tico.experimental.quantization.passes.insert_quantize_on_dtype_mismatch import (
|
|
25
|
-
InsertQuantizeOnDtypeMismatch,
|
|
26
|
-
)
|
|
27
|
-
from tico.experimental.quantization.passes.propagate_qparam_backward import (
|
|
28
|
-
PropagateQParamBackward,
|
|
29
|
-
)
|
|
30
|
-
from tico.experimental.quantization.passes.propagate_qparam_forward import (
|
|
31
|
-
PropagateQParamForward,
|
|
32
|
-
)
|
|
33
|
-
from tico.experimental.quantization.passes.quantize_bias import QuantizeBias
|
|
34
|
-
from tico.experimental.quantization.passes.remove_weight_dequant_op import (
|
|
35
|
-
RemoveWeightDequantOp,
|
|
36
|
-
)
|
|
37
23
|
from tico.passes.cast_aten_where_arg_type import CastATenWhereArgType
|
|
38
24
|
from tico.passes.cast_clamp_mixed_type_args import CastClampMixedTypeArgs
|
|
39
25
|
from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
|
|
40
26
|
from tico.passes.const_prop_pass import ConstPropPass
|
|
41
27
|
from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
|
|
28
|
+
from tico.passes.convert_expand_to_slice_cat import ConvertExpandToSliceCat
|
|
42
29
|
from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
|
|
30
|
+
from tico.passes.convert_matmul_to_linear import ConvertMatmulToLinear
|
|
43
31
|
from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy
|
|
44
32
|
from tico.passes.convert_to_relu6 import ConvertToReLU6
|
|
45
33
|
from tico.passes.decompose_addmm import DecomposeAddmm
|
|
@@ -72,6 +60,14 @@ from tico.passes.remove_redundant_slice import RemoveRedundantSlice
|
|
|
72
60
|
from tico.passes.remove_redundant_to_copy import RemoveRedundantToCopy
|
|
73
61
|
from tico.passes.restore_linear import RestoreLinear
|
|
74
62
|
from tico.passes.segment_index_select import SegmentIndexSelectConst
|
|
63
|
+
from tico.quantization.passes.fold_quant_ops import FoldQuantOps
|
|
64
|
+
from tico.quantization.passes.insert_quantize_on_dtype_mismatch import (
|
|
65
|
+
InsertQuantizeOnDtypeMismatch,
|
|
66
|
+
)
|
|
67
|
+
from tico.quantization.passes.propagate_qparam_backward import PropagateQParamBackward
|
|
68
|
+
from tico.quantization.passes.propagate_qparam_forward import PropagateQParamForward
|
|
69
|
+
from tico.quantization.passes.quantize_bias import QuantizeBias
|
|
70
|
+
from tico.quantization.passes.remove_weight_dequant_op import RemoveWeightDequantOp
|
|
75
71
|
from tico.serialize.circle_serializer import build_circle
|
|
76
72
|
from tico.serialize.operators.node_visitor import get_support_targets
|
|
77
73
|
from tico.utils import logging
|
|
@@ -141,6 +137,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
|
|
|
141
137
|
or torch.__version__.startswith("2.7")
|
|
142
138
|
or torch.__version__.startswith("2.8")
|
|
143
139
|
or torch.__version__.startswith("2.9")
|
|
140
|
+
or torch.__version__.startswith("2.10")
|
|
144
141
|
):
|
|
145
142
|
return run_decompositions(exported_program)
|
|
146
143
|
else:
|
|
@@ -249,6 +246,14 @@ def convert_exported_module_to_circle(
|
|
|
249
246
|
ConstPropPass(),
|
|
250
247
|
SegmentIndexSelectConst(),
|
|
251
248
|
LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")),
|
|
249
|
+
ConvertExpandToSliceCat(enabled=config.get("convert_expand_to_slice_cat")),
|
|
250
|
+
ConvertMatmulToLinear(
|
|
251
|
+
enable_lhs_const=config.get("convert_lhs_const_mm_to_fc"),
|
|
252
|
+
enable_rhs_const=config.get("convert_rhs_const_mm_to_fc"),
|
|
253
|
+
enable_single_batch_lhs_const_bmm=config.get(
|
|
254
|
+
"convert_single_batch_lhs_const_bmm_to_fc"
|
|
255
|
+
),
|
|
256
|
+
),
|
|
252
257
|
LowerToResizeNearestNeighbor(),
|
|
253
258
|
LegalizePreDefinedLayoutOperators(),
|
|
254
259
|
LowerPow2ToMul(),
|
|
@@ -287,7 +292,7 @@ def convert_exported_module_to_circle(
|
|
|
287
292
|
|
|
288
293
|
check_unsupported_target(exported_program)
|
|
289
294
|
check_training_ops(exported_program)
|
|
290
|
-
circle_program = build_circle(exported_program)
|
|
295
|
+
circle_program = build_circle(exported_program, config)
|
|
291
296
|
|
|
292
297
|
return circle_program
|
|
293
298
|
|
tico/utils/dtype.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import torch
|
|
3
3
|
|
|
4
|
+
from circle_schema import circle
|
|
5
|
+
|
|
4
6
|
NUMPY_TO_TORCH_DTYPE_DICT = {
|
|
5
7
|
np.dtype("float32"): torch.float32,
|
|
6
8
|
np.dtype("float64"): torch.float64,
|
|
@@ -15,6 +17,26 @@ NUMPY_TO_TORCH_DTYPE_DICT = {
|
|
|
15
17
|
np.dtype("bool"): torch.bool,
|
|
16
18
|
}
|
|
17
19
|
|
|
20
|
+
CIRCLE_TO_TORCH_DTYPE_DICT = {
|
|
21
|
+
circle.TensorType.TensorType.FLOAT32: torch.float32,
|
|
22
|
+
circle.TensorType.TensorType.UINT8: torch.uint8,
|
|
23
|
+
circle.TensorType.TensorType.INT8: torch.int8,
|
|
24
|
+
circle.TensorType.TensorType.INT16: torch.int16,
|
|
25
|
+
circle.TensorType.TensorType.INT32: torch.int32,
|
|
26
|
+
circle.TensorType.TensorType.INT64: torch.int64,
|
|
27
|
+
circle.TensorType.TensorType.BOOL: torch.bool,
|
|
28
|
+
}
|
|
29
|
+
|
|
18
30
|
|
|
19
31
|
def numpy_dtype_to_torch_dtype(np_dtype: np.dtype) -> torch.dtype:
|
|
20
32
|
return NUMPY_TO_TORCH_DTYPE_DICT[np_dtype]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def circle_dtype_to_torch_dtype(circle_dtype: int) -> torch.dtype:
|
|
36
|
+
assert isinstance(circle_dtype, int)
|
|
37
|
+
if circle_dtype not in CIRCLE_TO_TORCH_DTYPE_DICT:
|
|
38
|
+
raise RuntimeError(f"Unsupported dtype {circle_dtype}")
|
|
39
|
+
|
|
40
|
+
torch_dtype = CIRCLE_TO_TORCH_DTYPE_DICT[circle_dtype]
|
|
41
|
+
assert torch_dtype is not None
|
|
42
|
+
return torch_dtype
|
tico/utils/register_custom_op.py
CHANGED
|
@@ -31,9 +31,11 @@ def CircleResizeNearestNeighbor():
|
|
|
31
31
|
W_scale_factor = size[2] / W
|
|
32
32
|
if H_scale_factor != W_scale_factor:
|
|
33
33
|
raise RuntimeError("Scale factor of H and W should be same.")
|
|
34
|
-
|
|
35
|
-
|
|
34
|
+
permuted = torch.permute(input_, [0, 3, 1, 2])
|
|
35
|
+
resized = torch.nn.functional.interpolate(
|
|
36
|
+
permuted, scale_factor=H_scale_factor, mode="nearest"
|
|
36
37
|
)
|
|
38
|
+
return torch.permute(resized, [0, 2, 3, 1])
|
|
37
39
|
|
|
38
40
|
@register_fake("circle_custom::resize_nearest_neighbor")
|
|
39
41
|
def _(input_: torch.Tensor, size: List[int]):
|
|
@@ -631,7 +633,7 @@ def CircleInstanceNorm():
|
|
|
631
633
|
bias: Optional[torch.Tensor] = None,
|
|
632
634
|
running_mean: Optional[torch.Tensor] = None,
|
|
633
635
|
running_var: Optional[torch.Tensor] = None,
|
|
634
|
-
use_input_stats: bool =
|
|
636
|
+
use_input_stats: bool = True,
|
|
635
637
|
momentum: float = 0.1,
|
|
636
638
|
eps: float = 1e-05,
|
|
637
639
|
cudnn_enabled: bool = False,
|
|
@@ -639,7 +641,7 @@ def CircleInstanceNorm():
|
|
|
639
641
|
NHWC_to_NCHW = [0, 3, 1, 2]
|
|
640
642
|
NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
|
|
641
643
|
|
|
642
|
-
args = [NCHW_input, weight, bias, None, None,
|
|
644
|
+
args = [NCHW_input, weight, bias, None, None, True, momentum, eps, False]
|
|
643
645
|
NCHW_output = torch.ops.aten.instance_norm.default(*args)
|
|
644
646
|
NCHW_to_NHWC = [0, 2, 3, 1]
|
|
645
647
|
NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
|
|
@@ -703,6 +705,28 @@ def CircleQuantizeMX():
|
|
|
703
705
|
return input_
|
|
704
706
|
|
|
705
707
|
|
|
708
|
+
def CircleRMSNorm():
|
|
709
|
+
@custom_op("circle_custom::rms_norm", mutates_args=())
|
|
710
|
+
def rms_norm(
|
|
711
|
+
hidden_states: torch.Tensor,
|
|
712
|
+
weight: torch.Tensor,
|
|
713
|
+
eps: float = 1e-05,
|
|
714
|
+
) -> torch.Tensor:
|
|
715
|
+
input_dtype = hidden_states.dtype
|
|
716
|
+
hidden_states = hidden_states.to(torch.float32)
|
|
717
|
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
718
|
+
hidden_states = hidden_states * torch.rsqrt(variance + eps)
|
|
719
|
+
return weight * hidden_states.to(input_dtype)
|
|
720
|
+
|
|
721
|
+
@register_fake("circle_custom::rms_norm")
|
|
722
|
+
def _(
|
|
723
|
+
hidden_states: torch.Tensor,
|
|
724
|
+
weight: torch.Tensor,
|
|
725
|
+
eps: float = 1e-05,
|
|
726
|
+
) -> torch.Tensor:
|
|
727
|
+
return hidden_states.new_empty(hidden_states.size())
|
|
728
|
+
|
|
729
|
+
|
|
706
730
|
# Add custom ops to the torch namespace
|
|
707
731
|
def RegisterOps():
|
|
708
732
|
CircleResizeNearestNeighbor()
|
|
@@ -715,3 +739,4 @@ def RegisterOps():
|
|
|
715
739
|
CircleAvgPool2D()
|
|
716
740
|
CircleInstanceNorm()
|
|
717
741
|
CircleQuantizeMX()
|
|
742
|
+
CircleRMSNorm()
|
tico/utils/signature.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Sequence
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
from circle_schema import circle
|
|
20
|
+
|
|
21
|
+
from tico.serialize.circle_mapping import to_circle_shape
|
|
22
|
+
from tico.utils.dtype import circle_dtype_to_torch_dtype
|
|
23
|
+
from tico.utils.installed_packages import is_dynamic_cache_available
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def is_dynamic_cache_instance(value):
|
|
27
|
+
if is_dynamic_cache_available():
|
|
28
|
+
from transformers.cache_utils import DynamicCache
|
|
29
|
+
|
|
30
|
+
return isinstance(value, DynamicCache)
|
|
31
|
+
else:
|
|
32
|
+
return False
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def flatten_and_convert_kwargs(kwargs: dict) -> dict[str, torch.Tensor]:
|
|
36
|
+
result = {} # type: ignore[var-annotated]
|
|
37
|
+
for k, v in kwargs.items():
|
|
38
|
+
if v is None:
|
|
39
|
+
continue
|
|
40
|
+
elif isinstance(v, (list, tuple)):
|
|
41
|
+
# 1. handle list
|
|
42
|
+
def unpack_recursive(name, value, store=None):
|
|
43
|
+
if store is None:
|
|
44
|
+
store = {}
|
|
45
|
+
|
|
46
|
+
if isinstance(value, (tuple, list)):
|
|
47
|
+
for i, v in enumerate(value):
|
|
48
|
+
# recursive call. Append index to name and explore lower level
|
|
49
|
+
unpack_recursive(f"{name}_{i}", v, store)
|
|
50
|
+
else:
|
|
51
|
+
# base type (scalar etc.) directly stored
|
|
52
|
+
store[name] = value
|
|
53
|
+
|
|
54
|
+
return store
|
|
55
|
+
|
|
56
|
+
unpack_recursive(k, v, result)
|
|
57
|
+
elif is_dynamic_cache_instance(v):
|
|
58
|
+
# 2. handle DynamicCache
|
|
59
|
+
for idx, cache_val in enumerate(v.key_cache):
|
|
60
|
+
result[f"{k}_key_cache_{idx}"] = cache_val
|
|
61
|
+
|
|
62
|
+
for idx, cache_val in enumerate(v.value_cache):
|
|
63
|
+
result[f"{k}_value_cache_{idx}"] = cache_val
|
|
64
|
+
else:
|
|
65
|
+
result[k] = v
|
|
66
|
+
|
|
67
|
+
# 3. Convert to tensors
|
|
68
|
+
for k, v in result.items():
|
|
69
|
+
result[k] = v if isinstance(v, torch.Tensor) else torch.tensor(v)
|
|
70
|
+
|
|
71
|
+
return result
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def flatten_and_convert_args(args: Sequence) -> tuple:
|
|
75
|
+
result = [] # type: ignore[var-annotated]
|
|
76
|
+
for item in args:
|
|
77
|
+
if item is None:
|
|
78
|
+
continue
|
|
79
|
+
|
|
80
|
+
# 1. recursion on list and tuple
|
|
81
|
+
if isinstance(item, (list, tuple)):
|
|
82
|
+
result.extend(flatten_and_convert_args(item))
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
# 2. handle DynamicCache
|
|
86
|
+
if is_dynamic_cache_available():
|
|
87
|
+
from transformers.cache_utils import DynamicCache
|
|
88
|
+
|
|
89
|
+
if isinstance(item, DynamicCache):
|
|
90
|
+
# NOTE The tensor order is: key_in → key_out → value_in → value_out
|
|
91
|
+
#
|
|
92
|
+
# Refer to https://github.com/huggingface/transformers/blob/3457e8e73e4f5532cc69059682b1ba4484d7e7e8/src/transformers/cache_utils.py#L557
|
|
93
|
+
# ```
|
|
94
|
+
# self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
|
95
|
+
# self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
|
96
|
+
# ```
|
|
97
|
+
result.extend(item.key_cache)
|
|
98
|
+
result.extend(item.value_cache)
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
# 3. Convert to tensors
|
|
102
|
+
result.append(item if isinstance(item, torch.Tensor) else torch.tensor(item))
|
|
103
|
+
|
|
104
|
+
return tuple(result)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class ModelInputSpec:
|
|
108
|
+
@classmethod
|
|
109
|
+
def load(cls, circle_path):
|
|
110
|
+
def load(circle_path: str) -> bytes:
|
|
111
|
+
with open(circle_path, "rb") as f:
|
|
112
|
+
buf = bytes(f.read())
|
|
113
|
+
return buf
|
|
114
|
+
|
|
115
|
+
circle_binary = load(circle_path)
|
|
116
|
+
return cls(circle_binary)
|
|
117
|
+
|
|
118
|
+
def __init__(self, circle_binary):
|
|
119
|
+
model = circle.Model.Model.GetRootAsModel(circle_binary, 0)
|
|
120
|
+
assert model.SubgraphsLength() == 1, "Only one subgraph is supported"
|
|
121
|
+
|
|
122
|
+
graph = model.Subgraphs(0)
|
|
123
|
+
tensors = [graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength())]
|
|
124
|
+
|
|
125
|
+
self.names = [t.Name().decode("utf-8").split("::")[-1] for t in tensors]
|
|
126
|
+
self.shapes = [t.ShapeAsNumpy() for t in tensors]
|
|
127
|
+
self.shape_signatures = list(
|
|
128
|
+
map(
|
|
129
|
+
lambda x: None if (isinstance(x, int) and x == 0) else x,
|
|
130
|
+
(t.ShapeSignatureAsNumpy() for t in tensors),
|
|
131
|
+
)
|
|
132
|
+
)
|
|
133
|
+
self.types: list[torch.dtype] = [
|
|
134
|
+
circle_dtype_to_torch_dtype(t.Type()) for t in tensors
|
|
135
|
+
]
|
|
136
|
+
self.name_to_idx = {name: idx for idx, name in enumerate(self.names)}
|
|
137
|
+
|
|
138
|
+
def bind(self, args, kwargs, check=True):
|
|
139
|
+
"""Convert args and kwargs into an ordered list according to model input order"""
|
|
140
|
+
inputs = []
|
|
141
|
+
args = flatten_and_convert_args(args)
|
|
142
|
+
kwargs = flatten_and_convert_kwargs(kwargs)
|
|
143
|
+
|
|
144
|
+
arg_num = len(args) + len(kwargs)
|
|
145
|
+
m_input_num = len(self.names)
|
|
146
|
+
if arg_num != m_input_num:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f"Mismatch: number of model inputs and number of passed arguments are not the same: inputs({m_input_num}) != passed({arg_num}), input spec: {self.names}"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# 1. positional arguments
|
|
152
|
+
for i, val in enumerate(args):
|
|
153
|
+
name = self.names[i]
|
|
154
|
+
inputs.append(val)
|
|
155
|
+
|
|
156
|
+
# 2. keyword arguments
|
|
157
|
+
for idx in range(len(args), len(self.names)):
|
|
158
|
+
name = self.names[idx]
|
|
159
|
+
inputs.append(kwargs[name])
|
|
160
|
+
|
|
161
|
+
if check:
|
|
162
|
+
self.check_types(inputs)
|
|
163
|
+
self.check_shapes(inputs)
|
|
164
|
+
|
|
165
|
+
return inputs
|
|
166
|
+
|
|
167
|
+
def check_types(self, inputs):
|
|
168
|
+
"""Check the types of input values"""
|
|
169
|
+
for i, (inp, ref_type) in enumerate(zip(inputs, self.types)):
|
|
170
|
+
# TODO: Support more data types (np array)
|
|
171
|
+
assert isinstance(
|
|
172
|
+
inp, (torch.Tensor | int | float)
|
|
173
|
+
), f"Input '{self.names[i]}' type must be a torch tensor or scalar."
|
|
174
|
+
|
|
175
|
+
if isinstance(inp, torch.Tensor):
|
|
176
|
+
if inp.dtype != ref_type:
|
|
177
|
+
raise TypeError(
|
|
178
|
+
f"Input '{self.names[i]}' type {inp.dtype} != expected {ref_type}"
|
|
179
|
+
)
|
|
180
|
+
else:
|
|
181
|
+
# Scalars (int, float)
|
|
182
|
+
if ref_type == torch.float32:
|
|
183
|
+
if not isinstance(inp, (float)):
|
|
184
|
+
raise TypeError(
|
|
185
|
+
f"Input '{self.names[i]}' type {type(inp)} != expected {ref_type}"
|
|
186
|
+
)
|
|
187
|
+
elif ref_type == torch.int64:
|
|
188
|
+
if not isinstance(inp, (int)):
|
|
189
|
+
raise TypeError(
|
|
190
|
+
f"Input '{self.names[i]}' type {type(inp)} != expected {ref_type}"
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
print(f"Unexpected ref_type: {ref_type}")
|
|
194
|
+
|
|
195
|
+
def check_shapes(self, inputs):
|
|
196
|
+
"""Check the shapes of input values"""
|
|
197
|
+
|
|
198
|
+
def merge(shape, shape_sig):
|
|
199
|
+
"""
|
|
200
|
+
Merge shape signature with shape
|
|
201
|
+
"""
|
|
202
|
+
from copy import deepcopy
|
|
203
|
+
|
|
204
|
+
shape_merged = deepcopy(shape)
|
|
205
|
+
if shape_sig is not None:
|
|
206
|
+
for idx, ss in enumerate(shape_sig):
|
|
207
|
+
if ss == -1:
|
|
208
|
+
shape_merged[idx] = -1
|
|
209
|
+
|
|
210
|
+
return shape_merged
|
|
211
|
+
|
|
212
|
+
for i, (inp, ref_shape, ref_shape_sig) in enumerate(
|
|
213
|
+
zip(inputs, self.shapes, self.shape_signatures)
|
|
214
|
+
):
|
|
215
|
+
# TODO: Support more data types (np array)
|
|
216
|
+
assert isinstance(
|
|
217
|
+
inp, (torch.Tensor | int | float)
|
|
218
|
+
), f"Input '{self.names[i]}' type must be a torch tensor or scalar."
|
|
219
|
+
|
|
220
|
+
if isinstance(inp, torch.Tensor): # Tensor
|
|
221
|
+
in_shape, in_shape_sig = to_circle_shape(inp.size())
|
|
222
|
+
|
|
223
|
+
if len(in_shape) != len(ref_shape):
|
|
224
|
+
raise ValueError(
|
|
225
|
+
f"Input '{self.names[i]}' has invalid rank {len(in_shape)}!= expected {len(ref_shape)}"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
in_merged_shape = merge(in_shape, in_shape_sig)
|
|
229
|
+
ref_merged_shape = merge(ref_shape, ref_shape_sig)
|
|
230
|
+
for in_shp, ref_shp in zip(in_merged_shape, ref_merged_shape):
|
|
231
|
+
if ref_shp == -1:
|
|
232
|
+
continue
|
|
233
|
+
if in_shp == -1:
|
|
234
|
+
raise ValueError(
|
|
235
|
+
f"Input '{self.names[i]}' has unknown dimension {inp.size()} != expected shape({ref_shape}) / shape signature({ref_shape_sig}) "
|
|
236
|
+
)
|
|
237
|
+
if in_shp != ref_shp:
|
|
238
|
+
raise ValueError(
|
|
239
|
+
f"Input '{self.names[i]}' has wrong dimension {inp.size()} != expected shape({ref_shape}) / shape signature({ref_shape_sig}) "
|
|
240
|
+
)
|
|
241
|
+
elif isinstance(inp, (int, float)): # Scalar
|
|
242
|
+
if len(ref_shape) > 0:
|
|
243
|
+
raise ValueError(
|
|
244
|
+
f"Input '{self.names[i]}' has invalid rank {len(ref_shape)}"
|
|
245
|
+
)
|
|
246
|
+
else:
|
|
247
|
+
print(f"Unexpected input type: {type(inp)}")
|