mct-nightly 2.2.0.20241211.531__py3-none-any.whl → 2.2.0.20241213.540__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241211.531.dist-info → mct_nightly-2.2.0.20241213.540.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241211.531.dist-info → mct_nightly-2.2.0.20241213.540.dist-info}/RECORD +21 -21
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/base_node.py +3 -2
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -2
- model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +83 -14
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +407 -475
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +5 -3
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +5 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +5 -6
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +5 -6
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +5 -6
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +9 -9
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +2 -2
- {mct_nightly-2.2.0.20241211.531.dist-info → mct_nightly-2.2.0.20241213.540.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241211.531.dist-info → mct_nightly-2.2.0.20241213.540.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20241211.531.dist-info → mct_nightly-2.2.0.20241213.540.dist-info}/top_level.txt +0 -0
@@ -12,23 +12,72 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
import copy
|
16
|
-
|
17
|
-
from enum import Enum
|
18
|
-
|
19
15
|
import pprint
|
20
16
|
|
17
|
+
from dataclasses import replace, dataclass, asdict, field
|
18
|
+
from enum import Enum
|
21
19
|
from typing import Dict, Any, Union, Tuple, List, Optional
|
22
|
-
|
23
20
|
from mct_quantizers import QuantizationMethod
|
24
21
|
from model_compression_toolkit.constants import FLOAT_BITWIDTH
|
25
|
-
|
26
22
|
from model_compression_toolkit.logger import Logger
|
27
23
|
from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST
|
28
|
-
from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
|
29
24
|
from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import \
|
30
|
-
|
31
|
-
|
25
|
+
_current_tp_model
|
26
|
+
|
27
|
+
class OperatorSetNames(Enum):
|
28
|
+
OPSET_CONV = "Conv"
|
29
|
+
OPSET_DEPTHWISE_CONV = "DepthwiseConv2D"
|
30
|
+
OPSET_CONV_TRANSPOSE = "ConvTraspose"
|
31
|
+
OPSET_FULLY_CONNECTED = "FullyConnected"
|
32
|
+
OPSET_CONCATENATE = "Concatenate"
|
33
|
+
OPSET_STACK = "Stack"
|
34
|
+
OPSET_UNSTACK = "Unstack"
|
35
|
+
OPSET_GATHER = "Gather"
|
36
|
+
OPSET_EXPAND = "Expend"
|
37
|
+
OPSET_BATCH_NORM = "BatchNorm"
|
38
|
+
OPSET_RELU = "ReLU"
|
39
|
+
OPSET_RELU6 = "ReLU6"
|
40
|
+
OPSET_LEAKY_RELU = "LEAKYReLU"
|
41
|
+
OPSET_HARD_TANH = "HardTanh"
|
42
|
+
OPSET_ADD = "Add"
|
43
|
+
OPSET_SUB = "Sub"
|
44
|
+
OPSET_MUL = "Mul"
|
45
|
+
OPSET_DIV = "Div"
|
46
|
+
OPSET_MIN_MAX = "MinMax"
|
47
|
+
OPSET_PRELU = "PReLU"
|
48
|
+
OPSET_SWISH = "Swish"
|
49
|
+
OPSET_SIGMOID = "Sigmoid"
|
50
|
+
OPSET_TANH = "Tanh"
|
51
|
+
OPSET_GELU = "Gelu"
|
52
|
+
OPSET_HARDSIGMOID = "HardSigmoid"
|
53
|
+
OPSET_HARDSWISH = "HardSwish"
|
54
|
+
OPSET_FLATTEN = "Flatten"
|
55
|
+
OPSET_GET_ITEM = "GetItem"
|
56
|
+
OPSET_RESHAPE = "Reshape"
|
57
|
+
OPSET_UNSQUEEZE = "Unsqueeze"
|
58
|
+
OPSET_SQUEEZE = "Squeeze"
|
59
|
+
OPSET_PERMUTE = "Permute"
|
60
|
+
OPSET_TRANSPOSE = "Transpose"
|
61
|
+
OPSET_DROPOUT = "Dropout"
|
62
|
+
OPSET_SPLIT = "Split"
|
63
|
+
OPSET_CHUNK = "Chunk"
|
64
|
+
OPSET_UNBIND = "Unbind"
|
65
|
+
OPSET_MAXPOOL = "MaxPool"
|
66
|
+
OPSET_SIZE = "Size"
|
67
|
+
OPSET_SHAPE = "Shape"
|
68
|
+
OPSET_EQUAL = "Equal"
|
69
|
+
OPSET_ARGMAX = "ArgMax"
|
70
|
+
OPSET_TOPK = "TopK"
|
71
|
+
OPSET_FAKE_QUANT_WITH_MIN_MAX_VARS = "FakeQuantWithMinMaxVars"
|
72
|
+
OPSET_COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression"
|
73
|
+
OPSET_CROPPING2D = "Cropping2D"
|
74
|
+
OPSET_ZERO_PADDING2d = "ZeroPadding2D"
|
75
|
+
OPSET_CAST = "Cast"
|
76
|
+
OPSET_STRIDED_SLICE = "StridedSlice"
|
77
|
+
|
78
|
+
@classmethod
|
79
|
+
def get_values(cls):
|
80
|
+
return [v.value for v in cls]
|
32
81
|
|
33
82
|
|
34
83
|
class Signedness(Enum):
|
@@ -44,451 +93,420 @@ class Signedness(Enum):
|
|
44
93
|
UNSIGNED = 2
|
45
94
|
|
46
95
|
|
96
|
+
@dataclass(frozen=True)
|
47
97
|
class AttributeQuantizationConfig:
|
48
98
|
"""
|
49
|
-
|
50
|
-
"""
|
51
|
-
def __init__(self,
|
52
|
-
weights_quantization_method: QuantizationMethod = QuantizationMethod.POWER_OF_TWO,
|
53
|
-
weights_n_bits: int = FLOAT_BITWIDTH,
|
54
|
-
weights_per_channel_threshold: bool = False,
|
55
|
-
enable_weights_quantization: bool = False,
|
56
|
-
lut_values_bitwidth: Union[int, None] = None, # If None - set 8 in hptq, o.w use it
|
57
|
-
):
|
58
|
-
"""
|
59
|
-
Initializes an attribute quantization config.
|
60
|
-
|
61
|
-
Args:
|
62
|
-
weights_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for weights quantization.
|
63
|
-
weights_n_bits (int): Number of bits to quantize the coefficients.
|
64
|
-
weights_per_channel_threshold (bool): Whether to quantize the weights per-channel or not (per-tensor).
|
65
|
-
enable_weights_quantization (bool): Whether to quantize the model weights or not.
|
66
|
-
lut_values_bitwidth (int): Number of bits to use when quantizing in look-up-table.
|
99
|
+
Holds the quantization configuration of a weight attribute of a layer.
|
67
100
|
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
101
|
+
Attributes:
|
102
|
+
weights_quantization_method (QuantizationMethod): The method to use from QuantizationMethod for weights quantization.
|
103
|
+
weights_n_bits (int): Number of bits to quantize the coefficients.
|
104
|
+
weights_per_channel_threshold (bool): Indicates whether to quantize the weights per-channel or per-tensor.
|
105
|
+
enable_weights_quantization (bool): Indicates whether to quantize the model weights or not.
|
106
|
+
lut_values_bitwidth (Optional[int]): Number of bits to use when quantizing in a look-up table.
|
107
|
+
If None, defaults to 8 in hptq; otherwise, it uses the provided value.
|
108
|
+
"""
|
109
|
+
weights_quantization_method: QuantizationMethod = QuantizationMethod.POWER_OF_TWO
|
110
|
+
weights_n_bits: int = FLOAT_BITWIDTH
|
111
|
+
weights_per_channel_threshold: bool = False
|
112
|
+
enable_weights_quantization: bool = False
|
113
|
+
lut_values_bitwidth: Optional[int] = None
|
75
114
|
|
76
|
-
def
|
115
|
+
def __post_init__(self):
|
77
116
|
"""
|
78
|
-
|
117
|
+
Post-initialization processing for input validation.
|
79
118
|
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
Returns:
|
84
|
-
Edited quantization configuration.
|
119
|
+
Raises:
|
120
|
+
Logger critical if attributes are of incorrect type or have invalid values.
|
85
121
|
"""
|
122
|
+
if not isinstance(self.weights_n_bits, int) or self.weights_n_bits < 1:
|
123
|
+
Logger.critical("weights_n_bits must be a positive integer.") # pragma: no cover
|
124
|
+
if not isinstance(self.enable_weights_quantization, bool):
|
125
|
+
Logger.critical("enable_weights_quantization must be a boolean.") # pragma: no cover
|
126
|
+
if self.lut_values_bitwidth is not None and not isinstance(self.lut_values_bitwidth, int):
|
127
|
+
Logger.critical("lut_values_bitwidth must be an integer or None.") # pragma: no cover
|
86
128
|
|
87
|
-
|
88
|
-
|
89
|
-
def __eq__(self, other):
|
129
|
+
def clone_and_edit(self, **kwargs) -> 'AttributeQuantizationConfig':
|
90
130
|
"""
|
91
|
-
|
131
|
+
Clone the current AttributeQuantizationConfig and edit some of its attributes.
|
92
132
|
|
93
133
|
Args:
|
94
|
-
|
134
|
+
**kwargs: Keyword arguments representing the attributes to edit in the cloned instance.
|
95
135
|
|
96
136
|
Returns:
|
97
|
-
|
98
|
-
Whether this configuration is equal to another object or not.
|
137
|
+
AttributeQuantizationConfig: A new instance of AttributeQuantizationConfig with updated attributes.
|
99
138
|
"""
|
100
|
-
|
101
|
-
return False # pragma: no cover
|
102
|
-
return self.weights_quantization_method == other.weights_quantization_method and \
|
103
|
-
self.weights_n_bits == other.weights_n_bits and \
|
104
|
-
self.weights_per_channel_threshold == other.weights_per_channel_threshold and \
|
105
|
-
self.enable_weights_quantization == other.enable_weights_quantization and \
|
106
|
-
self.lut_values_bitwidth == other.lut_values_bitwidth
|
139
|
+
return replace(self, **kwargs)
|
107
140
|
|
108
141
|
|
142
|
+
@dataclass(frozen=True)
|
109
143
|
class OpQuantizationConfig:
|
110
144
|
"""
|
111
145
|
OpQuantizationConfig is a class to configure the quantization parameters of an operator.
|
112
|
-
"""
|
113
146
|
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
):
|
127
|
-
"""
|
147
|
+
Args:
|
148
|
+
default_weight_attr_config (AttributeQuantizationConfig): A default attribute quantization configuration for the operation.
|
149
|
+
attr_weights_configs_mapping (Dict[str, AttributeQuantizationConfig]): A mapping between an op attribute name and its quantization configuration.
|
150
|
+
activation_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for activation quantization.
|
151
|
+
activation_n_bits (int): Number of bits to quantize the activations.
|
152
|
+
supported_input_activation_n_bits (int or Tuple[int]): Number of bits that operator accepts as input.
|
153
|
+
enable_activation_quantization (bool): Whether to quantize the model activations or not.
|
154
|
+
quantization_preserving (bool): Whether quantization parameters should be the same for an operator's input and output.
|
155
|
+
fixed_scale (float): Scale to use for an operator quantization parameters.
|
156
|
+
fixed_zero_point (int): Zero-point to use for an operator quantization parameters.
|
157
|
+
simd_size (int): Per op integer representing the Single Instruction, Multiple Data (SIMD) width of an operator. It indicates the number of data elements that can be fetched and processed simultaneously in a single instruction.
|
158
|
+
signedness (bool): Set activation quantization signedness.
|
128
159
|
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
if isinstance(supported_input_activation_n_bits,
|
150
|
-
self
|
151
|
-
elif isinstance(supported_input_activation_n_bits,
|
152
|
-
|
153
|
-
|
154
|
-
Logger.critical(f"Supported_input_activation_n_bits only accepts int or tuple of ints, but got {type(supported_input_activation_n_bits)}") # pragma: no cover
|
155
|
-
self.enable_activation_quantization = enable_activation_quantization
|
156
|
-
self.quantization_preserving = quantization_preserving
|
157
|
-
self.fixed_scale = fixed_scale
|
158
|
-
self.fixed_zero_point = fixed_zero_point
|
159
|
-
self.signedness = signedness
|
160
|
-
self.simd_size = simd_size
|
160
|
+
"""
|
161
|
+
default_weight_attr_config: AttributeQuantizationConfig
|
162
|
+
attr_weights_configs_mapping: Dict[str, AttributeQuantizationConfig]
|
163
|
+
activation_quantization_method: QuantizationMethod
|
164
|
+
activation_n_bits: int
|
165
|
+
supported_input_activation_n_bits: Union[int, Tuple[int]]
|
166
|
+
enable_activation_quantization: bool
|
167
|
+
quantization_preserving: bool
|
168
|
+
fixed_scale: float
|
169
|
+
fixed_zero_point: int
|
170
|
+
simd_size: int
|
171
|
+
signedness: Signedness
|
172
|
+
|
173
|
+
def __post_init__(self):
|
174
|
+
"""
|
175
|
+
Post-initialization processing for input validation.
|
176
|
+
|
177
|
+
Raises:
|
178
|
+
Logger critical if supported_input_activation_n_bits is not an int or a tuple of ints.
|
179
|
+
"""
|
180
|
+
if isinstance(self.supported_input_activation_n_bits, int):
|
181
|
+
object.__setattr__(self, 'supported_input_activation_n_bits', (self.supported_input_activation_n_bits,))
|
182
|
+
elif not isinstance(self.supported_input_activation_n_bits, tuple):
|
183
|
+
Logger.critical(
|
184
|
+
f"Supported_input_activation_n_bits only accepts int or tuple of ints, but got {type(self.supported_input_activation_n_bits)}") # pragma: no cover
|
161
185
|
|
162
|
-
def get_info(self):
|
186
|
+
def get_info(self) -> Dict[str, Any]:
|
163
187
|
"""
|
188
|
+
Get information about the quantization configuration.
|
164
189
|
|
165
|
-
Returns:
|
166
|
-
|
190
|
+
Returns:
|
191
|
+
dict: Information about the quantization configuration as a dictionary.
|
167
192
|
"""
|
168
|
-
return self
|
193
|
+
return asdict(self) # pragma: no cover
|
169
194
|
|
170
|
-
def clone_and_edit(self, attr_to_edit: Dict[str, Dict[str, Any]] = {}, **kwargs):
|
195
|
+
def clone_and_edit(self, attr_to_edit: Dict[str, Dict[str, Any]] = {}, **kwargs) -> 'OpQuantizationConfig':
|
171
196
|
"""
|
172
197
|
Clone the quantization config and edit some of its attributes.
|
198
|
+
|
173
199
|
Args:
|
174
|
-
attr_to_edit: A mapping between
|
175
|
-
|
200
|
+
attr_to_edit (Dict[str, Dict[str, Any]]): A mapping between attribute names to edit and their parameters that
|
201
|
+
should be edited to a new value.
|
176
202
|
**kwargs: Keyword arguments to edit the configuration to clone.
|
177
203
|
|
178
204
|
Returns:
|
179
|
-
Edited quantization configuration.
|
205
|
+
OpQuantizationConfig: Edited quantization configuration.
|
180
206
|
"""
|
181
207
|
|
182
|
-
|
208
|
+
# Clone and update top-level attributes
|
209
|
+
updated_config = replace(self, **kwargs)
|
183
210
|
|
184
|
-
#
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
211
|
+
# Clone and update nested immutable dataclasses in `attr_weights_configs_mapping`
|
212
|
+
updated_attr_mapping = {
|
213
|
+
attr_name: (attr_cfg.clone_and_edit(**attr_to_edit[attr_name])
|
214
|
+
if attr_name in attr_to_edit else attr_cfg)
|
215
|
+
for attr_name, attr_cfg in updated_config.attr_weights_configs_mapping.items()
|
216
|
+
}
|
189
217
|
|
190
|
-
|
218
|
+
# Return a new instance with the updated attribute mapping
|
219
|
+
return replace(updated_config, attr_weights_configs_mapping=updated_attr_mapping)
|
191
220
|
|
192
|
-
return qc
|
193
221
|
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
other: Object to compare.
|
222
|
+
@dataclass(frozen=True)
|
223
|
+
class QuantizationConfigOptions:
|
224
|
+
"""
|
225
|
+
QuantizationConfigOptions wraps a set of quantization configurations to consider during the quantization of an operator.
|
199
226
|
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
self.attr_weights_configs_mapping == other.attr_weights_configs_mapping and \
|
207
|
-
self.activation_quantization_method == other.activation_quantization_method and \
|
208
|
-
self.activation_n_bits == other.activation_n_bits and \
|
209
|
-
self.supported_input_activation_n_bits == other.supported_input_activation_n_bits and \
|
210
|
-
self.enable_activation_quantization == other.enable_activation_quantization and \
|
211
|
-
self.signedness == other.signedness and \
|
212
|
-
self.simd_size == other.simd_size
|
227
|
+
Attributes:
|
228
|
+
quantization_config_list (List[OpQuantizationConfig]): List of possible OpQuantizationConfig to gather.
|
229
|
+
base_config (Optional[OpQuantizationConfig]): Fallback OpQuantizationConfig to use when optimizing the model in a non-mixed-precision manner.
|
230
|
+
"""
|
231
|
+
quantization_config_list: List[OpQuantizationConfig]
|
232
|
+
base_config: Optional[OpQuantizationConfig] = None
|
213
233
|
|
214
|
-
|
215
|
-
def max_input_activation_n_bits(self) -> int:
|
234
|
+
def __post_init__(self):
|
216
235
|
"""
|
217
|
-
|
218
|
-
|
219
|
-
Returns: Maximum supported input bit-width.
|
236
|
+
Post-initialization processing for input validation.
|
220
237
|
|
238
|
+
Raises:
|
239
|
+
Logger critical if quantization_config_list is not a list, contains invalid elements, or if base_config is not set correctly.
|
221
240
|
"""
|
222
|
-
|
241
|
+
# Validate `quantization_config_list`
|
242
|
+
if not isinstance(self.quantization_config_list, list):
|
243
|
+
Logger.critical(
|
244
|
+
f"'quantization_config_list' must be a list, but received: {type(self.quantization_config_list)}.") # pragma: no cover
|
245
|
+
for cfg in self.quantization_config_list:
|
246
|
+
if not isinstance(cfg, OpQuantizationConfig):
|
247
|
+
Logger.critical(
|
248
|
+
f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}.") # pragma: no cover
|
223
249
|
|
250
|
+
# Handle base_config
|
251
|
+
if len(self.quantization_config_list) > 1:
|
252
|
+
if self.base_config is None:
|
253
|
+
Logger.critical(f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization.") # pragma: no cover
|
254
|
+
if not any(self.base_config == cfg for cfg in self.quantization_config_list):
|
255
|
+
Logger.critical(f"'base_config' must be included in the quantization config options list.") # pragma: no cover
|
256
|
+
elif len(self.quantization_config_list) == 1:
|
257
|
+
if self.base_config is None:
|
258
|
+
object.__setattr__(self, 'base_config', self.quantization_config_list[0])
|
259
|
+
elif self.base_config != self.quantization_config_list[0]:
|
260
|
+
Logger.critical(
|
261
|
+
"'base_config' should be the same as the sole item in 'quantization_config_list'.") # pragma: no cover
|
224
262
|
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
Wrap a set of quantization configurations to consider during the quantization
|
229
|
-
of an operator.
|
263
|
+
elif len(self.quantization_config_list) == 0:
|
264
|
+
Logger.critical("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.") # pragma: no cover
|
230
265
|
|
231
|
-
|
232
|
-
def __init__(self,
|
233
|
-
quantization_config_list: List[OpQuantizationConfig],
|
234
|
-
base_config: OpQuantizationConfig = None):
|
266
|
+
def clone_and_edit(self, **kwargs) -> 'QuantizationConfigOptions':
|
235
267
|
"""
|
268
|
+
Clone the quantization configuration options and edit attributes in each configuration.
|
236
269
|
|
237
270
|
Args:
|
238
|
-
|
239
|
-
base_config (OpQuantizationConfig): Fallback OpQuantizationConfig to use when optimizing the model in a non mixed-precision manner.
|
240
|
-
"""
|
241
|
-
|
242
|
-
assert isinstance(quantization_config_list,
|
243
|
-
list), f"'QuantizationConfigOptions' options list must be a list, but received: {type(quantization_config_list)}."
|
244
|
-
for cfg in quantization_config_list:
|
245
|
-
assert isinstance(cfg, OpQuantizationConfig),\
|
246
|
-
f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}."
|
247
|
-
self.quantization_config_list = quantization_config_list
|
248
|
-
if len(quantization_config_list) > 1:
|
249
|
-
assert base_config is not None, \
|
250
|
-
f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization."
|
251
|
-
assert any([base_config is cfg for cfg in quantization_config_list]), \
|
252
|
-
f"'base_config' must be included in the quantization config options list."
|
253
|
-
# Enforce base_config to be a reference to an instance in quantization_config_list.
|
254
|
-
self.base_config = base_config
|
255
|
-
elif len(quantization_config_list) == 1:
|
256
|
-
assert base_config is None or base_config == quantization_config_list[0], "'base_config' should be included in 'quantization_config_list'"
|
257
|
-
# Set base_config to be a reference to the first instance in quantization_config_list.
|
258
|
-
self.base_config = quantization_config_list[0]
|
259
|
-
else:
|
260
|
-
raise AssertionError("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.")
|
261
|
-
|
262
|
-
def __eq__(self, other):
|
263
|
-
"""
|
264
|
-
Is this QCOptions equal to another object.
|
265
|
-
Args:
|
266
|
-
other: Object to compare.
|
271
|
+
**kwargs: Keyword arguments to edit in each configuration.
|
267
272
|
|
268
273
|
Returns:
|
269
|
-
|
274
|
+
A new instance of QuantizationConfigOptions with updated configurations.
|
270
275
|
"""
|
276
|
+
updated_base_config = replace(self.base_config, **kwargs)
|
277
|
+
updated_configs_list = [
|
278
|
+
replace(cfg, **kwargs) for cfg in self.quantization_config_list
|
279
|
+
]
|
280
|
+
return replace(self, base_config=updated_base_config, quantization_config_list=updated_configs_list)
|
271
281
|
|
272
|
-
|
273
|
-
return False
|
274
|
-
if len(self.quantization_config_list) != len(other.quantization_config_list):
|
275
|
-
return False
|
276
|
-
for qc, other_qc in zip(self.quantization_config_list, other.quantization_config_list):
|
277
|
-
if qc != other_qc:
|
278
|
-
return False
|
279
|
-
return True
|
280
|
-
|
281
|
-
def clone_and_edit(self, **kwargs):
|
282
|
-
qc_options = copy.deepcopy(self)
|
283
|
-
for qc in qc_options.quantization_config_list:
|
284
|
-
self.__edit_quantization_configuration(qc, kwargs)
|
285
|
-
return qc_options
|
286
|
-
|
287
|
-
def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs):
|
282
|
+
def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs) -> 'QuantizationConfigOptions':
|
288
283
|
"""
|
289
284
|
Clones the quantization configurations and edits some of their attributes' parameters.
|
290
285
|
|
291
286
|
Args:
|
292
|
-
attrs:
|
293
|
-
of all attributes in the operation attributes config mapping.
|
287
|
+
attrs (List[str]): Attributes names to clone and edit their configurations. If None, updates all attributes.
|
294
288
|
**kwargs: Keyword arguments to edit in the attributes configuration.
|
295
289
|
|
296
290
|
Returns:
|
297
|
-
QuantizationConfigOptions with edited attributes configurations.
|
298
|
-
|
291
|
+
QuantizationConfigOptions: A new instance of QuantizationConfigOptions with edited attributes configurations.
|
299
292
|
"""
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
for qc in qc_options.quantization_config_list:
|
293
|
+
updated_base_config = self.base_config
|
294
|
+
updated_configs = []
|
295
|
+
for qc in self.quantization_config_list:
|
304
296
|
if attrs is None:
|
305
297
|
attrs_to_update = list(qc.attr_weights_configs_mapping.keys())
|
306
298
|
else:
|
307
|
-
if not isinstance(attrs, List): # pragma: no cover
|
308
|
-
Logger.critical(f"Expected a list of attributes but received {type(attrs)}.")
|
309
299
|
attrs_to_update = attrs
|
310
|
-
|
300
|
+
# Ensure all attributes exist in the config
|
311
301
|
for attr in attrs_to_update:
|
312
|
-
if
|
313
|
-
Logger.critical(f
|
314
|
-
|
315
|
-
|
316
|
-
|
302
|
+
if attr not in qc.attr_weights_configs_mapping:
|
303
|
+
Logger.critical(f"{attr} does not exist in {qc}.")
|
304
|
+
updated_attr_mapping = {
|
305
|
+
attr: qc.attr_weights_configs_mapping[attr].clone_and_edit(**kwargs)
|
306
|
+
for attr in attrs_to_update
|
307
|
+
}
|
308
|
+
if qc == updated_base_config:
|
309
|
+
updated_base_config = replace(updated_base_config, attr_weights_configs_mapping=updated_attr_mapping)
|
310
|
+
updated_configs.append(replace(qc, attr_weights_configs_mapping=updated_attr_mapping))
|
311
|
+
return replace(self, base_config=updated_base_config, quantization_config_list=updated_configs)
|
317
312
|
|
318
|
-
def clone_and_map_weights_attr_keys(self, layer_attrs_mapping:
|
313
|
+
def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Optional[Dict[str, str]]) -> 'QuantizationConfigOptions':
|
319
314
|
"""
|
320
|
-
|
321
|
-
based on the given attributes names mapping.
|
315
|
+
Clones the quantization configurations and updates keys in attribute config mappings.
|
322
316
|
|
323
317
|
Args:
|
324
|
-
layer_attrs_mapping: A mapping between
|
318
|
+
layer_attrs_mapping (Optional[Dict[str, str]]): A mapping between attribute names.
|
325
319
|
|
326
320
|
Returns:
|
327
|
-
QuantizationConfigOptions with
|
328
|
-
|
321
|
+
QuantizationConfigOptions: A new instance of QuantizationConfigOptions with updated attribute keys.
|
329
322
|
"""
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
# Check if the base_config is already included in the quantization configuration list
|
335
|
-
# If not, add base_config to the list of configurations to update
|
336
|
-
cfgs_to_update = [cfg for cfg in qc_options.quantization_config_list]
|
337
|
-
if not any(qc_options.base_config is cfg for cfg in cfgs_to_update):
|
338
|
-
# TODO: add test for this case
|
339
|
-
cfgs_to_update.append(qc_options.base_config)
|
340
|
-
|
341
|
-
for qc in cfgs_to_update:
|
323
|
+
updated_configs = []
|
324
|
+
new_base_config = self.base_config
|
325
|
+
for qc in self.quantization_config_list:
|
342
326
|
if layer_attrs_mapping is None:
|
343
|
-
qc.attr_weights_configs_mapping = {}
|
344
|
-
else:
|
345
327
|
new_attr_mapping = {}
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
return qc_options
|
328
|
+
else:
|
329
|
+
new_attr_mapping = {
|
330
|
+
layer_attrs_mapping.get(attr, attr): cfg
|
331
|
+
for attr, cfg in qc.attr_weights_configs_mapping.items()
|
332
|
+
}
|
333
|
+
if qc == self.base_config:
|
334
|
+
new_base_config = replace(qc, attr_weights_configs_mapping=new_attr_mapping)
|
335
|
+
updated_configs.append(replace(qc, attr_weights_configs_mapping=new_attr_mapping))
|
336
|
+
return replace(self, base_config=new_base_config, quantization_config_list=updated_configs)
|
356
337
|
|
357
|
-
def
|
358
|
-
|
359
|
-
|
360
|
-
k), (f'Editing is only possible for existing attributes in the configuration; '
|
361
|
-
f'{k} is not an attribute of {qc}.')
|
362
|
-
setattr(qc, k, v)
|
338
|
+
def get_info(self) -> Dict[str, Any]:
|
339
|
+
"""
|
340
|
+
Get detailed information about each quantization configuration option.
|
363
341
|
|
364
|
-
|
342
|
+
Returns:
|
343
|
+
dict: Information about the quantization configuration options as a dictionary.
|
344
|
+
"""
|
365
345
|
return {f'option {i}': cfg.get_info() for i, cfg in enumerate(self.quantization_config_list)}
|
366
346
|
|
367
347
|
|
348
|
+
@dataclass(frozen=True)
|
368
349
|
class TargetPlatformModelComponent:
|
369
350
|
"""
|
370
|
-
Component of TargetPlatformModel (Fusing, OperatorsSet, etc.)
|
351
|
+
Component of TargetPlatformModel (Fusing, OperatorsSet, etc.).
|
371
352
|
"""
|
372
|
-
def __init__(self, name: str):
|
373
|
-
"""
|
374
353
|
|
375
|
-
|
376
|
-
|
354
|
+
def __post_init__(self):
|
355
|
+
"""
|
356
|
+
Post-initialization to register the component with the current TargetPlatformModel.
|
377
357
|
"""
|
378
|
-
self.name = name
|
379
358
|
_current_tp_model.get().append_component(self)
|
380
359
|
|
381
360
|
def get_info(self) -> Dict[str, Any]:
|
382
361
|
"""
|
362
|
+
Get information about the component to display.
|
383
363
|
|
384
|
-
Returns:
|
385
|
-
|
386
|
-
|
364
|
+
Returns:
|
365
|
+
Dict[str, Any]: Returns an empty dictionary. The actual component should override
|
366
|
+
this method to provide relevant information.
|
387
367
|
"""
|
388
368
|
return {}
|
389
369
|
|
390
370
|
|
371
|
+
@dataclass(frozen=True)
|
391
372
|
class OperatorsSetBase(TargetPlatformModelComponent):
|
392
373
|
"""
|
393
|
-
Base class to represent a set of
|
374
|
+
Base class to represent a set of a target platform model component of operator set types.
|
375
|
+
Inherits from TargetPlatformModelComponent.
|
394
376
|
"""
|
395
|
-
def
|
377
|
+
def __post_init__(self):
|
396
378
|
"""
|
397
|
-
|
398
|
-
|
399
|
-
name: Name of OperatorsSet.
|
379
|
+
Post-initialization to ensure the component is registered with the TargetPlatformModel.
|
380
|
+
Calls the parent class's __post_init__ method to append this component to the current TargetPlatformModel.
|
400
381
|
"""
|
401
|
-
super().
|
382
|
+
super().__post_init__()
|
402
383
|
|
403
384
|
|
385
|
+
@dataclass(frozen=True)
|
404
386
|
class OperatorsSet(OperatorsSetBase):
|
405
|
-
|
406
|
-
|
407
|
-
qc_options: QuantizationConfigOptions = None):
|
408
|
-
"""
|
409
|
-
Set of operators that are represented by a unique label.
|
387
|
+
"""
|
388
|
+
Set of operators that are represented by a unique label.
|
410
389
|
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
390
|
+
Attributes:
|
391
|
+
name (str): The set's label (must be unique within a TargetPlatformModel).
|
392
|
+
qc_options (QuantizationConfigOptions): Configuration options to use for this set of operations.
|
393
|
+
If None, it represents a fusing set.
|
394
|
+
is_default (bool): Indicates whether this set is the default quantization configuration
|
395
|
+
for the TargetPlatformModel or a fusing set.
|
396
|
+
"""
|
397
|
+
name: str
|
398
|
+
qc_options: QuantizationConfigOptions = None
|
415
399
|
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
self.is_default = _current_tp_model.get().default_qco == self.qc_options or is_fusing_set
|
400
|
+
def __post_init__(self):
|
401
|
+
"""
|
402
|
+
Post-initialization processing to mark the operator set as default if applicable.
|
420
403
|
|
404
|
+
Calls the parent class's __post_init__ method and sets `is_default` to True
|
405
|
+
if this set corresponds to the default quantization configuration for the
|
406
|
+
TargetPlatformModel or if it is a fusing set.
|
421
407
|
|
422
|
-
def get_info(self) -> Dict[str,Any]:
|
423
408
|
"""
|
409
|
+
super().__post_init__()
|
410
|
+
is_fusing_set = self.qc_options is None
|
411
|
+
is_default = _current_tp_model.get().default_qco == self.qc_options or is_fusing_set
|
412
|
+
object.__setattr__(self, 'is_default', is_default)
|
424
413
|
|
425
|
-
|
414
|
+
def get_info(self) -> Dict[str, Any]:
|
415
|
+
"""
|
416
|
+
Get information about the set as a dictionary.
|
426
417
|
|
418
|
+
Returns:
|
419
|
+
Dict[str, Any]: A dictionary containing the set name and
|
420
|
+
whether it is the default quantization configuration.
|
427
421
|
"""
|
428
422
|
return {"name": self.name,
|
429
423
|
"is_default_qc": self.is_default}
|
430
424
|
|
431
425
|
|
426
|
+
@dataclass(frozen=True)
|
432
427
|
class OperatorSetConcat(OperatorsSetBase):
|
433
428
|
"""
|
434
429
|
Concatenate a list of operator sets to treat them similarly in different places (like fusing).
|
430
|
+
|
431
|
+
Attributes:
|
432
|
+
op_set_list (List[OperatorsSet]): List of operator sets to group.
|
433
|
+
qc_options (None): Configuration options for the set, always None for concatenated sets.
|
434
|
+
name (str): Concatenated name generated from the names of the operator sets in the list.
|
435
435
|
"""
|
436
|
-
|
437
|
-
|
438
|
-
|
436
|
+
op_set_list: List[OperatorsSet] = field(default_factory=list)
|
437
|
+
qc_options: None = field(default=None, init=False)
|
438
|
+
name: str = None
|
439
439
|
|
440
|
-
|
441
|
-
*opsets (OperatorsSet): List of operator sets to group.
|
440
|
+
def __post_init__(self):
|
442
441
|
"""
|
443
|
-
|
444
|
-
super().__init__(name=name)
|
445
|
-
self.op_set_list = opsets
|
446
|
-
self.qc_options = None # Concat have no qc options
|
442
|
+
Post-initialization processing to generate the concatenated name and set it as the `name` attribute.
|
447
443
|
|
448
|
-
|
444
|
+
Calls the parent class's __post_init__ method and creates a concatenated name
|
445
|
+
by joining the names of all operator sets in `op_set_list`.
|
449
446
|
"""
|
447
|
+
super().__post_init__()
|
448
|
+
# Generate the concatenated name from the operator sets
|
449
|
+
concatenated_name = "_".join([op.name for op in self.op_set_list])
|
450
|
+
# Set the inherited name attribute using `object.__setattr__` since the dataclass is frozen
|
451
|
+
object.__setattr__(self, "name", concatenated_name)
|
450
452
|
|
451
|
-
|
453
|
+
def get_info(self) -> Dict[str, Any]:
|
454
|
+
"""
|
455
|
+
Get information about the concatenated set as a dictionary.
|
452
456
|
|
457
|
+
Returns:
|
458
|
+
Dict[str, Any]: A dictionary containing the concatenated name and
|
459
|
+
the list of names of the operator sets in `op_set_list`.
|
453
460
|
"""
|
454
461
|
return {"name": self.name,
|
455
462
|
OPS_SET_LIST: [s.name for s in self.op_set_list]}
|
456
463
|
|
457
464
|
|
465
|
+
@dataclass(frozen=True)
|
458
466
|
class Fusing(TargetPlatformModelComponent):
|
459
467
|
"""
|
460
|
-
|
461
|
-
|
468
|
+
Fusing defines a list of operators that should be combined and treated as a single operator,
|
469
|
+
hence no quantization is applied between them.
|
470
|
+
|
471
|
+
Attributes:
|
472
|
+
operator_groups_list (Tuple[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups,
|
473
|
+
each being either an OperatorSetConcat or an OperatorsSet.
|
474
|
+
name (str): The name for the Fusing instance. If not provided, it is generated from the operator groups' names.
|
462
475
|
"""
|
476
|
+
operator_groups_list: Tuple[Union[OperatorsSet, OperatorSetConcat]]
|
477
|
+
name: str = None
|
463
478
|
|
464
|
-
def
|
465
|
-
operator_groups_list: List[Union[OperatorsSet, OperatorSetConcat]],
|
466
|
-
name: str = None):
|
467
|
-
"""
|
468
|
-
Args:
|
469
|
-
operator_groups_list (List[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups, each being either an OperatorSetConcat or an OperatorsSet.
|
470
|
-
name (str): The name for the Fusing instance. If not provided, it's generated from the operator groups' names.
|
479
|
+
def __post_init__(self):
|
471
480
|
"""
|
472
|
-
|
473
|
-
list), f'List of operator groups should be of type list but is {type(operator_groups_list)}'
|
474
|
-
assert len(operator_groups_list) >= 2, f'Fusing can not be created for a single operators group'
|
481
|
+
Post-initialization processing for input validation and name generation.
|
475
482
|
|
476
|
-
|
477
|
-
|
478
|
-
|
483
|
+
Calls the parent class's __post_init__ method, validates the operator_groups_list,
|
484
|
+
and generates the name if not explicitly provided.
|
485
|
+
|
486
|
+
Raises:
|
487
|
+
Logger critical if operator_groups_list is not a list or if it contains fewer than two operators.
|
488
|
+
"""
|
489
|
+
super().__post_init__()
|
490
|
+
# Validate the operator_groups_list
|
491
|
+
if not isinstance(self.operator_groups_list, list):
|
492
|
+
Logger.critical(
|
493
|
+
f"List of operator groups should be of type list but is {type(self.operator_groups_list)}.") # pragma: no cover
|
494
|
+
if len(self.operator_groups_list) < 2:
|
495
|
+
Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
|
479
496
|
|
480
|
-
|
481
|
-
|
497
|
+
# Generate the name from the operator groups if not provided
|
498
|
+
generated_name = '_'.join([x.name for x in self.operator_groups_list])
|
499
|
+
object.__setattr__(self, 'name', generated_name)
|
482
500
|
|
483
501
|
def contains(self, other: Any) -> bool:
|
484
502
|
"""
|
485
503
|
Determines if the current Fusing instance contains another Fusing instance.
|
486
504
|
|
487
505
|
Args:
|
488
|
-
other: The other Fusing instance to check against.
|
506
|
+
other (Any): The other Fusing instance to check against.
|
489
507
|
|
490
508
|
Returns:
|
491
|
-
|
509
|
+
bool: True if the other Fusing instance is contained within this one, False otherwise.
|
492
510
|
"""
|
493
511
|
if not isinstance(other, Fusing):
|
494
512
|
return False
|
@@ -506,208 +524,135 @@ class Fusing(TargetPlatformModelComponent):
|
|
506
524
|
# Other Fusing instance is not contained
|
507
525
|
return False
|
508
526
|
|
509
|
-
def get_info(self):
|
527
|
+
def get_info(self) -> Union[Dict[str, str], str]:
|
510
528
|
"""
|
511
529
|
Retrieves information about the Fusing instance, including its name and the sequence of operator groups.
|
512
530
|
|
513
531
|
Returns:
|
514
|
-
A dictionary with the Fusing instance's name as the key
|
515
|
-
|
532
|
+
Union[Dict[str, str], str]: A dictionary with the Fusing instance's name as the key
|
533
|
+
and the sequence of operator groups as the value,
|
534
|
+
or just the sequence of operator groups if no name is set.
|
516
535
|
"""
|
517
536
|
if self.name is not None:
|
518
537
|
return {self.name: ' -> '.join([x.name for x in self.operator_groups_list])}
|
519
538
|
return ' -> '.join([x.name for x in self.operator_groups_list])
|
520
539
|
|
521
540
|
|
522
|
-
|
541
|
+
@dataclass(frozen=True)
|
542
|
+
class TargetPlatformModel:
|
523
543
|
"""
|
524
544
|
Represents the hardware configuration used for quantized model inference.
|
525
545
|
|
526
|
-
This model defines:
|
527
|
-
- The operators and their associated quantization configurations.
|
528
|
-
- Fusing patterns, enabling multiple operators to be combined into a single operator
|
529
|
-
for optimization during inference.
|
530
|
-
- Versioning support through minor and patch versions for backward compatibility.
|
531
|
-
|
532
546
|
Attributes:
|
533
|
-
|
547
|
+
default_qco (QuantizationConfigOptions): Default quantization configuration options for the model.
|
548
|
+
tpc_minor_version (Optional[int]): Minor version of the Target Platform Configuration.
|
549
|
+
tpc_patch_version (Optional[int]): Patch version of the Target Platform Configuration.
|
550
|
+
tpc_platform_type (Optional[str]): Type of the platform for the Target Platform Configuration.
|
551
|
+
add_metadata (bool): Flag to determine if metadata should be added.
|
552
|
+
name (str): Name of the Target Platform Model.
|
553
|
+
operator_set (List[OperatorsSetBase]): List of operator sets within the model.
|
554
|
+
fusing_patterns (List[Fusing]): List of fusing patterns for the model.
|
555
|
+
is_simd_padding (bool): Indicates if SIMD padding is applied.
|
556
|
+
SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
|
534
557
|
"""
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
Args:
|
546
|
-
default_qco (QuantizationConfigOptions): Default QuantizationConfigOptions to use for operators that their QuantizationConfigOptions are not defined in the model.
|
547
|
-
tpc_minor_version (Optional[int]): The minor version of the target platform capabilities.
|
548
|
-
tpc_patch_version (Optional[int]): The patch version of the target platform capabilities.
|
549
|
-
tpc_platform_type (Optional[str]): The platform type of the target platform capabilities.
|
550
|
-
add_metadata (bool): Whether to add metadata to the model or not.
|
551
|
-
name (str): Name of the model.
|
552
|
-
|
553
|
-
Raises:
|
554
|
-
AssertionError: If the provided `default_qco` does not contain exactly one quantization configuration.
|
555
|
-
"""
|
556
|
-
|
557
|
-
super().__init__()
|
558
|
-
self.tpc_minor_version = tpc_minor_version
|
559
|
-
self.tpc_patch_version = tpc_patch_version
|
560
|
-
self.tpc_platform_type = tpc_platform_type
|
561
|
-
self.add_metadata = add_metadata
|
562
|
-
self.name = name
|
563
|
-
self.operator_set = []
|
564
|
-
assert isinstance(default_qco, QuantizationConfigOptions), \
|
565
|
-
"default_qco must be an instance of QuantizationConfigOptions"
|
566
|
-
assert len(default_qco.quantization_config_list) == 1, \
|
567
|
-
"Default QuantizationConfigOptions must contain exactly one option."
|
568
|
-
|
569
|
-
self.default_qco = default_qco
|
570
|
-
self.fusing_patterns = []
|
571
|
-
self.is_simd_padding = False
|
572
|
-
|
573
|
-
def get_config_options_by_operators_set(self,
|
574
|
-
operators_set_name: str) -> QuantizationConfigOptions:
|
575
|
-
"""
|
576
|
-
Get the QuantizationConfigOptions of a OperatorsSet by the OperatorsSet name.
|
577
|
-
If the name is not in the model, the default QuantizationConfigOptions is returned.
|
578
|
-
|
579
|
-
Args:
|
580
|
-
operators_set_name: Name of OperatorsSet to get.
|
581
|
-
|
582
|
-
Returns:
|
583
|
-
QuantizationConfigOptions to use for ops in OperatorsSet named operators_set_name.
|
584
|
-
"""
|
585
|
-
for op_set in self.operator_set:
|
586
|
-
if operators_set_name == op_set.name:
|
587
|
-
return op_set.qc_options
|
588
|
-
return self.default_qco
|
589
|
-
|
590
|
-
def get_default_op_quantization_config(self) -> OpQuantizationConfig:
|
591
|
-
"""
|
592
|
-
|
593
|
-
Returns: The default OpQuantizationConfig of the TargetPlatformModel.
|
558
|
+
default_qco: QuantizationConfigOptions
|
559
|
+
tpc_minor_version: Optional[int]
|
560
|
+
tpc_patch_version: Optional[int]
|
561
|
+
tpc_platform_type: Optional[str]
|
562
|
+
add_metadata: bool = True
|
563
|
+
name: str = "default_tp_model"
|
564
|
+
operator_set: List[OperatorsSetBase] = field(default_factory=list)
|
565
|
+
fusing_patterns: List[Fusing] = field(default_factory=list)
|
566
|
+
is_simd_padding: bool = False
|
594
567
|
|
595
|
-
|
596
|
-
assert len(self.default_qco.quantization_config_list) == 1, \
|
597
|
-
f'Default quantization configuration options must contain only one option,' \
|
598
|
-
f' but found {len(get_current_tp_model().default_qco.quantization_config_list)} configurations.'
|
599
|
-
return self.default_qco.quantization_config_list[0]
|
568
|
+
SCHEMA_VERSION: int = 1
|
600
569
|
|
601
|
-
def
|
602
|
-
opset_name: str) -> bool:
|
570
|
+
def __post_init__(self):
|
603
571
|
"""
|
604
|
-
|
605
|
-
|
606
|
-
Args:
|
607
|
-
opset_name: Operators set name to check.
|
572
|
+
Post-initialization processing for input validation.
|
608
573
|
|
609
|
-
|
610
|
-
|
574
|
+
Raises:
|
575
|
+
Logger critical if the default_qco is not an instance of QuantizationConfigOptions
|
576
|
+
or if it contains more than one quantization configuration.
|
611
577
|
"""
|
612
|
-
|
578
|
+
# Validate `default_qco`
|
579
|
+
if not isinstance(self.default_qco, QuantizationConfigOptions):
|
580
|
+
Logger.critical("'default_qco' must be an instance of QuantizationConfigOptions.") # pragma: no cover
|
581
|
+
if len(self.default_qco.quantization_config_list) != 1:
|
582
|
+
Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
|
613
583
|
|
614
|
-
def
|
615
|
-
opset_name: str) -> OperatorsSetBase:
|
584
|
+
def append_component(self, tp_model_component: TargetPlatformModelComponent):
|
616
585
|
"""
|
617
|
-
|
618
|
-
If name is not in the model - None is returned.
|
586
|
+
Attach a TargetPlatformModel component to the model (like Fusing or OperatorsSet).
|
619
587
|
|
620
588
|
Args:
|
621
|
-
|
622
|
-
|
623
|
-
Returns:
|
624
|
-
OperatorsSet object with the name opset_name, or None if opset_name is not in the model.
|
625
|
-
"""
|
626
|
-
|
627
|
-
opset_list = [x for x in self.operator_set if x.name == opset_name]
|
628
|
-
assert len(opset_list) <= 1, f'Found more than one OperatorsSet in' \
|
629
|
-
f' TargetPlatformModel with the name {opset_name}. ' \
|
630
|
-
f'OperatorsSet name must be unique.'
|
631
|
-
if len(opset_list) == 0: # opset_name is not in the model.
|
632
|
-
return None
|
633
|
-
|
634
|
-
return opset_list[0] # There's one opset with that name
|
635
|
-
|
636
|
-
def append_component(self,
|
637
|
-
tp_model_component: TargetPlatformModelComponent):
|
638
|
-
"""
|
639
|
-
Attach a TargetPlatformModel component to the model. Components can be for example:
|
640
|
-
Fusing, OperatorsSet, etc.
|
641
|
-
|
642
|
-
Args:
|
643
|
-
tp_model_component: Component to attach to the model.
|
589
|
+
tp_model_component (TargetPlatformModelComponent): Component to attach to the model.
|
644
590
|
|
591
|
+
Raises:
|
592
|
+
Logger critical if the component is not an instance of Fusing or OperatorsSetBase.
|
645
593
|
"""
|
646
594
|
if isinstance(tp_model_component, Fusing):
|
647
595
|
self.fusing_patterns.append(tp_model_component)
|
648
596
|
elif isinstance(tp_model_component, OperatorsSetBase):
|
649
597
|
self.operator_set.append(tp_model_component)
|
650
|
-
else:
|
651
|
-
Logger.critical(
|
652
|
-
|
653
|
-
def __enter__(self):
|
654
|
-
"""
|
655
|
-
Start defining the TargetPlatformModel using 'with'.
|
656
|
-
|
657
|
-
Returns: Initialized TargetPlatformModel object.
|
598
|
+
else:
|
599
|
+
Logger.critical(
|
600
|
+
f"Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.") # pragma: no cover
|
658
601
|
|
602
|
+
def get_info(self) -> Dict[str, Any]:
|
659
603
|
"""
|
660
|
-
|
661
|
-
return self
|
604
|
+
Get a dictionary summarizing the TargetPlatformModel properties.
|
662
605
|
|
663
|
-
|
664
|
-
|
665
|
-
Finish defining the TargetPlatformModel at the end of the 'with' clause.
|
666
|
-
Returns the final and immutable TargetPlatformModel instance.
|
606
|
+
Returns:
|
607
|
+
Dict[str, Any]: Summary of the TargetPlatformModel properties.
|
667
608
|
"""
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
_current_tp_model.reset()
|
674
|
-
self.initialized_done() # Make model immutable.
|
675
|
-
return self
|
609
|
+
return {
|
610
|
+
"Model name": self.name,
|
611
|
+
"Operators sets": [o.get_info() for o in self.operator_set],
|
612
|
+
"Fusing patterns": [f.get_info() for f in self.fusing_patterns],
|
613
|
+
}
|
676
614
|
|
677
615
|
def __validate_model(self):
|
678
616
|
"""
|
617
|
+
Validate the model's configuration to ensure its integrity.
|
679
618
|
|
680
|
-
|
681
|
-
|
682
|
-
as their names should be unique.
|
683
|
-
|
619
|
+
Raises:
|
620
|
+
Logger critical if the model contains multiple operator sets with the same name.
|
684
621
|
"""
|
685
622
|
opsets_names = [op.name for op in self.operator_set]
|
686
623
|
if len(set(opsets_names)) != len(opsets_names):
|
687
|
-
Logger.critical(
|
624
|
+
Logger.critical("Operator Sets must have unique names.") # pragma: no cover
|
688
625
|
|
689
|
-
def
|
626
|
+
def __enter__(self) -> 'TargetPlatformModel':
|
690
627
|
"""
|
628
|
+
Start defining the TargetPlatformModel using a 'with' statement.
|
691
629
|
|
692
630
|
Returns:
|
693
|
-
|
631
|
+
TargetPlatformModel: The initialized TargetPlatformModel object.
|
694
632
|
"""
|
695
|
-
|
696
|
-
|
697
|
-
f' but found {len(self.default_qco.quantization_config_list)} configurations.'
|
698
|
-
return self.default_qco.quantization_config_list[0]
|
633
|
+
_current_tp_model.set(self)
|
634
|
+
return self
|
699
635
|
|
700
|
-
def
|
636
|
+
def __exit__(self, exc_type, exc_value, tb) -> 'TargetPlatformModel':
|
701
637
|
"""
|
638
|
+
Finalize and validate the TargetPlatformModel at the end of the 'with' clause.
|
639
|
+
|
640
|
+
Args:
|
641
|
+
exc_type: Exception type, if any occurred.
|
642
|
+
exc_value: Exception value, if any occurred.
|
643
|
+
tb: Traceback object, if an exception occurred.
|
702
644
|
|
703
|
-
|
645
|
+
Raises:
|
646
|
+
The exception raised in the 'with' block, if any.
|
704
647
|
|
648
|
+
Returns:
|
649
|
+
TargetPlatformModel: The validated TargetPlatformModel object.
|
705
650
|
"""
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
651
|
+
if exc_value is not None:
|
652
|
+
raise exc_value
|
653
|
+
self.__validate_model()
|
654
|
+
_current_tp_model.reset()
|
655
|
+
return self
|
711
656
|
|
712
657
|
def show(self):
|
713
658
|
"""
|
@@ -715,17 +660,4 @@ class TargetPlatformModel(ImmutableClass):
|
|
715
660
|
Display the TargetPlatformModel.
|
716
661
|
|
717
662
|
"""
|
718
|
-
pprint.pprint(self.get_info(), sort_dicts=False)
|
719
|
-
|
720
|
-
def set_simd_padding(self,
|
721
|
-
is_simd_padding: bool):
|
722
|
-
"""
|
723
|
-
Set flag is_simd_padding to indicate whether this TP model defines
|
724
|
-
that padding due to SIMD constrains occurs.
|
725
|
-
|
726
|
-
Args:
|
727
|
-
is_simd_padding: Whether this TP model defines that padding due to SIMD constrains occurs.
|
728
|
-
|
729
|
-
"""
|
730
|
-
self.is_simd_padding = is_simd_padding
|
731
|
-
|
663
|
+
pprint.pprint(self.get_info(), sort_dicts=False)
|