mct-nightly 2.2.0.20241211.531__py3-none-any.whl → 2.2.0.20241213.540__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {mct_nightly-2.2.0.20241211.531.dist-info → mct_nightly-2.2.0.20241213.540.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20241211.531.dist-info → mct_nightly-2.2.0.20241213.540.dist-info}/RECORD +21 -21
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/graph/base_node.py +3 -2
  5. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -2
  6. model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +83 -14
  7. model_compression_toolkit/target_platform_capabilities/schema/v1.py +407 -475
  8. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +5 -3
  9. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +5 -3
  10. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +5 -6
  11. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +3 -3
  12. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +3 -3
  13. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +5 -6
  14. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +3 -3
  15. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +5 -6
  16. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +3 -3
  17. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +9 -9
  18. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +2 -2
  19. {mct_nightly-2.2.0.20241211.531.dist-info → mct_nightly-2.2.0.20241213.540.dist-info}/LICENSE.md +0 -0
  20. {mct_nightly-2.2.0.20241211.531.dist-info → mct_nightly-2.2.0.20241213.540.dist-info}/WHEEL +0 -0
  21. {mct_nightly-2.2.0.20241211.531.dist-info → mct_nightly-2.2.0.20241213.540.dist-info}/top_level.txt +0 -0
@@ -12,23 +12,72 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import copy
16
-
17
- from enum import Enum
18
-
19
15
  import pprint
20
16
 
17
+ from dataclasses import replace, dataclass, asdict, field
18
+ from enum import Enum
21
19
  from typing import Dict, Any, Union, Tuple, List, Optional
22
-
23
20
  from mct_quantizers import QuantizationMethod
24
21
  from model_compression_toolkit.constants import FLOAT_BITWIDTH
25
-
26
22
  from model_compression_toolkit.logger import Logger
27
23
  from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST
28
- from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
29
24
  from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import \
30
- get_current_tp_model, _current_tp_model
31
- from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import clone_and_edit_object_params
25
+ _current_tp_model
26
+
27
+ class OperatorSetNames(Enum):
28
+ OPSET_CONV = "Conv"
29
+ OPSET_DEPTHWISE_CONV = "DepthwiseConv2D"
30
+ OPSET_CONV_TRANSPOSE = "ConvTraspose"
31
+ OPSET_FULLY_CONNECTED = "FullyConnected"
32
+ OPSET_CONCATENATE = "Concatenate"
33
+ OPSET_STACK = "Stack"
34
+ OPSET_UNSTACK = "Unstack"
35
+ OPSET_GATHER = "Gather"
36
+ OPSET_EXPAND = "Expend"
37
+ OPSET_BATCH_NORM = "BatchNorm"
38
+ OPSET_RELU = "ReLU"
39
+ OPSET_RELU6 = "ReLU6"
40
+ OPSET_LEAKY_RELU = "LEAKYReLU"
41
+ OPSET_HARD_TANH = "HardTanh"
42
+ OPSET_ADD = "Add"
43
+ OPSET_SUB = "Sub"
44
+ OPSET_MUL = "Mul"
45
+ OPSET_DIV = "Div"
46
+ OPSET_MIN_MAX = "MinMax"
47
+ OPSET_PRELU = "PReLU"
48
+ OPSET_SWISH = "Swish"
49
+ OPSET_SIGMOID = "Sigmoid"
50
+ OPSET_TANH = "Tanh"
51
+ OPSET_GELU = "Gelu"
52
+ OPSET_HARDSIGMOID = "HardSigmoid"
53
+ OPSET_HARDSWISH = "HardSwish"
54
+ OPSET_FLATTEN = "Flatten"
55
+ OPSET_GET_ITEM = "GetItem"
56
+ OPSET_RESHAPE = "Reshape"
57
+ OPSET_UNSQUEEZE = "Unsqueeze"
58
+ OPSET_SQUEEZE = "Squeeze"
59
+ OPSET_PERMUTE = "Permute"
60
+ OPSET_TRANSPOSE = "Transpose"
61
+ OPSET_DROPOUT = "Dropout"
62
+ OPSET_SPLIT = "Split"
63
+ OPSET_CHUNK = "Chunk"
64
+ OPSET_UNBIND = "Unbind"
65
+ OPSET_MAXPOOL = "MaxPool"
66
+ OPSET_SIZE = "Size"
67
+ OPSET_SHAPE = "Shape"
68
+ OPSET_EQUAL = "Equal"
69
+ OPSET_ARGMAX = "ArgMax"
70
+ OPSET_TOPK = "TopK"
71
+ OPSET_FAKE_QUANT_WITH_MIN_MAX_VARS = "FakeQuantWithMinMaxVars"
72
+ OPSET_COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression"
73
+ OPSET_CROPPING2D = "Cropping2D"
74
+ OPSET_ZERO_PADDING2d = "ZeroPadding2D"
75
+ OPSET_CAST = "Cast"
76
+ OPSET_STRIDED_SLICE = "StridedSlice"
77
+
78
+ @classmethod
79
+ def get_values(cls):
80
+ return [v.value for v in cls]
32
81
 
33
82
 
34
83
  class Signedness(Enum):
@@ -44,451 +93,420 @@ class Signedness(Enum):
44
93
  UNSIGNED = 2
45
94
 
46
95
 
96
+ @dataclass(frozen=True)
47
97
  class AttributeQuantizationConfig:
48
98
  """
49
- Hold the quantization configuration of a weight attribute of a layer.
50
- """
51
- def __init__(self,
52
- weights_quantization_method: QuantizationMethod = QuantizationMethod.POWER_OF_TWO,
53
- weights_n_bits: int = FLOAT_BITWIDTH,
54
- weights_per_channel_threshold: bool = False,
55
- enable_weights_quantization: bool = False,
56
- lut_values_bitwidth: Union[int, None] = None, # If None - set 8 in hptq, o.w use it
57
- ):
58
- """
59
- Initializes an attribute quantization config.
60
-
61
- Args:
62
- weights_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for weights quantization.
63
- weights_n_bits (int): Number of bits to quantize the coefficients.
64
- weights_per_channel_threshold (bool): Whether to quantize the weights per-channel or not (per-tensor).
65
- enable_weights_quantization (bool): Whether to quantize the model weights or not.
66
- lut_values_bitwidth (int): Number of bits to use when quantizing in look-up-table.
99
+ Holds the quantization configuration of a weight attribute of a layer.
67
100
 
68
- """
69
-
70
- self.weights_quantization_method = weights_quantization_method
71
- self.weights_n_bits = weights_n_bits
72
- self.weights_per_channel_threshold = weights_per_channel_threshold
73
- self.enable_weights_quantization = enable_weights_quantization
74
- self.lut_values_bitwidth = lut_values_bitwidth
101
+ Attributes:
102
+ weights_quantization_method (QuantizationMethod): The method to use from QuantizationMethod for weights quantization.
103
+ weights_n_bits (int): Number of bits to quantize the coefficients.
104
+ weights_per_channel_threshold (bool): Indicates whether to quantize the weights per-channel or per-tensor.
105
+ enable_weights_quantization (bool): Indicates whether to quantize the model weights or not.
106
+ lut_values_bitwidth (Optional[int]): Number of bits to use when quantizing in a look-up table.
107
+ If None, defaults to 8 in hptq; otherwise, it uses the provided value.
108
+ """
109
+ weights_quantization_method: QuantizationMethod = QuantizationMethod.POWER_OF_TWO
110
+ weights_n_bits: int = FLOAT_BITWIDTH
111
+ weights_per_channel_threshold: bool = False
112
+ enable_weights_quantization: bool = False
113
+ lut_values_bitwidth: Optional[int] = None
75
114
 
76
- def clone_and_edit(self, **kwargs):
115
+ def __post_init__(self):
77
116
  """
78
- Clone the quantization config and edit some of its attributes.
117
+ Post-initialization processing for input validation.
79
118
 
80
- Args:
81
- **kwargs: Keyword arguments to edit the configuration to clone.
82
-
83
- Returns:
84
- Edited quantization configuration.
119
+ Raises:
120
+ Logger critical if attributes are of incorrect type or have invalid values.
85
121
  """
122
+ if not isinstance(self.weights_n_bits, int) or self.weights_n_bits < 1:
123
+ Logger.critical("weights_n_bits must be a positive integer.") # pragma: no cover
124
+ if not isinstance(self.enable_weights_quantization, bool):
125
+ Logger.critical("enable_weights_quantization must be a boolean.") # pragma: no cover
126
+ if self.lut_values_bitwidth is not None and not isinstance(self.lut_values_bitwidth, int):
127
+ Logger.critical("lut_values_bitwidth must be an integer or None.") # pragma: no cover
86
128
 
87
- return clone_and_edit_object_params(self, **kwargs)
88
-
89
- def __eq__(self, other):
129
+ def clone_and_edit(self, **kwargs) -> 'AttributeQuantizationConfig':
90
130
  """
91
- Is this configuration equal to another object.
131
+ Clone the current AttributeQuantizationConfig and edit some of its attributes.
92
132
 
93
133
  Args:
94
- other: Object to compare.
134
+ **kwargs: Keyword arguments representing the attributes to edit in the cloned instance.
95
135
 
96
136
  Returns:
97
-
98
- Whether this configuration is equal to another object or not.
137
+ AttributeQuantizationConfig: A new instance of AttributeQuantizationConfig with updated attributes.
99
138
  """
100
- if not isinstance(other, AttributeQuantizationConfig):
101
- return False # pragma: no cover
102
- return self.weights_quantization_method == other.weights_quantization_method and \
103
- self.weights_n_bits == other.weights_n_bits and \
104
- self.weights_per_channel_threshold == other.weights_per_channel_threshold and \
105
- self.enable_weights_quantization == other.enable_weights_quantization and \
106
- self.lut_values_bitwidth == other.lut_values_bitwidth
139
+ return replace(self, **kwargs)
107
140
 
108
141
 
142
+ @dataclass(frozen=True)
109
143
  class OpQuantizationConfig:
110
144
  """
111
145
  OpQuantizationConfig is a class to configure the quantization parameters of an operator.
112
- """
113
146
 
114
- def __init__(self,
115
- default_weight_attr_config: AttributeQuantizationConfig,
116
- attr_weights_configs_mapping: Dict[str, AttributeQuantizationConfig],
117
- activation_quantization_method: QuantizationMethod,
118
- activation_n_bits: int,
119
- supported_input_activation_n_bits: Union[int, Tuple[int]],
120
- enable_activation_quantization: bool,
121
- quantization_preserving: bool,
122
- fixed_scale: float,
123
- fixed_zero_point: int,
124
- simd_size: int,
125
- signedness: Signedness
126
- ):
127
- """
147
+ Args:
148
+ default_weight_attr_config (AttributeQuantizationConfig): A default attribute quantization configuration for the operation.
149
+ attr_weights_configs_mapping (Dict[str, AttributeQuantizationConfig]): A mapping between an op attribute name and its quantization configuration.
150
+ activation_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for activation quantization.
151
+ activation_n_bits (int): Number of bits to quantize the activations.
152
+ supported_input_activation_n_bits (int or Tuple[int]): Number of bits that operator accepts as input.
153
+ enable_activation_quantization (bool): Whether to quantize the model activations or not.
154
+ quantization_preserving (bool): Whether quantization parameters should be the same for an operator's input and output.
155
+ fixed_scale (float): Scale to use for an operator quantization parameters.
156
+ fixed_zero_point (int): Zero-point to use for an operator quantization parameters.
157
+ simd_size (int): Per op integer representing the Single Instruction, Multiple Data (SIMD) width of an operator. It indicates the number of data elements that can be fetched and processed simultaneously in a single instruction.
158
+ signedness (bool): Set activation quantization signedness.
128
159
 
129
- Args:
130
- default_weight_attr_config (AttributeQuantizationConfig): A default attribute quantization configuration for the operation.
131
- attr_weights_configs_mapping (Dict[str, AttributeQuantizationConfig]): A mapping between an op attribute name and its quantization configuration.
132
- activation_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for activation quantization.
133
- activation_n_bits (int): Number of bits to quantize the activations.
134
- supported_input_activation_n_bits (int or Tuple[int]): Number of bits that operator accepts as input.
135
- enable_activation_quantization (bool): Whether to quantize the model activations or not.
136
- quantization_preserving (bool): Whether quantization parameters should be the same for an operator's input and output.
137
- fixed_scale (float): Scale to use for an operator quantization parameters.
138
- fixed_zero_point (int): Zero-point to use for an operator quantization parameters.
139
- simd_size (int): Per op integer representing the Single Instruction, Multiple Data (SIMD) width of an operator. It indicates the number of data elements that can be fetched and processed simultaneously in a single instruction.
140
- signedness (bool): Set activation quantization signedness.
141
-
142
- """
143
-
144
- self.default_weight_attr_config = default_weight_attr_config
145
- self.attr_weights_configs_mapping = attr_weights_configs_mapping
146
-
147
- self.activation_quantization_method = activation_quantization_method
148
- self.activation_n_bits = activation_n_bits
149
- if isinstance(supported_input_activation_n_bits, tuple):
150
- self.supported_input_activation_n_bits = supported_input_activation_n_bits
151
- elif isinstance(supported_input_activation_n_bits, int):
152
- self.supported_input_activation_n_bits = (supported_input_activation_n_bits,)
153
- else:
154
- Logger.critical(f"Supported_input_activation_n_bits only accepts int or tuple of ints, but got {type(supported_input_activation_n_bits)}") # pragma: no cover
155
- self.enable_activation_quantization = enable_activation_quantization
156
- self.quantization_preserving = quantization_preserving
157
- self.fixed_scale = fixed_scale
158
- self.fixed_zero_point = fixed_zero_point
159
- self.signedness = signedness
160
- self.simd_size = simd_size
160
+ """
161
+ default_weight_attr_config: AttributeQuantizationConfig
162
+ attr_weights_configs_mapping: Dict[str, AttributeQuantizationConfig]
163
+ activation_quantization_method: QuantizationMethod
164
+ activation_n_bits: int
165
+ supported_input_activation_n_bits: Union[int, Tuple[int]]
166
+ enable_activation_quantization: bool
167
+ quantization_preserving: bool
168
+ fixed_scale: float
169
+ fixed_zero_point: int
170
+ simd_size: int
171
+ signedness: Signedness
172
+
173
+ def __post_init__(self):
174
+ """
175
+ Post-initialization processing for input validation.
176
+
177
+ Raises:
178
+ Logger critical if supported_input_activation_n_bits is not an int or a tuple of ints.
179
+ """
180
+ if isinstance(self.supported_input_activation_n_bits, int):
181
+ object.__setattr__(self, 'supported_input_activation_n_bits', (self.supported_input_activation_n_bits,))
182
+ elif not isinstance(self.supported_input_activation_n_bits, tuple):
183
+ Logger.critical(
184
+ f"Supported_input_activation_n_bits only accepts int or tuple of ints, but got {type(self.supported_input_activation_n_bits)}") # pragma: no cover
161
185
 
162
- def get_info(self):
186
+ def get_info(self) -> Dict[str, Any]:
163
187
  """
188
+ Get information about the quantization configuration.
164
189
 
165
- Returns: Info about the quantization configuration as a dictionary.
166
-
190
+ Returns:
191
+ dict: Information about the quantization configuration as a dictionary.
167
192
  """
168
- return self.__dict__ # pragma: no cover
193
+ return asdict(self) # pragma: no cover
169
194
 
170
- def clone_and_edit(self, attr_to_edit: Dict[str, Dict[str, Any]] = {}, **kwargs):
195
+ def clone_and_edit(self, attr_to_edit: Dict[str, Dict[str, Any]] = {}, **kwargs) -> 'OpQuantizationConfig':
171
196
  """
172
197
  Clone the quantization config and edit some of its attributes.
198
+
173
199
  Args:
174
- attr_to_edit: A mapping between attributes names to edit and their parameters that
175
- should be edited to a new value.
200
+ attr_to_edit (Dict[str, Dict[str, Any]]): A mapping between attribute names to edit and their parameters that
201
+ should be edited to a new value.
176
202
  **kwargs: Keyword arguments to edit the configuration to clone.
177
203
 
178
204
  Returns:
179
- Edited quantization configuration.
205
+ OpQuantizationConfig: Edited quantization configuration.
180
206
  """
181
207
 
182
- qc = clone_and_edit_object_params(self, **kwargs)
208
+ # Clone and update top-level attributes
209
+ updated_config = replace(self, **kwargs)
183
210
 
184
- # optionally: editing specific parameters in the config of specified attributes
185
- edited_attrs = copy.deepcopy(qc.attr_weights_configs_mapping)
186
- for attr_name, attr_cfg in qc.attr_weights_configs_mapping.items():
187
- if attr_name in attr_to_edit:
188
- edited_attrs[attr_name] = attr_cfg.clone_and_edit(**attr_to_edit[attr_name])
211
+ # Clone and update nested immutable dataclasses in `attr_weights_configs_mapping`
212
+ updated_attr_mapping = {
213
+ attr_name: (attr_cfg.clone_and_edit(**attr_to_edit[attr_name])
214
+ if attr_name in attr_to_edit else attr_cfg)
215
+ for attr_name, attr_cfg in updated_config.attr_weights_configs_mapping.items()
216
+ }
189
217
 
190
- qc.attr_weights_configs_mapping = edited_attrs
218
+ # Return a new instance with the updated attribute mapping
219
+ return replace(updated_config, attr_weights_configs_mapping=updated_attr_mapping)
191
220
 
192
- return qc
193
221
 
194
- def __eq__(self, other):
195
- """
196
- Is this configuration equal to another object.
197
- Args:
198
- other: Object to compare.
222
+ @dataclass(frozen=True)
223
+ class QuantizationConfigOptions:
224
+ """
225
+ QuantizationConfigOptions wraps a set of quantization configurations to consider during the quantization of an operator.
199
226
 
200
- Returns:
201
- Whether this configuration is equal to another object or not.
202
- """
203
- if not isinstance(other, OpQuantizationConfig):
204
- return False # pragma: no cover
205
- return self.default_weight_attr_config == other.default_weight_attr_config and \
206
- self.attr_weights_configs_mapping == other.attr_weights_configs_mapping and \
207
- self.activation_quantization_method == other.activation_quantization_method and \
208
- self.activation_n_bits == other.activation_n_bits and \
209
- self.supported_input_activation_n_bits == other.supported_input_activation_n_bits and \
210
- self.enable_activation_quantization == other.enable_activation_quantization and \
211
- self.signedness == other.signedness and \
212
- self.simd_size == other.simd_size
227
+ Attributes:
228
+ quantization_config_list (List[OpQuantizationConfig]): List of possible OpQuantizationConfig to gather.
229
+ base_config (Optional[OpQuantizationConfig]): Fallback OpQuantizationConfig to use when optimizing the model in a non-mixed-precision manner.
230
+ """
231
+ quantization_config_list: List[OpQuantizationConfig]
232
+ base_config: Optional[OpQuantizationConfig] = None
213
233
 
214
- @property
215
- def max_input_activation_n_bits(self) -> int:
234
+ def __post_init__(self):
216
235
  """
217
- Get maximum supported input bit-width.
218
-
219
- Returns: Maximum supported input bit-width.
236
+ Post-initialization processing for input validation.
220
237
 
238
+ Raises:
239
+ Logger critical if quantization_config_list is not a list, contains invalid elements, or if base_config is not set correctly.
221
240
  """
222
- return max(self.supported_input_activation_n_bits)
241
+ # Validate `quantization_config_list`
242
+ if not isinstance(self.quantization_config_list, list):
243
+ Logger.critical(
244
+ f"'quantization_config_list' must be a list, but received: {type(self.quantization_config_list)}.") # pragma: no cover
245
+ for cfg in self.quantization_config_list:
246
+ if not isinstance(cfg, OpQuantizationConfig):
247
+ Logger.critical(
248
+ f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}.") # pragma: no cover
223
249
 
250
+ # Handle base_config
251
+ if len(self.quantization_config_list) > 1:
252
+ if self.base_config is None:
253
+ Logger.critical(f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization.") # pragma: no cover
254
+ if not any(self.base_config == cfg for cfg in self.quantization_config_list):
255
+ Logger.critical(f"'base_config' must be included in the quantization config options list.") # pragma: no cover
256
+ elif len(self.quantization_config_list) == 1:
257
+ if self.base_config is None:
258
+ object.__setattr__(self, 'base_config', self.quantization_config_list[0])
259
+ elif self.base_config != self.quantization_config_list[0]:
260
+ Logger.critical(
261
+ "'base_config' should be the same as the sole item in 'quantization_config_list'.") # pragma: no cover
224
262
 
225
- class QuantizationConfigOptions:
226
- """
227
-
228
- Wrap a set of quantization configurations to consider during the quantization
229
- of an operator.
263
+ elif len(self.quantization_config_list) == 0:
264
+ Logger.critical("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.") # pragma: no cover
230
265
 
231
- """
232
- def __init__(self,
233
- quantization_config_list: List[OpQuantizationConfig],
234
- base_config: OpQuantizationConfig = None):
266
+ def clone_and_edit(self, **kwargs) -> 'QuantizationConfigOptions':
235
267
  """
268
+ Clone the quantization configuration options and edit attributes in each configuration.
236
269
 
237
270
  Args:
238
- quantization_config_list (List[OpQuantizationConfig]): List of possible OpQuantizationConfig to gather.
239
- base_config (OpQuantizationConfig): Fallback OpQuantizationConfig to use when optimizing the model in a non mixed-precision manner.
240
- """
241
-
242
- assert isinstance(quantization_config_list,
243
- list), f"'QuantizationConfigOptions' options list must be a list, but received: {type(quantization_config_list)}."
244
- for cfg in quantization_config_list:
245
- assert isinstance(cfg, OpQuantizationConfig),\
246
- f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}."
247
- self.quantization_config_list = quantization_config_list
248
- if len(quantization_config_list) > 1:
249
- assert base_config is not None, \
250
- f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization."
251
- assert any([base_config is cfg for cfg in quantization_config_list]), \
252
- f"'base_config' must be included in the quantization config options list."
253
- # Enforce base_config to be a reference to an instance in quantization_config_list.
254
- self.base_config = base_config
255
- elif len(quantization_config_list) == 1:
256
- assert base_config is None or base_config == quantization_config_list[0], "'base_config' should be included in 'quantization_config_list'"
257
- # Set base_config to be a reference to the first instance in quantization_config_list.
258
- self.base_config = quantization_config_list[0]
259
- else:
260
- raise AssertionError("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.")
261
-
262
- def __eq__(self, other):
263
- """
264
- Is this QCOptions equal to another object.
265
- Args:
266
- other: Object to compare.
271
+ **kwargs: Keyword arguments to edit in each configuration.
267
272
 
268
273
  Returns:
269
- Whether this QCOptions equal to another object or not.
274
+ A new instance of QuantizationConfigOptions with updated configurations.
270
275
  """
276
+ updated_base_config = replace(self.base_config, **kwargs)
277
+ updated_configs_list = [
278
+ replace(cfg, **kwargs) for cfg in self.quantization_config_list
279
+ ]
280
+ return replace(self, base_config=updated_base_config, quantization_config_list=updated_configs_list)
271
281
 
272
- if not isinstance(other, QuantizationConfigOptions):
273
- return False
274
- if len(self.quantization_config_list) != len(other.quantization_config_list):
275
- return False
276
- for qc, other_qc in zip(self.quantization_config_list, other.quantization_config_list):
277
- if qc != other_qc:
278
- return False
279
- return True
280
-
281
- def clone_and_edit(self, **kwargs):
282
- qc_options = copy.deepcopy(self)
283
- for qc in qc_options.quantization_config_list:
284
- self.__edit_quantization_configuration(qc, kwargs)
285
- return qc_options
286
-
287
- def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs):
282
+ def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs) -> 'QuantizationConfigOptions':
288
283
  """
289
284
  Clones the quantization configurations and edits some of their attributes' parameters.
290
285
 
291
286
  Args:
292
- attrs: attributes names to clone their configurations. If None is provided, updating the configurations
293
- of all attributes in the operation attributes config mapping.
287
+ attrs (List[str]): Attributes names to clone and edit their configurations. If None, updates all attributes.
294
288
  **kwargs: Keyword arguments to edit in the attributes configuration.
295
289
 
296
290
  Returns:
297
- QuantizationConfigOptions with edited attributes configurations.
298
-
291
+ QuantizationConfigOptions: A new instance of QuantizationConfigOptions with edited attributes configurations.
299
292
  """
300
-
301
- qc_options = copy.deepcopy(self)
302
-
303
- for qc in qc_options.quantization_config_list:
293
+ updated_base_config = self.base_config
294
+ updated_configs = []
295
+ for qc in self.quantization_config_list:
304
296
  if attrs is None:
305
297
  attrs_to_update = list(qc.attr_weights_configs_mapping.keys())
306
298
  else:
307
- if not isinstance(attrs, List): # pragma: no cover
308
- Logger.critical(f"Expected a list of attributes but received {type(attrs)}.")
309
299
  attrs_to_update = attrs
310
-
300
+ # Ensure all attributes exist in the config
311
301
  for attr in attrs_to_update:
312
- if qc.attr_weights_configs_mapping.get(attr) is None: # pragma: no cover
313
- Logger.critical(f'Editing attributes is only possible for existing attributes in the configuration\'s '
314
- f'weights config mapping; {attr} does not exist in {qc}.')
315
- self.__edit_quantization_configuration(qc.attr_weights_configs_mapping[attr], kwargs)
316
- return qc_options
302
+ if attr not in qc.attr_weights_configs_mapping:
303
+ Logger.critical(f"{attr} does not exist in {qc}.")
304
+ updated_attr_mapping = {
305
+ attr: qc.attr_weights_configs_mapping[attr].clone_and_edit(**kwargs)
306
+ for attr in attrs_to_update
307
+ }
308
+ if qc == updated_base_config:
309
+ updated_base_config = replace(updated_base_config, attr_weights_configs_mapping=updated_attr_mapping)
310
+ updated_configs.append(replace(qc, attr_weights_configs_mapping=updated_attr_mapping))
311
+ return replace(self, base_config=updated_base_config, quantization_config_list=updated_configs)
317
312
 
318
- def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Union[Dict[str, str], None]):
313
+ def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Optional[Dict[str, str]]) -> 'QuantizationConfigOptions':
319
314
  """
320
- Clones the quantization configuration options and edits the keys in each configuration attributes config mapping,
321
- based on the given attributes names mapping.
315
+ Clones the quantization configurations and updates keys in attribute config mappings.
322
316
 
323
317
  Args:
324
- layer_attrs_mapping: A mapping between attributes names.
318
+ layer_attrs_mapping (Optional[Dict[str, str]]): A mapping between attribute names.
325
319
 
326
320
  Returns:
327
- QuantizationConfigOptions with edited attributes names.
328
-
321
+ QuantizationConfigOptions: A new instance of QuantizationConfigOptions with updated attribute keys.
329
322
  """
330
- qc_options = copy.deepcopy(self)
331
-
332
- # Extract the list of existing quantization configurations from qc_options
333
-
334
- # Check if the base_config is already included in the quantization configuration list
335
- # If not, add base_config to the list of configurations to update
336
- cfgs_to_update = [cfg for cfg in qc_options.quantization_config_list]
337
- if not any(qc_options.base_config is cfg for cfg in cfgs_to_update):
338
- # TODO: add test for this case
339
- cfgs_to_update.append(qc_options.base_config)
340
-
341
- for qc in cfgs_to_update:
323
+ updated_configs = []
324
+ new_base_config = self.base_config
325
+ for qc in self.quantization_config_list:
342
326
  if layer_attrs_mapping is None:
343
- qc.attr_weights_configs_mapping = {}
344
- else:
345
327
  new_attr_mapping = {}
346
- for attr in list(qc.attr_weights_configs_mapping.keys()):
347
- new_key = layer_attrs_mapping.get(attr)
348
- if new_key is None: # pragma: no cover
349
- Logger.critical(f"Attribute \'{attr}\' does not exist in the provided attribute mapping.")
350
-
351
- new_attr_mapping[new_key] = qc.attr_weights_configs_mapping.pop(attr)
352
-
353
- qc.attr_weights_configs_mapping.update(new_attr_mapping)
354
-
355
- return qc_options
328
+ else:
329
+ new_attr_mapping = {
330
+ layer_attrs_mapping.get(attr, attr): cfg
331
+ for attr, cfg in qc.attr_weights_configs_mapping.items()
332
+ }
333
+ if qc == self.base_config:
334
+ new_base_config = replace(qc, attr_weights_configs_mapping=new_attr_mapping)
335
+ updated_configs.append(replace(qc, attr_weights_configs_mapping=new_attr_mapping))
336
+ return replace(self, base_config=new_base_config, quantization_config_list=updated_configs)
356
337
 
357
- def __edit_quantization_configuration(self, qc, kwargs):
358
- for k, v in kwargs.items():
359
- assert hasattr(qc,
360
- k), (f'Editing is only possible for existing attributes in the configuration; '
361
- f'{k} is not an attribute of {qc}.')
362
- setattr(qc, k, v)
338
+ def get_info(self) -> Dict[str, Any]:
339
+ """
340
+ Get detailed information about each quantization configuration option.
363
341
 
364
- def get_info(self):
342
+ Returns:
343
+ dict: Information about the quantization configuration options as a dictionary.
344
+ """
365
345
  return {f'option {i}': cfg.get_info() for i, cfg in enumerate(self.quantization_config_list)}
366
346
 
367
347
 
348
+ @dataclass(frozen=True)
368
349
  class TargetPlatformModelComponent:
369
350
  """
370
- Component of TargetPlatformModel (Fusing, OperatorsSet, etc.)
351
+ Component of TargetPlatformModel (Fusing, OperatorsSet, etc.).
371
352
  """
372
- def __init__(self, name: str):
373
- """
374
353
 
375
- Args:
376
- name: Name of component.
354
+ def __post_init__(self):
355
+ """
356
+ Post-initialization to register the component with the current TargetPlatformModel.
377
357
  """
378
- self.name = name
379
358
  _current_tp_model.get().append_component(self)
380
359
 
381
360
  def get_info(self) -> Dict[str, Any]:
382
361
  """
362
+ Get information about the component to display.
383
363
 
384
- Returns: Get information about the component to display (return an empty dictionary.
385
- the actual component should fill it with info).
386
-
364
+ Returns:
365
+ Dict[str, Any]: Returns an empty dictionary. The actual component should override
366
+ this method to provide relevant information.
387
367
  """
388
368
  return {}
389
369
 
390
370
 
371
+ @dataclass(frozen=True)
391
372
  class OperatorsSetBase(TargetPlatformModelComponent):
392
373
  """
393
- Base class to represent a set of operators.
374
+ Base class to represent a set of a target platform model component of operator set types.
375
+ Inherits from TargetPlatformModelComponent.
394
376
  """
395
- def __init__(self, name: str):
377
+ def __post_init__(self):
396
378
  """
397
-
398
- Args:
399
- name: Name of OperatorsSet.
379
+ Post-initialization to ensure the component is registered with the TargetPlatformModel.
380
+ Calls the parent class's __post_init__ method to append this component to the current TargetPlatformModel.
400
381
  """
401
- super().__init__(name=name)
382
+ super().__post_init__()
402
383
 
403
384
 
385
+ @dataclass(frozen=True)
404
386
  class OperatorsSet(OperatorsSetBase):
405
- def __init__(self,
406
- name: str,
407
- qc_options: QuantizationConfigOptions = None):
408
- """
409
- Set of operators that are represented by a unique label.
387
+ """
388
+ Set of operators that are represented by a unique label.
410
389
 
411
- Args:
412
- name (str): Set's label (must be unique in a TargetPlatformModel).
413
- qc_options (QuantizationConfigOptions): Configuration options to use for this set of operations.
414
- """
390
+ Attributes:
391
+ name (str): The set's label (must be unique within a TargetPlatformModel).
392
+ qc_options (QuantizationConfigOptions): Configuration options to use for this set of operations.
393
+ If None, it represents a fusing set.
394
+ is_default (bool): Indicates whether this set is the default quantization configuration
395
+ for the TargetPlatformModel or a fusing set.
396
+ """
397
+ name: str
398
+ qc_options: QuantizationConfigOptions = None
415
399
 
416
- super().__init__(name)
417
- self.qc_options = qc_options
418
- is_fusing_set = qc_options is None
419
- self.is_default = _current_tp_model.get().default_qco == self.qc_options or is_fusing_set
400
+ def __post_init__(self):
401
+ """
402
+ Post-initialization processing to mark the operator set as default if applicable.
420
403
 
404
+ Calls the parent class's __post_init__ method and sets `is_default` to True
405
+ if this set corresponds to the default quantization configuration for the
406
+ TargetPlatformModel or if it is a fusing set.
421
407
 
422
- def get_info(self) -> Dict[str,Any]:
423
408
  """
409
+ super().__post_init__()
410
+ is_fusing_set = self.qc_options is None
411
+ is_default = _current_tp_model.get().default_qco == self.qc_options or is_fusing_set
412
+ object.__setattr__(self, 'is_default', is_default)
424
413
 
425
- Returns: Info about the set as a dictionary.
414
+ def get_info(self) -> Dict[str, Any]:
415
+ """
416
+ Get information about the set as a dictionary.
426
417
 
418
+ Returns:
419
+ Dict[str, Any]: A dictionary containing the set name and
420
+ whether it is the default quantization configuration.
427
421
  """
428
422
  return {"name": self.name,
429
423
  "is_default_qc": self.is_default}
430
424
 
431
425
 
426
+ @dataclass(frozen=True)
432
427
  class OperatorSetConcat(OperatorsSetBase):
433
428
  """
434
429
  Concatenate a list of operator sets to treat them similarly in different places (like fusing).
430
+
431
+ Attributes:
432
+ op_set_list (List[OperatorsSet]): List of operator sets to group.
433
+ qc_options (None): Configuration options for the set, always None for concatenated sets.
434
+ name (str): Concatenated name generated from the names of the operator sets in the list.
435
435
  """
436
- def __init__(self, *opsets: OperatorsSet):
437
- """
438
- Group a list of operation sets.
436
+ op_set_list: List[OperatorsSet] = field(default_factory=list)
437
+ qc_options: None = field(default=None, init=False)
438
+ name: str = None
439
439
 
440
- Args:
441
- *opsets (OperatorsSet): List of operator sets to group.
440
+ def __post_init__(self):
442
441
  """
443
- name = "_".join([a.name for a in opsets])
444
- super().__init__(name=name)
445
- self.op_set_list = opsets
446
- self.qc_options = None # Concat have no qc options
442
+ Post-initialization processing to generate the concatenated name and set it as the `name` attribute.
447
443
 
448
- def get_info(self) -> Dict[str,Any]:
444
+ Calls the parent class's __post_init__ method and creates a concatenated name
445
+ by joining the names of all operator sets in `op_set_list`.
449
446
  """
447
+ super().__post_init__()
448
+ # Generate the concatenated name from the operator sets
449
+ concatenated_name = "_".join([op.name for op in self.op_set_list])
450
+ # Set the inherited name attribute using `object.__setattr__` since the dataclass is frozen
451
+ object.__setattr__(self, "name", concatenated_name)
450
452
 
451
- Returns: Info about the sets group as a dictionary.
453
+ def get_info(self) -> Dict[str, Any]:
454
+ """
455
+ Get information about the concatenated set as a dictionary.
452
456
 
457
+ Returns:
458
+ Dict[str, Any]: A dictionary containing the concatenated name and
459
+ the list of names of the operator sets in `op_set_list`.
453
460
  """
454
461
  return {"name": self.name,
455
462
  OPS_SET_LIST: [s.name for s in self.op_set_list]}
456
463
 
457
464
 
465
+ @dataclass(frozen=True)
458
466
  class Fusing(TargetPlatformModelComponent):
459
467
  """
460
- Fusing defines a list of operators that should be combined and treated as a single operator,
461
- hence no quantization is applied between them.
468
+ Fusing defines a list of operators that should be combined and treated as a single operator,
469
+ hence no quantization is applied between them.
470
+
471
+ Attributes:
472
+ operator_groups_list (Tuple[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups,
473
+ each being either an OperatorSetConcat or an OperatorsSet.
474
+ name (str): The name for the Fusing instance. If not provided, it is generated from the operator groups' names.
462
475
  """
476
+ operator_groups_list: Tuple[Union[OperatorsSet, OperatorSetConcat]]
477
+ name: str = None
463
478
 
464
- def __init__(self,
465
- operator_groups_list: List[Union[OperatorsSet, OperatorSetConcat]],
466
- name: str = None):
467
- """
468
- Args:
469
- operator_groups_list (List[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups, each being either an OperatorSetConcat or an OperatorsSet.
470
- name (str): The name for the Fusing instance. If not provided, it's generated from the operator groups' names.
479
+ def __post_init__(self):
471
480
  """
472
- assert isinstance(operator_groups_list,
473
- list), f'List of operator groups should be of type list but is {type(operator_groups_list)}'
474
- assert len(operator_groups_list) >= 2, f'Fusing can not be created for a single operators group'
481
+ Post-initialization processing for input validation and name generation.
475
482
 
476
- # Generate a name from the operator groups if no name is provided
477
- if name is None:
478
- name = '_'.join([x.name for x in operator_groups_list])
483
+ Calls the parent class's __post_init__ method, validates the operator_groups_list,
484
+ and generates the name if not explicitly provided.
485
+
486
+ Raises:
487
+ Logger critical if operator_groups_list is not a list or if it contains fewer than two operators.
488
+ """
489
+ super().__post_init__()
490
+ # Validate the operator_groups_list
491
+ if not isinstance(self.operator_groups_list, list):
492
+ Logger.critical(
493
+ f"List of operator groups should be of type list but is {type(self.operator_groups_list)}.") # pragma: no cover
494
+ if len(self.operator_groups_list) < 2:
495
+ Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
479
496
 
480
- super().__init__(name)
481
- self.operator_groups_list = operator_groups_list
497
+ # Generate the name from the operator groups if not provided
498
+ generated_name = '_'.join([x.name for x in self.operator_groups_list])
499
+ object.__setattr__(self, 'name', generated_name)
482
500
 
483
501
  def contains(self, other: Any) -> bool:
484
502
  """
485
503
  Determines if the current Fusing instance contains another Fusing instance.
486
504
 
487
505
  Args:
488
- other: The other Fusing instance to check against.
506
+ other (Any): The other Fusing instance to check against.
489
507
 
490
508
  Returns:
491
- A boolean indicating whether the other instance is contained within this one.
509
+ bool: True if the other Fusing instance is contained within this one, False otherwise.
492
510
  """
493
511
  if not isinstance(other, Fusing):
494
512
  return False
@@ -506,208 +524,135 @@ class Fusing(TargetPlatformModelComponent):
506
524
  # Other Fusing instance is not contained
507
525
  return False
508
526
 
509
- def get_info(self):
527
+ def get_info(self) -> Union[Dict[str, str], str]:
510
528
  """
511
529
  Retrieves information about the Fusing instance, including its name and the sequence of operator groups.
512
530
 
513
531
  Returns:
514
- A dictionary with the Fusing instance's name as the key and the sequence of operator groups as the value,
515
- or just the sequence of operator groups if no name is set.
532
+ Union[Dict[str, str], str]: A dictionary with the Fusing instance's name as the key
533
+ and the sequence of operator groups as the value,
534
+ or just the sequence of operator groups if no name is set.
516
535
  """
517
536
  if self.name is not None:
518
537
  return {self.name: ' -> '.join([x.name for x in self.operator_groups_list])}
519
538
  return ' -> '.join([x.name for x in self.operator_groups_list])
520
539
 
521
540
 
522
- class TargetPlatformModel(ImmutableClass):
541
+ @dataclass(frozen=True)
542
+ class TargetPlatformModel:
523
543
  """
524
544
  Represents the hardware configuration used for quantized model inference.
525
545
 
526
- This model defines:
527
- - The operators and their associated quantization configurations.
528
- - Fusing patterns, enabling multiple operators to be combined into a single operator
529
- for optimization during inference.
530
- - Versioning support through minor and patch versions for backward compatibility.
531
-
532
546
  Attributes:
533
- SCHEMA_VERSION (int): The schema version of the target platform model.
547
+ default_qco (QuantizationConfigOptions): Default quantization configuration options for the model.
548
+ tpc_minor_version (Optional[int]): Minor version of the Target Platform Configuration.
549
+ tpc_patch_version (Optional[int]): Patch version of the Target Platform Configuration.
550
+ tpc_platform_type (Optional[str]): Type of the platform for the Target Platform Configuration.
551
+ add_metadata (bool): Flag to determine if metadata should be added.
552
+ name (str): Name of the Target Platform Model.
553
+ operator_set (List[OperatorsSetBase]): List of operator sets within the model.
554
+ fusing_patterns (List[Fusing]): List of fusing patterns for the model.
555
+ is_simd_padding (bool): Indicates if SIMD padding is applied.
556
+ SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
534
557
  """
535
- SCHEMA_VERSION = 1
536
- def __init__(self,
537
- default_qco: QuantizationConfigOptions,
538
- tpc_minor_version: Optional[int],
539
- tpc_patch_version: Optional[int],
540
- tpc_platform_type: Optional[str],
541
- add_metadata: bool = True,
542
- name="default_tp_model"):
543
- """
544
-
545
- Args:
546
- default_qco (QuantizationConfigOptions): Default QuantizationConfigOptions to use for operators that their QuantizationConfigOptions are not defined in the model.
547
- tpc_minor_version (Optional[int]): The minor version of the target platform capabilities.
548
- tpc_patch_version (Optional[int]): The patch version of the target platform capabilities.
549
- tpc_platform_type (Optional[str]): The platform type of the target platform capabilities.
550
- add_metadata (bool): Whether to add metadata to the model or not.
551
- name (str): Name of the model.
552
-
553
- Raises:
554
- AssertionError: If the provided `default_qco` does not contain exactly one quantization configuration.
555
- """
556
-
557
- super().__init__()
558
- self.tpc_minor_version = tpc_minor_version
559
- self.tpc_patch_version = tpc_patch_version
560
- self.tpc_platform_type = tpc_platform_type
561
- self.add_metadata = add_metadata
562
- self.name = name
563
- self.operator_set = []
564
- assert isinstance(default_qco, QuantizationConfigOptions), \
565
- "default_qco must be an instance of QuantizationConfigOptions"
566
- assert len(default_qco.quantization_config_list) == 1, \
567
- "Default QuantizationConfigOptions must contain exactly one option."
568
-
569
- self.default_qco = default_qco
570
- self.fusing_patterns = []
571
- self.is_simd_padding = False
572
-
573
- def get_config_options_by_operators_set(self,
574
- operators_set_name: str) -> QuantizationConfigOptions:
575
- """
576
- Get the QuantizationConfigOptions of a OperatorsSet by the OperatorsSet name.
577
- If the name is not in the model, the default QuantizationConfigOptions is returned.
578
-
579
- Args:
580
- operators_set_name: Name of OperatorsSet to get.
581
-
582
- Returns:
583
- QuantizationConfigOptions to use for ops in OperatorsSet named operators_set_name.
584
- """
585
- for op_set in self.operator_set:
586
- if operators_set_name == op_set.name:
587
- return op_set.qc_options
588
- return self.default_qco
589
-
590
- def get_default_op_quantization_config(self) -> OpQuantizationConfig:
591
- """
592
-
593
- Returns: The default OpQuantizationConfig of the TargetPlatformModel.
558
+ default_qco: QuantizationConfigOptions
559
+ tpc_minor_version: Optional[int]
560
+ tpc_patch_version: Optional[int]
561
+ tpc_platform_type: Optional[str]
562
+ add_metadata: bool = True
563
+ name: str = "default_tp_model"
564
+ operator_set: List[OperatorsSetBase] = field(default_factory=list)
565
+ fusing_patterns: List[Fusing] = field(default_factory=list)
566
+ is_simd_padding: bool = False
594
567
 
595
- """
596
- assert len(self.default_qco.quantization_config_list) == 1, \
597
- f'Default quantization configuration options must contain only one option,' \
598
- f' but found {len(get_current_tp_model().default_qco.quantization_config_list)} configurations.'
599
- return self.default_qco.quantization_config_list[0]
568
+ SCHEMA_VERSION: int = 1
600
569
 
601
- def is_opset_in_model(self,
602
- opset_name: str) -> bool:
570
+ def __post_init__(self):
603
571
  """
604
- Check whether an operators set is defined in the model or not.
605
-
606
- Args:
607
- opset_name: Operators set name to check.
572
+ Post-initialization processing for input validation.
608
573
 
609
- Returns:
610
- Whether an operators set is defined in the model or not.
574
+ Raises:
575
+ Logger critical if the default_qco is not an instance of QuantizationConfigOptions
576
+ or if it contains more than one quantization configuration.
611
577
  """
612
- return opset_name in [x.name for x in self.operator_set]
578
+ # Validate `default_qco`
579
+ if not isinstance(self.default_qco, QuantizationConfigOptions):
580
+ Logger.critical("'default_qco' must be an instance of QuantizationConfigOptions.") # pragma: no cover
581
+ if len(self.default_qco.quantization_config_list) != 1:
582
+ Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
613
583
 
614
- def get_opset_by_name(self,
615
- opset_name: str) -> OperatorsSetBase:
584
+ def append_component(self, tp_model_component: TargetPlatformModelComponent):
616
585
  """
617
- Get an OperatorsSet object from the model by its name.
618
- If name is not in the model - None is returned.
586
+ Attach a TargetPlatformModel component to the model (like Fusing or OperatorsSet).
619
587
 
620
588
  Args:
621
- opset_name: OperatorsSet name to retrieve.
622
-
623
- Returns:
624
- OperatorsSet object with the name opset_name, or None if opset_name is not in the model.
625
- """
626
-
627
- opset_list = [x for x in self.operator_set if x.name == opset_name]
628
- assert len(opset_list) <= 1, f'Found more than one OperatorsSet in' \
629
- f' TargetPlatformModel with the name {opset_name}. ' \
630
- f'OperatorsSet name must be unique.'
631
- if len(opset_list) == 0: # opset_name is not in the model.
632
- return None
633
-
634
- return opset_list[0] # There's one opset with that name
635
-
636
- def append_component(self,
637
- tp_model_component: TargetPlatformModelComponent):
638
- """
639
- Attach a TargetPlatformModel component to the model. Components can be for example:
640
- Fusing, OperatorsSet, etc.
641
-
642
- Args:
643
- tp_model_component: Component to attach to the model.
589
+ tp_model_component (TargetPlatformModelComponent): Component to attach to the model.
644
590
 
591
+ Raises:
592
+ Logger critical if the component is not an instance of Fusing or OperatorsSetBase.
645
593
  """
646
594
  if isinstance(tp_model_component, Fusing):
647
595
  self.fusing_patterns.append(tp_model_component)
648
596
  elif isinstance(tp_model_component, OperatorsSetBase):
649
597
  self.operator_set.append(tp_model_component)
650
- else: # pragma: no cover
651
- Logger.critical(f'Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.')
652
-
653
- def __enter__(self):
654
- """
655
- Start defining the TargetPlatformModel using 'with'.
656
-
657
- Returns: Initialized TargetPlatformModel object.
598
+ else:
599
+ Logger.critical(
600
+ f"Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.") # pragma: no cover
658
601
 
602
+ def get_info(self) -> Dict[str, Any]:
659
603
  """
660
- _current_tp_model.set(self)
661
- return self
604
+ Get a dictionary summarizing the TargetPlatformModel properties.
662
605
 
663
- def __exit__(self, exc_type, exc_value, tb):
664
- """
665
- Finish defining the TargetPlatformModel at the end of the 'with' clause.
666
- Returns the final and immutable TargetPlatformModel instance.
606
+ Returns:
607
+ Dict[str, Any]: Summary of the TargetPlatformModel properties.
667
608
  """
668
-
669
- if exc_value is not None:
670
- print(exc_value, exc_value.args)
671
- raise exc_value
672
- self.__validate_model() # Assert that model is valid.
673
- _current_tp_model.reset()
674
- self.initialized_done() # Make model immutable.
675
- return self
609
+ return {
610
+ "Model name": self.name,
611
+ "Operators sets": [o.get_info() for o in self.operator_set],
612
+ "Fusing patterns": [f.get_info() for f in self.fusing_patterns],
613
+ }
676
614
 
677
615
  def __validate_model(self):
678
616
  """
617
+ Validate the model's configuration to ensure its integrity.
679
618
 
680
- Assert model is valid.
681
- Model is invalid if, for example, it contains multiple operator sets with the same name,
682
- as their names should be unique.
683
-
619
+ Raises:
620
+ Logger critical if the model contains multiple operator sets with the same name.
684
621
  """
685
622
  opsets_names = [op.name for op in self.operator_set]
686
623
  if len(set(opsets_names)) != len(opsets_names):
687
- Logger.critical(f'Operator Sets must have unique names.')
624
+ Logger.critical("Operator Sets must have unique names.") # pragma: no cover
688
625
 
689
- def get_default_config(self) -> OpQuantizationConfig:
626
+ def __enter__(self) -> 'TargetPlatformModel':
690
627
  """
628
+ Start defining the TargetPlatformModel using a 'with' statement.
691
629
 
692
630
  Returns:
693
-
631
+ TargetPlatformModel: The initialized TargetPlatformModel object.
694
632
  """
695
- assert len(self.default_qco.quantization_config_list) == 1, \
696
- f'Default quantization configuration options must contain only one option,' \
697
- f' but found {len(self.default_qco.quantization_config_list)} configurations.'
698
- return self.default_qco.quantization_config_list[0]
633
+ _current_tp_model.set(self)
634
+ return self
699
635
 
700
- def get_info(self) -> Dict[str, Any]:
636
+ def __exit__(self, exc_type, exc_value, tb) -> 'TargetPlatformModel':
701
637
  """
638
+ Finalize and validate the TargetPlatformModel at the end of the 'with' clause.
639
+
640
+ Args:
641
+ exc_type: Exception type, if any occurred.
642
+ exc_value: Exception value, if any occurred.
643
+ tb: Traceback object, if an exception occurred.
702
644
 
703
- Returns: Dictionary that summarizes the TargetPlatformModel properties (for display purposes).
645
+ Raises:
646
+ The exception raised in the 'with' block, if any.
704
647
 
648
+ Returns:
649
+ TargetPlatformModel: The validated TargetPlatformModel object.
705
650
  """
706
- return {"Model name": self.name,
707
- "Default quantization config": self.get_default_config().get_info(),
708
- "Operators sets": [o.get_info() for o in self.operator_set],
709
- "Fusing patterns": [f.get_info() for f in self.fusing_patterns]
710
- }
651
+ if exc_value is not None:
652
+ raise exc_value
653
+ self.__validate_model()
654
+ _current_tp_model.reset()
655
+ return self
711
656
 
712
657
  def show(self):
713
658
  """
@@ -715,17 +660,4 @@ class TargetPlatformModel(ImmutableClass):
715
660
  Display the TargetPlatformModel.
716
661
 
717
662
  """
718
- pprint.pprint(self.get_info(), sort_dicts=False)
719
-
720
- def set_simd_padding(self,
721
- is_simd_padding: bool):
722
- """
723
- Set flag is_simd_padding to indicate whether this TP model defines
724
- that padding due to SIMD constrains occurs.
725
-
726
- Args:
727
- is_simd_padding: Whether this TP model defines that padding due to SIMD constrains occurs.
728
-
729
- """
730
- self.is_simd_padding = is_simd_padding
731
-
663
+ pprint.pprint(self.get_info(), sort_dicts=False)