mct-nightly 2.2.0.20241230.534__py3-none-any.whl → 2.2.0.20250102.111338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {mct_nightly-2.2.0.20241230.534.dist-info → mct_nightly-2.2.0.20250102.111338.dist-info}/METADATA +8 -11
  2. {mct_nightly-2.2.0.20241230.534.dist-info → mct_nightly-2.2.0.20250102.111338.dist-info}/RECORD +19 -19
  3. {mct_nightly-2.2.0.20241230.534.dist-info → mct_nightly-2.2.0.20250102.111338.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +1 -1
  5. model_compression_toolkit/target_platform_capabilities/schema/v1.py +308 -173
  6. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +22 -22
  7. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +22 -22
  8. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +22 -22
  9. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +21 -21
  10. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +22 -22
  11. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +25 -25
  12. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +23 -23
  13. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +55 -40
  14. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +4 -6
  15. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +2 -4
  16. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +10 -10
  17. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +49 -46
  18. {mct_nightly-2.2.0.20241230.534.dist-info → mct_nightly-2.2.0.20250102.111338.dist-info}/LICENSE.md +0 -0
  19. {mct_nightly-2.2.0.20241230.534.dist-info → mct_nightly-2.2.0.20250102.111338.dist-info}/top_level.txt +0 -0
@@ -14,13 +14,13 @@
14
14
  # ==============================================================================
15
15
  import pprint
16
16
 
17
- from dataclasses import replace, dataclass, asdict, field
18
17
  from enum import Enum
19
- from typing import Dict, Any, Union, Tuple, List, Optional
18
+ from typing import Dict, Any, Union, Tuple, List, Optional, Literal, Annotated
20
19
  from mct_quantizers import QuantizationMethod
21
20
  from model_compression_toolkit.constants import FLOAT_BITWIDTH
22
21
  from model_compression_toolkit.logger import Logger
23
- from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST
22
+ from pydantic import BaseModel, Field, root_validator, validator, PositiveInt, PrivateAttr
23
+
24
24
 
25
25
  class OperatorSetNames(Enum):
26
26
  OPSET_CONV = "Conv"
@@ -92,8 +92,7 @@ class Signedness(Enum):
92
92
  UNSIGNED = 2
93
93
 
94
94
 
95
- @dataclass(frozen=True)
96
- class AttributeQuantizationConfig:
95
+ class AttributeQuantizationConfig(BaseModel):
97
96
  """
98
97
  Holds the quantization configuration of a weight attribute of a layer.
99
98
 
@@ -103,27 +102,22 @@ class AttributeQuantizationConfig:
103
102
  weights_per_channel_threshold (bool): Indicates whether to quantize the weights per-channel or per-tensor.
104
103
  enable_weights_quantization (bool): Indicates whether to quantize the model weights or not.
105
104
  lut_values_bitwidth (Optional[int]): Number of bits to use when quantizing in a look-up table.
106
- If None, defaults to 8 in hptq; otherwise, it uses the provided value.
105
+ If None, defaults to 8 in hptq; otherwise, it uses the provided value.
107
106
  """
108
107
  weights_quantization_method: QuantizationMethod = QuantizationMethod.POWER_OF_TWO
109
- weights_n_bits: int = FLOAT_BITWIDTH
108
+ weights_n_bits: PositiveInt = FLOAT_BITWIDTH
110
109
  weights_per_channel_threshold: bool = False
111
110
  enable_weights_quantization: bool = False
112
111
  lut_values_bitwidth: Optional[int] = None
113
112
 
114
- def __post_init__(self):
115
- """
116
- Post-initialization processing for input validation.
113
+ class Config:
114
+ # Makes the model immutable (frozen)
115
+ frozen = True
117
116
 
118
- Raises:
119
- Logger critical if attributes are of incorrect type or have invalid values.
120
- """
121
- if not isinstance(self.weights_n_bits, int) or self.weights_n_bits < 1:
122
- Logger.critical("weights_n_bits must be a positive integer.") # pragma: no cover
123
- if not isinstance(self.enable_weights_quantization, bool):
124
- Logger.critical("enable_weights_quantization must be a boolean.") # pragma: no cover
125
- if self.lut_values_bitwidth is not None and not isinstance(self.lut_values_bitwidth, int):
126
- Logger.critical("lut_values_bitwidth must be an integer or None.") # pragma: no cover
117
+ @property
118
+ def field_names(self) -> list:
119
+ """Return a list of field names for the model."""
120
+ return list(self.__fields__.keys())
127
121
 
128
122
  def clone_and_edit(self, **kwargs) -> 'AttributeQuantizationConfig':
129
123
  """
@@ -135,11 +129,10 @@ class AttributeQuantizationConfig:
135
129
  Returns:
136
130
  AttributeQuantizationConfig: A new instance of AttributeQuantizationConfig with updated attributes.
137
131
  """
138
- return replace(self, **kwargs)
132
+ return self.copy(update=kwargs)
139
133
 
140
134
 
141
- @dataclass(frozen=True)
142
- class OpQuantizationConfig:
135
+ class OpQuantizationConfig(BaseModel):
143
136
  """
144
137
  OpQuantizationConfig is a class to configure the quantization parameters of an operator.
145
138
 
@@ -148,39 +141,45 @@ class OpQuantizationConfig:
148
141
  attr_weights_configs_mapping (Dict[str, AttributeQuantizationConfig]): A mapping between an op attribute name and its quantization configuration.
149
142
  activation_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for activation quantization.
150
143
  activation_n_bits (int): Number of bits to quantize the activations.
151
- supported_input_activation_n_bits (int or Tuple[int]): Number of bits that operator accepts as input.
144
+ supported_input_activation_n_bits (Union[int, Tuple[int, ...]]): Number of bits that operator accepts as input.
152
145
  enable_activation_quantization (bool): Whether to quantize the model activations or not.
153
146
  quantization_preserving (bool): Whether quantization parameters should be the same for an operator's input and output.
154
- fixed_scale (float): Scale to use for an operator quantization parameters.
155
- fixed_zero_point (int): Zero-point to use for an operator quantization parameters.
156
- simd_size (int): Per op integer representing the Single Instruction, Multiple Data (SIMD) width of an operator. It indicates the number of data elements that can be fetched and processed simultaneously in a single instruction.
157
- signedness (bool): Set activation quantization signedness.
158
-
147
+ fixed_scale (Optional[float]): Scale to use for an operator quantization parameters.
148
+ fixed_zero_point (Optional[int]): Zero-point to use for an operator quantization parameters.
149
+ simd_size (Optional[int]): Per op integer representing the Single Instruction, Multiple Data (SIMD) width of an operator. It indicates the number of data elements that can be fetched and processed simultaneously in a single instruction.
150
+ signedness (Signedness): Set activation quantization signedness.
159
151
  """
160
152
  default_weight_attr_config: AttributeQuantizationConfig
161
153
  attr_weights_configs_mapping: Dict[str, AttributeQuantizationConfig]
162
154
  activation_quantization_method: QuantizationMethod
163
155
  activation_n_bits: int
164
- supported_input_activation_n_bits: Union[int, Tuple[int]]
156
+ supported_input_activation_n_bits: Union[int, Tuple[int, ...]]
165
157
  enable_activation_quantization: bool
166
158
  quantization_preserving: bool
167
- fixed_scale: float
168
- fixed_zero_point: int
169
- simd_size: int
159
+ fixed_scale: Optional[float]
160
+ fixed_zero_point: Optional[int]
161
+ simd_size: Optional[int]
170
162
  signedness: Signedness
171
163
 
172
- def __post_init__(self):
173
- """
174
- Post-initialization processing for input validation.
164
+ class Config:
165
+ frozen = True
175
166
 
176
- Raises:
177
- Logger critical if supported_input_activation_n_bits is not an int or a tuple of ints.
167
+ @validator('supported_input_activation_n_bits', pre=True, allow_reuse=True)
168
+ def validate_supported_input_activation_n_bits(cls, v):
178
169
  """
179
- if isinstance(self.supported_input_activation_n_bits, int):
180
- object.__setattr__(self, 'supported_input_activation_n_bits', (self.supported_input_activation_n_bits,))
181
- elif not isinstance(self.supported_input_activation_n_bits, tuple):
182
- Logger.critical(
183
- f"Supported_input_activation_n_bits only accepts int or tuple of ints, but got {type(self.supported_input_activation_n_bits)}") # pragma: no cover
170
+ Validate and process the supported_input_activation_n_bits field.
171
+ Converts an int to a tuple containing that int.
172
+ Ensures that if a tuple is provided, all elements are ints.
173
+ """
174
+
175
+ if isinstance(v, int):
176
+ v = (v,)
177
+
178
+ # When loading from JSON, lists are returned. If the value is a list, convert it to a tuple.
179
+ if isinstance(v, list):
180
+ v = tuple(v)
181
+
182
+ return v
184
183
 
185
184
  def get_info(self) -> Dict[str, Any]:
186
185
  """
@@ -189,9 +188,13 @@ class OpQuantizationConfig:
189
188
  Returns:
190
189
  dict: Information about the quantization configuration as a dictionary.
191
190
  """
192
- return asdict(self) # pragma: no cover
191
+ return self.dict() # pragma: no cover
193
192
 
194
- def clone_and_edit(self, attr_to_edit: Dict[str, Dict[str, Any]] = {}, **kwargs) -> 'OpQuantizationConfig':
193
+ def clone_and_edit(
194
+ self,
195
+ attr_to_edit: Dict[str, Dict[str, Any]] = {},
196
+ **kwargs: Any
197
+ ) -> 'OpQuantizationConfig':
195
198
  """
196
199
  Clone the quantization config and edit some of its attributes.
197
200
 
@@ -203,64 +206,87 @@ class OpQuantizationConfig:
203
206
  Returns:
204
207
  OpQuantizationConfig: Edited quantization configuration.
205
208
  """
206
-
207
209
  # Clone and update top-level attributes
208
- updated_config = replace(self, **kwargs)
210
+ updated_config = self.copy(update=kwargs)
209
211
 
210
212
  # Clone and update nested immutable dataclasses in `attr_weights_configs_mapping`
211
213
  updated_attr_mapping = {
212
214
  attr_name: (attr_cfg.clone_and_edit(**attr_to_edit[attr_name])
213
- if attr_name in attr_to_edit else attr_cfg)
215
+ if attr_name in attr_to_edit else attr_cfg)
214
216
  for attr_name, attr_cfg in updated_config.attr_weights_configs_mapping.items()
215
217
  }
216
218
 
217
219
  # Return a new instance with the updated attribute mapping
218
- return replace(updated_config, attr_weights_configs_mapping=updated_attr_mapping)
220
+ return updated_config.copy(update={'attr_weights_configs_mapping': updated_attr_mapping})
219
221
 
220
222
 
221
- @dataclass(frozen=True)
222
- class QuantizationConfigOptions:
223
+ class QuantizationConfigOptions(BaseModel):
223
224
  """
224
225
  QuantizationConfigOptions wraps a set of quantization configurations to consider during the quantization of an operator.
225
226
 
226
227
  Attributes:
227
- quantization_configurations (Tuple[OpQuantizationConfig]): Tuple of possible OpQuantizationConfig to gather.
228
+ quantization_configurations (Tuple[OpQuantizationConfig, ...]): Tuple of possible OpQuantizationConfig to gather.
228
229
  base_config (Optional[OpQuantizationConfig]): Fallback OpQuantizationConfig to use when optimizing the model in a non-mixed-precision manner.
229
230
  """
230
- quantization_configurations: Tuple[OpQuantizationConfig]
231
+ quantization_configurations: Tuple[OpQuantizationConfig, ...]
231
232
  base_config: Optional[OpQuantizationConfig] = None
232
233
 
233
- def __post_init__(self):
234
+ class Config:
235
+ frozen = True
236
+
237
+ @root_validator(pre=True, allow_reuse=True)
238
+ def validate_and_set_base_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
234
239
  """
235
- Post-initialization processing for input validation.
240
+ Validate and set the base_config based on quantization_configurations.
241
+
242
+ Args:
243
+ values (Dict[str, Any]): Input data.
236
244
 
237
- Raises:
238
- Logger critical if quantization_configurations is not a tuple, contains invalid elements, or if base_config is not set correctly.
245
+ Returns:
246
+ Dict[str, Any]: Modified input data with base_config set appropriately.
239
247
  """
240
- # Validate `quantization_configurations`
241
- if not isinstance(self.quantization_configurations, tuple):
248
+ quantization_configurations = values.get('quantization_configurations', ())
249
+ num_configs = len(quantization_configurations)
250
+ base_config = values.get('base_config')
251
+
252
+ if not isinstance(quantization_configurations, (tuple, list)):
242
253
  Logger.critical(
243
- f"'quantization_configurations' must be a tuple, but received: {type(self.quantization_configurations)}.") # pragma: no cover
244
- for cfg in self.quantization_configurations:
245
- if not isinstance(cfg, OpQuantizationConfig):
246
- Logger.critical(
247
- f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}.") # pragma: no cover
248
-
249
- # Handle base_config
250
- if len(self.quantization_configurations) > 1:
251
- if self.base_config is None:
252
- Logger.critical(f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization.") # pragma: no cover
253
- if not any(self.base_config == cfg for cfg in self.quantization_configurations):
254
- Logger.critical(f"'base_config' must be included in the quantization config options.") # pragma: no cover
255
- elif len(self.quantization_configurations) == 1:
256
- if self.base_config is None:
257
- object.__setattr__(self, 'base_config', self.quantization_configurations[0])
258
- elif self.base_config != self.quantization_configurations[0]:
254
+ f"'quantization_configurations' must be a list or tuple, but received: {type(quantization_configurations)}."
255
+ ) # pragma: no cover
256
+
257
+ if num_configs == 0:
258
+ Logger.critical(
259
+ "'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided configurations are empty."
260
+ ) # pragma: no cover
261
+
262
+ if base_config is None:
263
+ if num_configs > 1:
259
264
  Logger.critical(
260
- "'base_config' should be the same as the sole item in 'quantization_configurations'.") # pragma: no cover
265
+ "For multiple configurations, a 'base_config' is required for non-mixed-precision optimization."
266
+ ) # pragma: no cover
267
+ else:
268
+ # Automatically set base_config to the sole configuration
269
+ base_config = quantization_configurations[0]
261
270
 
262
- elif len(self.quantization_configurations) == 0:
263
- Logger.critical("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided configurations is empty.") # pragma: no cover
271
+
272
+ if base_config not in quantization_configurations:
273
+ Logger.critical(
274
+ "'base_config' must be included in the quantization config options."
275
+ ) # pragma: no cover
276
+
277
+ # if num_configs == 1:
278
+ # if base_config != quantization_configurations[0]:
279
+ # Logger.critical(
280
+ # "'base_config' should be the same as the sole item in 'quantization_configurations'."
281
+ # ) # pragma: no cover
282
+
283
+ values['base_config'] = base_config
284
+
285
+ # When loading from JSON, lists are returned. If the value is a list, convert it to a tuple.
286
+ if isinstance(quantization_configurations, list):
287
+ values['quantization_configurations'] = tuple(quantization_configurations)
288
+
289
+ return values
264
290
 
265
291
  def clone_and_edit(self, **kwargs) -> 'QuantizationConfigOptions':
266
292
  """
@@ -270,46 +296,71 @@ class QuantizationConfigOptions:
270
296
  **kwargs: Keyword arguments to edit in each configuration.
271
297
 
272
298
  Returns:
273
- A new instance of QuantizationConfigOptions with updated configurations.
299
+ QuantizationConfigOptions: A new instance with updated configurations.
274
300
  """
275
- updated_base_config = replace(self.base_config, **kwargs)
276
- updated_configs = [
277
- replace(cfg, **kwargs) for cfg in self.quantization_configurations
278
- ]
279
- return replace(self, base_config=updated_base_config, quantization_configurations=tuple(updated_configs))
301
+ # Clone and update base_config
302
+ updated_base_config = self.base_config.clone_and_edit(**kwargs) if self.base_config else None
303
+
304
+ # Clone and update all configurations
305
+ updated_configs = tuple(cfg.clone_and_edit(**kwargs) for cfg in self.quantization_configurations)
280
306
 
281
- def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs) -> 'QuantizationConfigOptions':
307
+ return self.copy(update={
308
+ 'base_config': updated_base_config,
309
+ 'quantization_configurations': updated_configs
310
+ })
311
+
312
+ def clone_and_edit_weight_attribute(
313
+ self,
314
+ attrs: Optional[List[str]] = None,
315
+ **kwargs
316
+ ) -> 'QuantizationConfigOptions':
282
317
  """
283
318
  Clones the quantization configurations and edits some of their attributes' parameters.
284
319
 
285
320
  Args:
286
- attrs (List[str]): Attributes names to clone and edit their configurations. If None, updates all attributes.
287
- **kwargs: Keyword arguments to edit in the attributes configuration.
321
+ attrs (Optional[List[str]]): Attribute names to clone and edit their configurations. If None, updates all attributes.
322
+ **kwargs: Keyword arguments to edit in the attributes' configuration.
288
323
 
289
324
  Returns:
290
- QuantizationConfigOptions: A new instance of QuantizationConfigOptions with edited attributes configurations.
325
+ QuantizationConfigOptions: A new instance with edited attributes configurations.
291
326
  """
292
327
  updated_base_config = self.base_config
293
328
  updated_configs = []
329
+
294
330
  for qc in self.quantization_configurations:
295
331
  if attrs is None:
296
332
  attrs_to_update = list(qc.attr_weights_configs_mapping.keys())
297
333
  else:
298
334
  attrs_to_update = attrs
335
+
299
336
  # Ensure all attributes exist in the config
300
337
  for attr in attrs_to_update:
301
338
  if attr not in qc.attr_weights_configs_mapping:
302
- Logger.critical(f"{attr} does not exist in {qc}.") # pragma: no cover
339
+ Logger.critical(f"Attribute '{attr}' does not exist in {qc}.") # pragma: no cover
340
+
341
+ # Update the specified attributes
303
342
  updated_attr_mapping = {
304
343
  attr: qc.attr_weights_configs_mapping[attr].clone_and_edit(**kwargs)
305
344
  for attr in attrs_to_update
306
345
  }
307
- if qc == updated_base_config:
308
- updated_base_config = replace(updated_base_config, attr_weights_configs_mapping=updated_attr_mapping)
309
- updated_configs.append(replace(qc, attr_weights_configs_mapping=updated_attr_mapping))
310
- return replace(self, base_config=updated_base_config, quantization_configurations=tuple(updated_configs))
311
346
 
312
- def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Optional[Dict[str, str]]) -> 'QuantizationConfigOptions':
347
+ # If the current config is the base_config, update it accordingly
348
+ if qc == self.base_config:
349
+ updated_base_config = qc.clone_and_edit(attr_weights_configs_mapping=updated_attr_mapping)
350
+
351
+ # Update the current config with the new attribute mappings
352
+ updated_cfg = qc.clone_and_edit(attr_weights_configs_mapping=updated_attr_mapping)
353
+ updated_configs.append(updated_cfg)
354
+
355
+ return self.copy(update={
356
+ 'base_config': updated_base_config,
357
+ 'quantization_configurations': tuple(updated_configs)
358
+ })
359
+
360
+ def clone_and_map_weights_attr_keys(
361
+ self,
362
+ layer_attrs_mapping: Optional[Dict[str, str]] = None
363
+ ) -> 'QuantizationConfigOptions':
313
364
  """
314
365
  Clones the quantization configurations and updates keys in attribute config mappings.
315
366
 
@@ -317,22 +368,32 @@ class QuantizationConfigOptions:
317
368
  layer_attrs_mapping (Optional[Dict[str, str]]): A mapping between attribute names.
318
369
 
319
370
  Returns:
320
- QuantizationConfigOptions: A new instance of QuantizationConfigOptions with updated attribute keys.
371
+ QuantizationConfigOptions: A new instance with updated attribute keys.
321
372
  """
322
- updated_configs = []
323
373
  new_base_config = self.base_config
374
+ updated_configs = []
375
+
324
376
  for qc in self.quantization_configurations:
325
377
  if layer_attrs_mapping is None:
326
- new_attr_mapping = {}
378
+ new_attr_mapping = qc.attr_weights_configs_mapping
327
379
  else:
328
380
  new_attr_mapping = {
329
381
  layer_attrs_mapping.get(attr, attr): cfg
330
382
  for attr, cfg in qc.attr_weights_configs_mapping.items()
331
383
  }
384
+
385
+ # If the current config is the base_config, update it accordingly
332
386
  if qc == self.base_config:
333
- new_base_config = replace(qc, attr_weights_configs_mapping=new_attr_mapping)
334
- updated_configs.append(replace(qc, attr_weights_configs_mapping=new_attr_mapping))
335
- return replace(self, base_config=new_base_config, quantization_configurations=tuple(updated_configs))
387
+ new_base_config = qc.clone_and_edit(attr_weights_configs_mapping=new_attr_mapping)
388
+
389
+ # Update the current config with the new attribute mappings
390
+ updated_cfg = qc.clone_and_edit(attr_weights_configs_mapping=new_attr_mapping)
391
+ updated_configs.append(updated_cfg)
392
+
393
+ return self.copy(update={
394
+ 'base_config': new_base_config,
395
+ 'quantization_configurations': tuple(updated_configs)
396
+ })
336
397
 
337
398
  def get_info(self) -> Dict[str, Any]:
338
399
  """
@@ -341,18 +402,16 @@ class QuantizationConfigOptions:
341
402
  Returns:
342
403
  dict: Information about the quantization configuration options as a dictionary.
343
404
  """
344
- return {f'option {i}': cfg.get_info() for i, cfg in enumerate(self.quantization_configurations)}
345
-
405
+ return {f'option_{i}': cfg.get_info() for i, cfg in enumerate(self.quantization_configurations)}
346
406
 
347
- @dataclass(frozen=True)
348
- class TargetPlatformModelComponent:
407
+ class TargetPlatformModelComponent(BaseModel):
349
408
  """
350
409
  Component of TargetPlatformModel (Fusing, OperatorsSet, etc.).
351
410
  """
352
- pass
411
+ class Config:
412
+ frozen = True
353
413
 
354
414
 
355
- @dataclass(frozen=True)
356
415
  class OperatorsSetBase(TargetPlatformModelComponent):
357
416
  """
358
417
  Base class to represent a set of a target platform model component of operator set types.
@@ -361,91 +420,151 @@ class OperatorsSetBase(TargetPlatformModelComponent):
361
420
  pass
362
421
 
363
422
 
364
- @dataclass(frozen=True)
365
423
  class OperatorsSet(OperatorsSetBase):
366
424
  """
367
425
  Set of operators that are represented by a unique label.
368
426
 
369
427
  Attributes:
370
- name (str): The set's label (must be unique within a TargetPlatformModel).
371
- qc_options (QuantizationConfigOptions): Configuration options to use for this set of operations.
372
- If None, it represents a fusing set.
373
- is_default (bool): Indicates whether this set is the default quantization configuration
374
- for the TargetPlatformModel or a fusing set.
428
+ name (Union[str, OperatorSetNames]): The set's label (must be unique within a TargetPlatformModel).
429
+ qc_options (Optional[QuantizationConfigOptions]): Configuration options to use for this set of operations.
430
+ If None, it represents a fusing set.
431
+ type (Literal["OperatorsSet"]): Fixed type identifier.
375
432
  """
376
433
  name: Union[str, OperatorSetNames]
377
- qc_options: QuantizationConfigOptions = None
434
+ qc_options: Optional[QuantizationConfigOptions] = None
435
+
436
+ # Define a private attribute _type
437
+ type: Literal["OperatorsSet"] = "OperatorsSet"
438
+
439
+ class Config:
440
+ frozen = True
378
441
 
379
442
  def get_info(self) -> Dict[str, Any]:
380
443
  """
381
444
  Get information about the set as a dictionary.
382
445
 
383
446
  Returns:
384
- Dict[str, Any]: A dictionary containing the set name and
385
- whether it is the default quantization configuration.
447
+ Dict[str, Any]: A dictionary containing the set name.
386
448
  """
387
449
  return {"name": self.name}
388
450
 
389
451
 
390
- @dataclass(frozen=True)
391
452
  class OperatorSetConcat(OperatorsSetBase):
392
453
  """
393
454
  Concatenate a tuple of operator sets to treat them similarly in different places (like fusing).
394
455
 
395
456
  Attributes:
396
- operators_set (Tuple[OperatorsSet]): Tuple of operator sets to group.
397
- qc_options (None): Configuration options for the set, always None for concatenated sets.
398
- name (str): Concatenated name generated from the names of the operator sets.
457
+ operators_set (Tuple[OperatorsSet, ...]): Tuple of operator sets to group.
458
+ name (Optional[str]): Concatenated name generated from the names of the operator sets.
399
459
  """
400
- operators_set: Tuple[OperatorsSet]
401
- qc_options: None = field(default=None, init=False)
460
+ operators_set: Tuple[OperatorsSet, ...]
461
+ name: Optional[str] = None # Will be set in the validator if not given
402
462
 
403
- def __post_init__(self):
463
+ # Define a private attribute _type
464
+ type: Literal["OperatorSetConcat"] = "OperatorSetConcat"
465
+
466
+ class Config:
467
+ frozen = True
468
+
469
+ @root_validator(pre=True, allow_reuse=True)
470
+ def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
404
471
  """
405
- Post-initialization processing to generate the concatenated name and set it as the `name` attribute.
472
+ Validate the input and set the concatenated name based on the operators_set.
473
+
474
+ Args:
475
+ values (Dict[str, Any]): Input data.
476
+
477
+ Returns:
478
+ Dict[str, Any]: Modified input data with 'name' set.
479
+ """
480
+ operators_set = values['operators_set']
481
+
482
+ if len(operators_set) < 1:
483
+ Logger.critical("'operators_set' must contain at least one OperatorsSet") # pragma: no cover
406
484
 
407
- Calls the parent class's __post_init__ method and creates a concatenated name
408
- by joining the names of all operator sets in `operators_set`.
485
+ if values.get('name') is None:
486
+ # Generate the concatenated name from the operator sets
487
+ concatenated_name = "_".join([
488
+ op.name.value if isinstance(op.name, OperatorSetNames) else op.name
489
+ for op in operators_set
490
+ ])
491
+ values['name'] = concatenated_name
492
+
493
+ return values
494
+
495
+ def get_info(self) -> Dict[str, Any]:
409
496
  """
410
- # Generate the concatenated name from the operator sets
411
- concatenated_name = "_".join([op.name.value if hasattr(op.name, "value") else op.name for op in self.operators_set])
412
- # Set the inherited name attribute using `object.__setattr__` since the dataclass is frozen
413
- object.__setattr__(self, "name", concatenated_name)
497
+ Get information about the concatenated operator sets as a dictionary.
414
498
 
499
+ Returns:
500
+ Dict[str, Any]: A dictionary containing the concatenated name and operator sets information.
501
+ """
502
+ return {
503
+ "name": self.name,
504
+ "operators_set": [op.get_info() for op in self.operators_set]
505
+ }
415
506
 
416
- @dataclass(frozen=True)
417
507
  class Fusing(TargetPlatformModelComponent):
418
508
  """
419
509
  Fusing defines a tuple of operators that should be combined and treated as a single operator,
420
510
  hence no quantization is applied between them.
421
511
 
422
512
  Attributes:
423
- operator_groups (Tuple[Union[OperatorsSet, OperatorSetConcat]]): A tuple of operator groups,
513
+ operator_groups (Tuple[Union[OperatorsSet, OperatorSetConcat], ...]): A tuple of operator groups,
424
514
  each being either an OperatorSetConcat or an OperatorsSet.
425
- name (str): The name for the Fusing instance. If not provided, it is generated from the operator groups' names.
515
+ name (Optional[str]): The name for the Fusing instance. If not provided, it is generated from the operator groups' names.
426
516
  """
427
- operator_groups: Tuple[Union[OperatorsSet, OperatorSetConcat]]
517
+ operator_groups: Tuple[Annotated[Union[OperatorsSet, OperatorSetConcat], Field(discriminator='type')], ...]
518
+ name: Optional[str] = None # Will be set in the validator if not given.
519
+
520
+ class Config:
521
+ frozen = True
428
522
 
429
- def __post_init__(self):
523
+ @root_validator(pre=True, allow_reuse=True)
524
+ def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
430
525
  """
431
- Post-initialization processing for input validation and name generation.
526
+ Validate the operator_groups and set the name by concatenating operator group names.
432
527
 
433
- Calls the parent class's __post_init__ method, validates the operator_groups,
434
- and generates the name if not explicitly provided.
528
+ Args:
529
+ values (Dict[str, Any]): Input data.
435
530
 
436
- Raises:
437
- Logger critical if operator_groups is not a tuple or if it contains fewer than two operators.
531
+ Returns:
532
+ Dict[str, Any]: Modified input data with 'name' set.
438
533
  """
439
- # Validate the operator_groups
440
- if not isinstance(self.operator_groups, tuple):
441
- Logger.critical(
442
- f"Operator groups should be of type 'tuple' but is {type(self.operator_groups)}.") # pragma: no cover
443
- if len(self.operator_groups) < 2:
444
- Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
534
+ operator_groups = values.get('operator_groups')
445
535
 
446
- # Generate the name from the operator groups if not provided
447
- generated_name = '_'.join([x.name.value if hasattr(x.name, 'value') else x.name for x in self.operator_groups])
448
- object.__setattr__(self, 'name', generated_name)
536
+ # When loading from JSON, lists are returned. If the value is a list, convert it to a tuple.
537
+ if isinstance(operator_groups, list):
538
+ values['operator_groups'] = tuple(operator_groups)
539
+
540
+ if values.get('name') is None:
541
+ # Generate the concatenated name from the operator groups
542
+ concatenated_name = "_".join([
543
+ op.name.value if isinstance(op.name, OperatorSetNames) else op.name
544
+ for op in values['operator_groups']
545
+ ])
546
+ values['name'] = concatenated_name
547
+
548
+ return values
549
+
550
+ @root_validator(allow_reuse=True)
551
+ def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]:
552
+ """
553
+ Perform validation after the model has been instantiated.
554
+
555
+ Args:
556
+ values (Dict[str, Any]): The instantiated fusing.
557
+
558
+ Returns:
559
+ Dict[str, Any]: The validated values.
560
+ """
561
+ operator_groups = values.get('operator_groups')
562
+
563
+ # Validate that there are at least two operator groups
564
+ if len(operator_groups) < 2:
565
+ Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
566
+
567
+ return values
449
568
 
450
569
  def contains(self, other: Any) -> bool:
451
570
  """
@@ -483,56 +602,75 @@ class Fusing(TargetPlatformModelComponent):
483
602
  or just the sequence of operator groups if no name is set.
484
603
  """
485
604
  if self.name is not None:
486
- return {self.name: ' -> '.join([x.name for x in self.operator_groups])}
487
- return ' -> '.join([x.name for x in self.operator_groups])
488
-
605
+ return {
606
+ self.name: ' -> '.join([
607
+ x.name.value if isinstance(x.name, OperatorSetNames) else x.name
608
+ for x in self.operator_groups
609
+ ])
610
+ }
611
+ return ' -> '.join([
612
+ x.name.value if isinstance(x.name, OperatorSetNames) else x.name
613
+ for x in self.operator_groups
614
+ ])
489
615
 
490
- @dataclass(frozen=True)
491
- class TargetPlatformModel:
616
+ class TargetPlatformModel(BaseModel):
492
617
  """
493
618
  Represents the hardware configuration used for quantized model inference.
494
619
 
495
620
  Attributes:
496
621
  default_qco (QuantizationConfigOptions): Default quantization configuration options for the model.
622
+ operator_set (Optional[Tuple[OperatorsSet, ...]]): Tuple of operator sets within the model.
623
+ fusing_patterns (Optional[Tuple[Fusing, ...]]): Tuple of fusing patterns for the model.
497
624
  tpc_minor_version (Optional[int]): Minor version of the Target Platform Configuration.
498
625
  tpc_patch_version (Optional[int]): Patch version of the Target Platform Configuration.
499
626
  tpc_platform_type (Optional[str]): Type of the platform for the Target Platform Configuration.
500
627
  add_metadata (bool): Flag to determine if metadata should be added.
501
628
  name (str): Name of the Target Platform Model.
502
- operator_set (Tuple[OperatorsSetBase]): Tuple of operator sets within the model.
503
- fusing_patterns (Tuple[Fusing]): Tuple of fusing patterns for the model.
504
629
  is_simd_padding (bool): Indicates if SIMD padding is applied.
505
630
  SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
506
631
  """
507
632
  default_qco: QuantizationConfigOptions
633
+ operator_set: Optional[Tuple[OperatorsSet, ...]]
634
+ fusing_patterns: Optional[Tuple[Fusing, ...]]
508
635
  tpc_minor_version: Optional[int]
509
636
  tpc_patch_version: Optional[int]
510
637
  tpc_platform_type: Optional[str]
511
638
  add_metadata: bool = True
512
- name: str = "default_tp_model"
513
- operator_set: Tuple[OperatorsSetBase] = None
514
- fusing_patterns: Tuple[Fusing] = None
639
+ name: Optional[str] = "default_tp_model"
515
640
  is_simd_padding: bool = False
516
641
 
517
642
  SCHEMA_VERSION: int = 1
518
643
 
519
- def __post_init__(self):
644
+ class Config:
645
+ frozen = True
646
+
647
+ @root_validator(allow_reuse=True)
648
+ def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]:
520
649
  """
521
- Post-initialization processing for input validation.
650
+ Perform validation after the model has been instantiated.
651
+
652
+ Args:
653
+ values (Dict[str, Any]): The instantiated target platform model.
522
654
 
523
- Raises:
524
- Logger critical if the default_qco is not an instance of QuantizationConfigOptions
525
- or if it contains more than one quantization configuration.
655
+ Returns:
656
+ Dict[str, Any]: The validated values.
526
657
  """
527
658
  # Validate `default_qco`
528
- if not isinstance(self.default_qco, QuantizationConfigOptions):
529
- Logger.critical("'default_qco' must be an instance of QuantizationConfigOptions.") # pragma: no cover
530
- if len(self.default_qco.quantization_configurations) != 1:
531
- Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
532
-
533
- opsets_names = [op.name.value if hasattr(op.name, "value") else op.name for op in self.operator_set] if self.operator_set else []
534
- if len(set(opsets_names)) != len(opsets_names):
535
- Logger.critical("Operator Sets must have unique names.") # pragma: no cover
659
+ default_qco = values.get('default_qco')
660
+ if len(default_qco.quantization_configurations) != 1:
661
+ Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
662
+
663
+ # Validate `operator_set` uniqueness
664
+ operator_set = values.get('operator_set')
665
+ if operator_set is not None:
666
+ opsets_names = [
667
+ op.name.value if isinstance(op.name, OperatorSetNames) else op.name
668
+ for op in operator_set
669
+ ]
670
+ if len(set(opsets_names)) != len(opsets_names):
671
+ Logger.critical("Operator Sets must have unique names.") # pragma: no cover
672
+
673
+ return values
536
674
 
537
675
  def get_info(self) -> Dict[str, Any]:
538
676
  """
@@ -547,11 +685,8 @@ class TargetPlatformModel:
547
685
  "Fusing patterns": [f.get_info() for f in self.fusing_patterns] if self.fusing_patterns else [],
548
686
  }
549
687
 
550
-
551
688
  def show(self):
552
689
  """
553
-
554
690
  Display the TargetPlatformModel.
555
-
556
691
  """
557
692
  pprint.pprint(self.get_info(), sort_dicts=False)