mct-nightly 2.2.0.20241221.519__py3-none-any.whl → 2.2.0.20241223.525__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {mct_nightly-2.2.0.20241221.519.dist-info → mct_nightly-2.2.0.20241223.525.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20241221.519.dist-info → mct_nightly-2.2.0.20241223.525.dist-info}/RECORD +26 -28
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/graph/base_graph.py +1 -1
  5. model_compression_toolkit/core/common/graph/base_node.py +3 -3
  6. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +4 -4
  7. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +2 -2
  8. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +1 -0
  9. model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +4 -5
  10. model_compression_toolkit/target_platform_capabilities/schema/v1.py +63 -170
  11. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -1
  12. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +1 -1
  13. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +7 -4
  14. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +50 -51
  15. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +54 -52
  16. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +57 -53
  17. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +52 -51
  18. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +53 -51
  19. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +59 -57
  20. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +54 -52
  21. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +90 -83
  22. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +26 -24
  23. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +57 -55
  24. model_compression_toolkit/target_platform_capabilities/target_platform/current_tp_model.py +0 -67
  25. model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +0 -30
  26. {mct_nightly-2.2.0.20241221.519.dist-info → mct_nightly-2.2.0.20241223.525.dist-info}/LICENSE.md +0 -0
  27. {mct_nightly-2.2.0.20241221.519.dist-info → mct_nightly-2.2.0.20241223.525.dist-info}/WHEEL +0 -0
  28. {mct_nightly-2.2.0.20241221.519.dist-info → mct_nightly-2.2.0.20241223.525.dist-info}/top_level.txt +0 -0
@@ -21,8 +21,6 @@ from mct_quantizers import QuantizationMethod
21
21
  from model_compression_toolkit.constants import FLOAT_BITWIDTH
22
22
  from model_compression_toolkit.logger import Logger
23
23
  from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST
24
- from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import \
25
- _current_tp_model
26
24
 
27
25
  class OperatorSetNames(Enum):
28
26
  OPSET_CONV = "Conv"
@@ -61,7 +59,6 @@ class OperatorSetNames(Enum):
61
59
  OPSET_DROPOUT = "Dropout"
62
60
  OPSET_SPLIT = "Split"
63
61
  OPSET_CHUNK = "Chunk"
64
- OPSET_UNBIND = "Unbind"
65
62
  OPSET_MAXPOOL = "MaxPool"
66
63
  OPSET_SIZE = "Size"
67
64
  OPSET_SHAPE = "Shape"
@@ -74,6 +71,7 @@ class OperatorSetNames(Enum):
74
71
  OPSET_ZERO_PADDING2d = "ZeroPadding2D"
75
72
  OPSET_CAST = "Cast"
76
73
  OPSET_STRIDED_SLICE = "StridedSlice"
74
+ OPSET_SSD_POST_PROCESS = "SSDPostProcess"
77
75
 
78
76
  @classmethod
79
77
  def get_values(cls):
@@ -225,10 +223,10 @@ class QuantizationConfigOptions:
225
223
  QuantizationConfigOptions wraps a set of quantization configurations to consider during the quantization of an operator.
226
224
 
227
225
  Attributes:
228
- quantization_config_list (List[OpQuantizationConfig]): List of possible OpQuantizationConfig to gather.
226
+ quantization_configurations (Tuple[OpQuantizationConfig]): Tuple of possible OpQuantizationConfig to gather.
229
227
  base_config (Optional[OpQuantizationConfig]): Fallback OpQuantizationConfig to use when optimizing the model in a non-mixed-precision manner.
230
228
  """
231
- quantization_config_list: List[OpQuantizationConfig]
229
+ quantization_configurations: Tuple[OpQuantizationConfig]
232
230
  base_config: Optional[OpQuantizationConfig] = None
233
231
 
234
232
  def __post_init__(self):
@@ -236,32 +234,32 @@ class QuantizationConfigOptions:
236
234
  Post-initialization processing for input validation.
237
235
 
238
236
  Raises:
239
- Logger critical if quantization_config_list is not a list, contains invalid elements, or if base_config is not set correctly.
237
+ Logger critical if quantization_configurations is not a tuple, contains invalid elements, or if base_config is not set correctly.
240
238
  """
241
- # Validate `quantization_config_list`
242
- if not isinstance(self.quantization_config_list, list):
239
+ # Validate `quantization_configurations`
240
+ if not isinstance(self.quantization_configurations, tuple):
243
241
  Logger.critical(
244
- f"'quantization_config_list' must be a list, but received: {type(self.quantization_config_list)}.") # pragma: no cover
245
- for cfg in self.quantization_config_list:
242
+ f"'quantization_configurations' must be a tuple, but received: {type(self.quantization_configurations)}.") # pragma: no cover
243
+ for cfg in self.quantization_configurations:
246
244
  if not isinstance(cfg, OpQuantizationConfig):
247
245
  Logger.critical(
248
246
  f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}.") # pragma: no cover
249
247
 
250
248
  # Handle base_config
251
- if len(self.quantization_config_list) > 1:
249
+ if len(self.quantization_configurations) > 1:
252
250
  if self.base_config is None:
253
251
  Logger.critical(f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization.") # pragma: no cover
254
- if not any(self.base_config == cfg for cfg in self.quantization_config_list):
255
- Logger.critical(f"'base_config' must be included in the quantization config options list.") # pragma: no cover
256
- elif len(self.quantization_config_list) == 1:
252
+ if not any(self.base_config == cfg for cfg in self.quantization_configurations):
253
+ Logger.critical(f"'base_config' must be included in the quantization config options.") # pragma: no cover
254
+ elif len(self.quantization_configurations) == 1:
257
255
  if self.base_config is None:
258
- object.__setattr__(self, 'base_config', self.quantization_config_list[0])
259
- elif self.base_config != self.quantization_config_list[0]:
256
+ object.__setattr__(self, 'base_config', self.quantization_configurations[0])
257
+ elif self.base_config != self.quantization_configurations[0]:
260
258
  Logger.critical(
261
- "'base_config' should be the same as the sole item in 'quantization_config_list'.") # pragma: no cover
259
+ "'base_config' should be the same as the sole item in 'quantization_configurations'.") # pragma: no cover
262
260
 
263
- elif len(self.quantization_config_list) == 0:
264
- Logger.critical("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.") # pragma: no cover
261
+ elif len(self.quantization_configurations) == 0:
262
+ Logger.critical("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided configurations is empty.") # pragma: no cover
265
263
 
266
264
  def clone_and_edit(self, **kwargs) -> 'QuantizationConfigOptions':
267
265
  """
@@ -274,10 +272,10 @@ class QuantizationConfigOptions:
274
272
  A new instance of QuantizationConfigOptions with updated configurations.
275
273
  """
276
274
  updated_base_config = replace(self.base_config, **kwargs)
277
- updated_configs_list = [
278
- replace(cfg, **kwargs) for cfg in self.quantization_config_list
275
+ updated_configs = [
276
+ replace(cfg, **kwargs) for cfg in self.quantization_configurations
279
277
  ]
280
- return replace(self, base_config=updated_base_config, quantization_config_list=updated_configs_list)
278
+ return replace(self, base_config=updated_base_config, quantization_configurations=tuple(updated_configs))
281
279
 
282
280
  def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs) -> 'QuantizationConfigOptions':
283
281
  """
@@ -292,7 +290,7 @@ class QuantizationConfigOptions:
292
290
  """
293
291
  updated_base_config = self.base_config
294
292
  updated_configs = []
295
- for qc in self.quantization_config_list:
293
+ for qc in self.quantization_configurations:
296
294
  if attrs is None:
297
295
  attrs_to_update = list(qc.attr_weights_configs_mapping.keys())
298
296
  else:
@@ -300,7 +298,7 @@ class QuantizationConfigOptions:
300
298
  # Ensure all attributes exist in the config
301
299
  for attr in attrs_to_update:
302
300
  if attr not in qc.attr_weights_configs_mapping:
303
- Logger.critical(f"{attr} does not exist in {qc}.")
301
+ Logger.critical(f"{attr} does not exist in {qc}.") # pragma: no cover
304
302
  updated_attr_mapping = {
305
303
  attr: qc.attr_weights_configs_mapping[attr].clone_and_edit(**kwargs)
306
304
  for attr in attrs_to_update
@@ -308,7 +306,7 @@ class QuantizationConfigOptions:
308
306
  if qc == updated_base_config:
309
307
  updated_base_config = replace(updated_base_config, attr_weights_configs_mapping=updated_attr_mapping)
310
308
  updated_configs.append(replace(qc, attr_weights_configs_mapping=updated_attr_mapping))
311
- return replace(self, base_config=updated_base_config, quantization_config_list=updated_configs)
309
+ return replace(self, base_config=updated_base_config, quantization_configurations=tuple(updated_configs))
312
310
 
313
311
  def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Optional[Dict[str, str]]) -> 'QuantizationConfigOptions':
314
312
  """
@@ -322,7 +320,7 @@ class QuantizationConfigOptions:
322
320
  """
323
321
  updated_configs = []
324
322
  new_base_config = self.base_config
325
- for qc in self.quantization_config_list:
323
+ for qc in self.quantization_configurations:
326
324
  if layer_attrs_mapping is None:
327
325
  new_attr_mapping = {}
328
326
  else:
@@ -333,7 +331,7 @@ class QuantizationConfigOptions:
333
331
  if qc == self.base_config:
334
332
  new_base_config = replace(qc, attr_weights_configs_mapping=new_attr_mapping)
335
333
  updated_configs.append(replace(qc, attr_weights_configs_mapping=new_attr_mapping))
336
- return replace(self, base_config=new_base_config, quantization_config_list=updated_configs)
334
+ return replace(self, base_config=new_base_config, quantization_configurations=tuple(updated_configs))
337
335
 
338
336
  def get_info(self) -> Dict[str, Any]:
339
337
  """
@@ -342,7 +340,7 @@ class QuantizationConfigOptions:
342
340
  Returns:
343
341
  dict: Information about the quantization configuration options as a dictionary.
344
342
  """
345
- return {f'option {i}': cfg.get_info() for i, cfg in enumerate(self.quantization_config_list)}
343
+ return {f'option {i}': cfg.get_info() for i, cfg in enumerate(self.quantization_configurations)}
346
344
 
347
345
 
348
346
  @dataclass(frozen=True)
@@ -350,22 +348,7 @@ class TargetPlatformModelComponent:
350
348
  """
351
349
  Component of TargetPlatformModel (Fusing, OperatorsSet, etc.).
352
350
  """
353
-
354
- def __post_init__(self):
355
- """
356
- Post-initialization to register the component with the current TargetPlatformModel.
357
- """
358
- _current_tp_model.get().append_component(self)
359
-
360
- def get_info(self) -> Dict[str, Any]:
361
- """
362
- Get information about the component to display.
363
-
364
- Returns:
365
- Dict[str, Any]: Returns an empty dictionary. The actual component should override
366
- this method to provide relevant information.
367
- """
368
- return {}
351
+ pass
369
352
 
370
353
 
371
354
  @dataclass(frozen=True)
@@ -374,12 +357,7 @@ class OperatorsSetBase(TargetPlatformModelComponent):
374
357
  Base class to represent a set of a target platform model component of operator set types.
375
358
  Inherits from TargetPlatformModelComponent.
376
359
  """
377
- def __post_init__(self):
378
- """
379
- Post-initialization to ensure the component is registered with the TargetPlatformModel.
380
- Calls the parent class's __post_init__ method to append this component to the current TargetPlatformModel.
381
- """
382
- super().__post_init__()
360
+ pass
383
361
 
384
362
 
385
363
  @dataclass(frozen=True)
@@ -394,23 +372,9 @@ class OperatorsSet(OperatorsSetBase):
394
372
  is_default (bool): Indicates whether this set is the default quantization configuration
395
373
  for the TargetPlatformModel or a fusing set.
396
374
  """
397
- name: str
375
+ name: Union[str, OperatorSetNames]
398
376
  qc_options: QuantizationConfigOptions = None
399
377
 
400
- def __post_init__(self):
401
- """
402
- Post-initialization processing to mark the operator set as default if applicable.
403
-
404
- Calls the parent class's __post_init__ method and sets `is_default` to True
405
- if this set corresponds to the default quantization configuration for the
406
- TargetPlatformModel or if it is a fusing set.
407
-
408
- """
409
- super().__post_init__()
410
- is_fusing_set = self.qc_options is None
411
- is_default = _current_tp_model.get().default_qco == self.qc_options or is_fusing_set
412
- object.__setattr__(self, 'is_default', is_default)
413
-
414
378
  def get_info(self) -> Dict[str, Any]:
415
379
  """
416
380
  Get information about the set as a dictionary.
@@ -419,83 +383,67 @@ class OperatorsSet(OperatorsSetBase):
419
383
  Dict[str, Any]: A dictionary containing the set name and
420
384
  whether it is the default quantization configuration.
421
385
  """
422
- return {"name": self.name,
423
- "is_default_qc": self.is_default}
386
+ return {"name": self.name}
424
387
 
425
388
 
426
389
  @dataclass(frozen=True)
427
390
  class OperatorSetConcat(OperatorsSetBase):
428
391
  """
429
- Concatenate a list of operator sets to treat them similarly in different places (like fusing).
392
+ Concatenate a tuple of operator sets to treat them similarly in different places (like fusing).
430
393
 
431
394
  Attributes:
432
- op_set_list (List[OperatorsSet]): List of operator sets to group.
395
+ operators_set (Tuple[OperatorsSet]): Tuple of operator sets to group.
433
396
  qc_options (None): Configuration options for the set, always None for concatenated sets.
434
- name (str): Concatenated name generated from the names of the operator sets in the list.
397
+ name (str): Concatenated name generated from the names of the operator sets.
435
398
  """
436
- op_set_list: List[OperatorsSet] = field(default_factory=list)
399
+ operators_set: Tuple[OperatorsSet]
437
400
  qc_options: None = field(default=None, init=False)
438
- name: str = None
439
401
 
440
402
  def __post_init__(self):
441
403
  """
442
404
  Post-initialization processing to generate the concatenated name and set it as the `name` attribute.
443
405
 
444
406
  Calls the parent class's __post_init__ method and creates a concatenated name
445
- by joining the names of all operator sets in `op_set_list`.
407
+ by joining the names of all operator sets in `operators_set`.
446
408
  """
447
- super().__post_init__()
448
409
  # Generate the concatenated name from the operator sets
449
- concatenated_name = "_".join([op.name for op in self.op_set_list])
410
+ concatenated_name = "_".join([op.name.value if hasattr(op.name, "value") else op.name for op in self.operators_set])
450
411
  # Set the inherited name attribute using `object.__setattr__` since the dataclass is frozen
451
412
  object.__setattr__(self, "name", concatenated_name)
452
413
 
453
- def get_info(self) -> Dict[str, Any]:
454
- """
455
- Get information about the concatenated set as a dictionary.
456
-
457
- Returns:
458
- Dict[str, Any]: A dictionary containing the concatenated name and
459
- the list of names of the operator sets in `op_set_list`.
460
- """
461
- return {"name": self.name,
462
- OPS_SET_LIST: [s.name for s in self.op_set_list]}
463
-
464
414
 
465
415
  @dataclass(frozen=True)
466
416
  class Fusing(TargetPlatformModelComponent):
467
417
  """
468
- Fusing defines a list of operators that should be combined and treated as a single operator,
418
+ Fusing defines a tuple of operators that should be combined and treated as a single operator,
469
419
  hence no quantization is applied between them.
470
420
 
471
421
  Attributes:
472
- operator_groups_list (Tuple[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups,
422
+ operator_groups (Tuple[Union[OperatorsSet, OperatorSetConcat]]): A tuple of operator groups,
473
423
  each being either an OperatorSetConcat or an OperatorsSet.
474
424
  name (str): The name for the Fusing instance. If not provided, it is generated from the operator groups' names.
475
425
  """
476
- operator_groups_list: Tuple[Union[OperatorsSet, OperatorSetConcat]]
477
- name: str = None
426
+ operator_groups: Tuple[Union[OperatorsSet, OperatorSetConcat]]
478
427
 
479
428
  def __post_init__(self):
480
429
  """
481
430
  Post-initialization processing for input validation and name generation.
482
431
 
483
- Calls the parent class's __post_init__ method, validates the operator_groups_list,
432
+ Calls the parent class's __post_init__ method, validates the operator_groups,
484
433
  and generates the name if not explicitly provided.
485
434
 
486
435
  Raises:
487
- Logger critical if operator_groups_list is not a list or if it contains fewer than two operators.
436
+ Logger critical if operator_groups is not a tuple or if it contains fewer than two operators.
488
437
  """
489
- super().__post_init__()
490
- # Validate the operator_groups_list
491
- if not isinstance(self.operator_groups_list, list):
438
+ # Validate the operator_groups
439
+ if not isinstance(self.operator_groups, tuple):
492
440
  Logger.critical(
493
- f"List of operator groups should be of type list but is {type(self.operator_groups_list)}.") # pragma: no cover
494
- if len(self.operator_groups_list) < 2:
441
+ f"Operator groups should be of type 'tuple' but is {type(self.operator_groups)}.") # pragma: no cover
442
+ if len(self.operator_groups) < 2:
495
443
  Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
496
444
 
497
445
  # Generate the name from the operator groups if not provided
498
- generated_name = '_'.join([x.name for x in self.operator_groups_list])
446
+ generated_name = '_'.join([x.name.value if hasattr(x.name, 'value') else x.name for x in self.operator_groups])
499
447
  object.__setattr__(self, 'name', generated_name)
500
448
 
501
449
  def contains(self, other: Any) -> bool:
@@ -512,11 +460,11 @@ class Fusing(TargetPlatformModelComponent):
512
460
  return False
513
461
 
514
462
  # Check for containment by comparing operator groups
515
- for i in range(len(self.operator_groups_list) - len(other.operator_groups_list) + 1):
516
- for j in range(len(other.operator_groups_list)):
517
- if self.operator_groups_list[i + j] != other.operator_groups_list[j] and not (
518
- isinstance(self.operator_groups_list[i + j], OperatorSetConcat) and (
519
- other.operator_groups_list[j] in self.operator_groups_list[i + j].op_set_list)):
463
+ for i in range(len(self.operator_groups) - len(other.operator_groups) + 1):
464
+ for j in range(len(other.operator_groups)):
465
+ if self.operator_groups[i + j] != other.operator_groups[j] and not (
466
+ isinstance(self.operator_groups[i + j], OperatorSetConcat) and (
467
+ other.operator_groups[j] in self.operator_groups[i + j].operators_set)):
520
468
  break
521
469
  else:
522
470
  # If all checks pass, the other Fusing instance is contained
@@ -534,8 +482,8 @@ class Fusing(TargetPlatformModelComponent):
534
482
  or just the sequence of operator groups if no name is set.
535
483
  """
536
484
  if self.name is not None:
537
- return {self.name: ' -> '.join([x.name for x in self.operator_groups_list])}
538
- return ' -> '.join([x.name for x in self.operator_groups_list])
485
+ return {self.name: ' -> '.join([x.name for x in self.operator_groups])}
486
+ return ' -> '.join([x.name for x in self.operator_groups])
539
487
 
540
488
 
541
489
  @dataclass(frozen=True)
@@ -550,8 +498,8 @@ class TargetPlatformModel:
550
498
  tpc_platform_type (Optional[str]): Type of the platform for the Target Platform Configuration.
551
499
  add_metadata (bool): Flag to determine if metadata should be added.
552
500
  name (str): Name of the Target Platform Model.
553
- operator_set (List[OperatorsSetBase]): List of operator sets within the model.
554
- fusing_patterns (List[Fusing]): List of fusing patterns for the model.
501
+ operator_set (Tuple[OperatorsSetBase]): Tuple of operator sets within the model.
502
+ fusing_patterns (Tuple[Fusing]): Tuple of fusing patterns for the model.
555
503
  is_simd_padding (bool): Indicates if SIMD padding is applied.
556
504
  SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
557
505
  """
@@ -561,8 +509,8 @@ class TargetPlatformModel:
561
509
  tpc_platform_type: Optional[str]
562
510
  add_metadata: bool = True
563
511
  name: str = "default_tp_model"
564
- operator_set: List[OperatorsSetBase] = field(default_factory=list)
565
- fusing_patterns: List[Fusing] = field(default_factory=list)
512
+ operator_set: Tuple[OperatorsSetBase] = None
513
+ fusing_patterns: Tuple[Fusing] = None
566
514
  is_simd_padding: bool = False
567
515
 
568
516
  SCHEMA_VERSION: int = 1
@@ -578,26 +526,12 @@ class TargetPlatformModel:
578
526
  # Validate `default_qco`
579
527
  if not isinstance(self.default_qco, QuantizationConfigOptions):
580
528
  Logger.critical("'default_qco' must be an instance of QuantizationConfigOptions.") # pragma: no cover
581
- if len(self.default_qco.quantization_config_list) != 1:
529
+ if len(self.default_qco.quantization_configurations) != 1:
582
530
  Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
583
531
 
584
- def append_component(self, tp_model_component: TargetPlatformModelComponent):
585
- """
586
- Attach a TargetPlatformModel component to the model (like Fusing or OperatorsSet).
587
-
588
- Args:
589
- tp_model_component (TargetPlatformModelComponent): Component to attach to the model.
590
-
591
- Raises:
592
- Logger critical if the component is not an instance of Fusing or OperatorsSetBase.
593
- """
594
- if isinstance(tp_model_component, Fusing):
595
- self.fusing_patterns.append(tp_model_component)
596
- elif isinstance(tp_model_component, OperatorsSetBase):
597
- self.operator_set.append(tp_model_component)
598
- else:
599
- Logger.critical(
600
- f"Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.") # pragma: no cover
532
+ opsets_names = [op.name.value if hasattr(op.name, "value") else op.name for op in self.operator_set] if self.operator_set else []
533
+ if len(set(opsets_names)) != len(opsets_names):
534
+ Logger.critical("Operator Sets must have unique names.") # pragma: no cover
601
535
 
602
536
  def get_info(self) -> Dict[str, Any]:
603
537
  """
@@ -608,51 +542,10 @@ class TargetPlatformModel:
608
542
  """
609
543
  return {
610
544
  "Model name": self.name,
611
- "Operators sets": [o.get_info() for o in self.operator_set],
612
- "Fusing patterns": [f.get_info() for f in self.fusing_patterns],
545
+ "Operators sets": [o.get_info() for o in self.operator_set] if self.operator_set else [],
546
+ "Fusing patterns": [f.get_info() for f in self.fusing_patterns] if self.fusing_patterns else [],
613
547
  }
614
548
 
615
- def __validate_model(self):
616
- """
617
- Validate the model's configuration to ensure its integrity.
618
-
619
- Raises:
620
- Logger critical if the model contains multiple operator sets with the same name.
621
- """
622
- opsets_names = [op.name for op in self.operator_set]
623
- if len(set(opsets_names)) != len(opsets_names):
624
- Logger.critical("Operator Sets must have unique names.") # pragma: no cover
625
-
626
- def __enter__(self) -> 'TargetPlatformModel':
627
- """
628
- Start defining the TargetPlatformModel using a 'with' statement.
629
-
630
- Returns:
631
- TargetPlatformModel: The initialized TargetPlatformModel object.
632
- """
633
- _current_tp_model.set(self)
634
- return self
635
-
636
- def __exit__(self, exc_type, exc_value, tb) -> 'TargetPlatformModel':
637
- """
638
- Finalize and validate the TargetPlatformModel at the end of the 'with' clause.
639
-
640
- Args:
641
- exc_type: Exception type, if any occurred.
642
- exc_value: Exception value, if any occurred.
643
- tb: Traceback object, if an exception occurred.
644
-
645
- Raises:
646
- The exception raised in the 'with' block, if any.
647
-
648
- Returns:
649
- TargetPlatformModel: The validated TargetPlatformModel object.
650
- """
651
- if exc_value is not None:
652
- raise exc_value
653
- self.__validate_model()
654
- _current_tp_model.reset()
655
- return self
656
549
 
657
550
  def show(self):
658
551
  """
@@ -15,7 +15,6 @@
15
15
 
16
16
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import AttributeFilter
17
17
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater, LayerFilterParams, OperationsToLayers, get_current_tpc
18
- from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import get_default_quantization_config_options
19
18
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, OperatorsSet, \
20
19
  OperatorSetConcat, Signedness, AttributeQuantizationConfig, OpQuantizationConfig, QuantizationConfigOptions, Fusing
21
20
 
@@ -90,7 +90,7 @@ class OperationsToLayers:
90
90
  return o.layers
91
91
  if isinstance(op, OperatorSetConcat): # If its a concat - return all layers from all OperatorsSets that in the OperatorSetConcat
92
92
  layers = []
93
- for o in op.op_set_list:
93
+ for o in op.operators_set:
94
94
  layers.extend(self.get_layers_by_op(o))
95
95
  return layers
96
96
  Logger.warning(f'{op.name} is not in model.')
@@ -100,8 +100,10 @@ class TargetPlatformCapabilities(ImmutableClass):
100
100
 
101
101
  """
102
102
  res = []
103
+ if self.tp_model.fusing_patterns is None:
104
+ return res
103
105
  for p in self.tp_model.fusing_patterns:
104
- ops = [self.get_layers_by_opset(x) for x in p.operator_groups_list]
106
+ ops = [self.get_layers_by_opset(x) for x in p.operator_groups]
105
107
  res.extend(itertools.product(*ops))
106
108
  return [list(x) for x in res]
107
109
 
@@ -207,9 +209,10 @@ class TargetPlatformCapabilities(ImmutableClass):
207
209
  Remove OperatorSets names from the list of the unused sets (so a warning
208
210
  will not be displayed).
209
211
  """
210
- for f in self.tp_model.fusing_patterns:
211
- for s in f.operator_groups_list:
212
- self.remove_opset_from_not_used_list(s.name)
212
+ if self.tp_model.fusing_patterns is not None:
213
+ for f in self.tp_model.fusing_patterns:
214
+ for s in f.operator_groups:
215
+ self.remove_opset_from_not_used_list(s.name)
213
216
 
214
217
  def remove_opset_from_not_used_list(self,
215
218
  opset_to_remove: str):
@@ -153,7 +153,54 @@ def generate_tp_model(default_config: OpQuantizationConfig,
153
153
  # of possible configurations to consider when quantizing a set of operations (in mixed-precision, for example).
154
154
  # If the QuantizationConfigOptions contains only one configuration,
155
155
  # this configuration will be used for the operation quantization:
156
- default_configuration_options = schema.QuantizationConfigOptions([default_config])
156
+ default_configuration_options = schema.QuantizationConfigOptions(tuple([default_config]))
157
+
158
+ # Create Mixed-Precision quantization configuration options from the given list of OpQuantizationConfig objects
159
+ mixed_precision_configuration_options = schema.QuantizationConfigOptions(tuple(mixed_precision_cfg_list),
160
+ base_config=base_config)
161
+
162
+ # Create an OperatorsSet to represent a set of operations.
163
+ # Each OperatorsSet has a unique label.
164
+ # If a quantization configuration options is passed, these options will
165
+ # be used for operations that will be attached to this set's label.
166
+ # Otherwise, it will be a configure-less set (used in fusing):
167
+ operator_set = []
168
+ fusing_patterns = []
169
+
170
+ operator_set.append(schema.OperatorsSet("NoQuantization",
171
+ default_configuration_options.clone_and_edit(enable_activation_quantization=False)
172
+ .clone_and_edit_weight_attribute(enable_weights_quantization=False)))
173
+
174
+ # Define operator sets that use mixed_precision_configuration_options:
175
+ conv = schema.OperatorsSet("Conv", mixed_precision_configuration_options)
176
+ fc = schema.OperatorsSet("FullyConnected", mixed_precision_configuration_options)
177
+
178
+ # Define operations sets without quantization configuration
179
+ # options (useful for creating fusing patterns, for example):
180
+ any_relu = schema.OperatorsSet("AnyReLU")
181
+ add = schema.OperatorsSet("Add")
182
+ sub = schema.OperatorsSet("Sub")
183
+ mul = schema.OperatorsSet("Mul")
184
+ div = schema.OperatorsSet("Div")
185
+ prelu = schema.OperatorsSet("PReLU")
186
+ swish = schema.OperatorsSet("Swish")
187
+ sigmoid = schema.OperatorsSet("Sigmoid")
188
+ tanh = schema.OperatorsSet("Tanh")
189
+
190
+ operator_set.extend([conv, fc, any_relu, add, sub, mul, div, prelu, swish, sigmoid, tanh])
191
+ # Combine multiple operators into a single operator to avoid quantization between
192
+ # them. To do this we define fusing patterns using the OperatorsSets that were created.
193
+ # To group multiple sets with regard to fusing, an OperatorSetConcat can be created
194
+ activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh])
195
+ activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid])
196
+ any_binary = schema.OperatorSetConcat([add, sub, mul, div])
197
+
198
+ # ------------------- #
199
+ # Fusions
200
+ # ------------------- #
201
+ fusing_patterns.append(schema.Fusing((conv, activations_after_conv_to_fuse)))
202
+ fusing_patterns.append(schema.Fusing((fc, activations_after_fc_to_fuse)))
203
+ fusing_patterns.append(schema.Fusing((any_binary, any_relu)))
157
204
 
158
205
  # Create a TargetPlatformModel and set its default quantization config.
159
206
  # This default configuration will be used for all operations
@@ -163,57 +210,9 @@ def generate_tp_model(default_config: OpQuantizationConfig,
163
210
  tpc_minor_version=1,
164
211
  tpc_patch_version=0,
165
212
  tpc_platform_type=IMX500_TP_MODEL,
213
+ operator_set=tuple(operator_set),
214
+ fusing_patterns=tuple(fusing_patterns),
166
215
  name=name,
167
216
  add_metadata=False,
168
217
  is_simd_padding=True)
169
-
170
- # To start defining the model's components (such as operator sets, and fusing patterns),
171
- # use 'with' the TargetPlatformModel instance, and create them as below:
172
- with generated_tpc:
173
- # Create an OperatorsSet to represent a set of operations.
174
- # Each OperatorsSet has a unique label.
175
- # If a quantization configuration options is passed, these options will
176
- # be used for operations that will be attached to this set's label.
177
- # Otherwise, it will be a configure-less set (used in fusing):
178
-
179
- # May suit for operations like: Dropout, Reshape, etc.
180
- default_qco = tp.get_default_quantization_config_options()
181
- schema.OperatorsSet("NoQuantization",
182
- default_qco.clone_and_edit(enable_activation_quantization=False)
183
- .clone_and_edit_weight_attribute(enable_weights_quantization=False))
184
-
185
- # Create Mixed-Precision quantization configuration options from the given list of OpQuantizationConfig objects
186
- mixed_precision_configuration_options = schema.QuantizationConfigOptions(mixed_precision_cfg_list,
187
- base_config=base_config)
188
-
189
- # Define operator sets that use mixed_precision_configuration_options:
190
- conv = schema.OperatorsSet("Conv", mixed_precision_configuration_options)
191
- fc = schema.OperatorsSet("FullyConnected", mixed_precision_configuration_options)
192
-
193
- # Define operations sets without quantization configuration
194
- # options (useful for creating fusing patterns, for example):
195
- any_relu = schema.OperatorsSet("AnyReLU")
196
- add = schema.OperatorsSet("Add")
197
- sub = schema.OperatorsSet("Sub")
198
- mul = schema.OperatorsSet("Mul")
199
- div = schema.OperatorsSet("Div")
200
- prelu = schema.OperatorsSet("PReLU")
201
- swish = schema.OperatorsSet("Swish")
202
- sigmoid = schema.OperatorsSet("Sigmoid")
203
- tanh = schema.OperatorsSet("Tanh")
204
-
205
- # Combine multiple operators into a single operator to avoid quantization between
206
- # them. To do this we define fusing patterns using the OperatorsSets that were created.
207
- # To group multiple sets with regard to fusing, an OperatorSetConcat can be created
208
- activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh])
209
- activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid])
210
- any_binary = schema.OperatorSetConcat([add, sub, mul, div])
211
-
212
- # ------------------- #
213
- # Fusions
214
- # ------------------- #
215
- schema.Fusing([conv, activations_after_conv_to_fuse])
216
- schema.Fusing([fc, activations_after_fc_to_fuse])
217
- schema.Fusing([any_binary, any_relu])
218
-
219
218
  return generated_tpc