mct-nightly 2.2.0.20241222.533__py3-none-any.whl → 2.2.0.20241224.532__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {mct_nightly-2.2.0.20241222.533.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20241222.533.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/RECORD +29 -28
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/graph/base_graph.py +1 -1
  5. model_compression_toolkit/core/common/graph/base_node.py +3 -3
  6. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +4 -4
  7. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +2 -2
  8. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +1 -0
  9. model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +4 -5
  10. model_compression_toolkit/target_platform_capabilities/schema/v1.py +66 -172
  11. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -1
  12. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2fw.py +56 -0
  13. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2keras.py +107 -0
  14. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2pytorch.py +91 -0
  15. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +1 -1
  16. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +7 -4
  17. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +50 -51
  18. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +54 -52
  19. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +57 -53
  20. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +52 -51
  21. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +53 -51
  22. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +59 -57
  23. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +54 -52
  24. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +90 -83
  25. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +26 -24
  26. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +57 -55
  27. model_compression_toolkit/target_platform_capabilities/target_platform/current_tp_model.py +0 -67
  28. model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +0 -30
  29. {mct_nightly-2.2.0.20241222.533.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/LICENSE.md +0 -0
  30. {mct_nightly-2.2.0.20241222.533.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/WHEEL +0 -0
  31. {mct_nightly-2.2.0.20241222.533.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/top_level.txt +0 -0
@@ -21,13 +21,11 @@ from mct_quantizers import QuantizationMethod
21
21
  from model_compression_toolkit.constants import FLOAT_BITWIDTH
22
22
  from model_compression_toolkit.logger import Logger
23
23
  from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST
24
- from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import \
25
- _current_tp_model
26
24
 
27
25
  class OperatorSetNames(Enum):
28
26
  OPSET_CONV = "Conv"
29
27
  OPSET_DEPTHWISE_CONV = "DepthwiseConv2D"
30
- OPSET_CONV_TRANSPOSE = "ConvTraspose"
28
+ OPSET_CONV_TRANSPOSE = "ConvTranspose"
31
29
  OPSET_FULLY_CONNECTED = "FullyConnected"
32
30
  OPSET_CONCATENATE = "Concatenate"
33
31
  OPSET_STACK = "Stack"
@@ -43,7 +41,8 @@ class OperatorSetNames(Enum):
43
41
  OPSET_SUB = "Sub"
44
42
  OPSET_MUL = "Mul"
45
43
  OPSET_DIV = "Div"
46
- OPSET_MIN_MAX = "MinMax"
44
+ OPSET_MIN = "Min"
45
+ OPSET_MAX = "Max"
47
46
  OPSET_PRELU = "PReLU"
48
47
  OPSET_SWISH = "Swish"
49
48
  OPSET_SIGMOID = "Sigmoid"
@@ -61,7 +60,6 @@ class OperatorSetNames(Enum):
61
60
  OPSET_DROPOUT = "Dropout"
62
61
  OPSET_SPLIT = "Split"
63
62
  OPSET_CHUNK = "Chunk"
64
- OPSET_UNBIND = "Unbind"
65
63
  OPSET_MAXPOOL = "MaxPool"
66
64
  OPSET_SIZE = "Size"
67
65
  OPSET_SHAPE = "Shape"
@@ -74,6 +72,7 @@ class OperatorSetNames(Enum):
74
72
  OPSET_ZERO_PADDING2d = "ZeroPadding2D"
75
73
  OPSET_CAST = "Cast"
76
74
  OPSET_STRIDED_SLICE = "StridedSlice"
75
+ OPSET_SSD_POST_PROCESS = "SSDPostProcess"
77
76
 
78
77
  @classmethod
79
78
  def get_values(cls):
@@ -225,10 +224,10 @@ class QuantizationConfigOptions:
225
224
  QuantizationConfigOptions wraps a set of quantization configurations to consider during the quantization of an operator.
226
225
 
227
226
  Attributes:
228
- quantization_config_list (List[OpQuantizationConfig]): List of possible OpQuantizationConfig to gather.
227
+ quantization_configurations (Tuple[OpQuantizationConfig]): Tuple of possible OpQuantizationConfig to gather.
229
228
  base_config (Optional[OpQuantizationConfig]): Fallback OpQuantizationConfig to use when optimizing the model in a non-mixed-precision manner.
230
229
  """
231
- quantization_config_list: List[OpQuantizationConfig]
230
+ quantization_configurations: Tuple[OpQuantizationConfig]
232
231
  base_config: Optional[OpQuantizationConfig] = None
233
232
 
234
233
  def __post_init__(self):
@@ -236,32 +235,32 @@ class QuantizationConfigOptions:
236
235
  Post-initialization processing for input validation.
237
236
 
238
237
  Raises:
239
- Logger critical if quantization_config_list is not a list, contains invalid elements, or if base_config is not set correctly.
238
+ Logger critical if quantization_configurations is not a tuple, contains invalid elements, or if base_config is not set correctly.
240
239
  """
241
- # Validate `quantization_config_list`
242
- if not isinstance(self.quantization_config_list, list):
240
+ # Validate `quantization_configurations`
241
+ if not isinstance(self.quantization_configurations, tuple):
243
242
  Logger.critical(
244
- f"'quantization_config_list' must be a list, but received: {type(self.quantization_config_list)}.") # pragma: no cover
245
- for cfg in self.quantization_config_list:
243
+ f"'quantization_configurations' must be a tuple, but received: {type(self.quantization_configurations)}.") # pragma: no cover
244
+ for cfg in self.quantization_configurations:
246
245
  if not isinstance(cfg, OpQuantizationConfig):
247
246
  Logger.critical(
248
247
  f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}.") # pragma: no cover
249
248
 
250
249
  # Handle base_config
251
- if len(self.quantization_config_list) > 1:
250
+ if len(self.quantization_configurations) > 1:
252
251
  if self.base_config is None:
253
252
  Logger.critical(f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization.") # pragma: no cover
254
- if not any(self.base_config == cfg for cfg in self.quantization_config_list):
255
- Logger.critical(f"'base_config' must be included in the quantization config options list.") # pragma: no cover
256
- elif len(self.quantization_config_list) == 1:
253
+ if not any(self.base_config == cfg for cfg in self.quantization_configurations):
254
+ Logger.critical(f"'base_config' must be included in the quantization config options.") # pragma: no cover
255
+ elif len(self.quantization_configurations) == 1:
257
256
  if self.base_config is None:
258
- object.__setattr__(self, 'base_config', self.quantization_config_list[0])
259
- elif self.base_config != self.quantization_config_list[0]:
257
+ object.__setattr__(self, 'base_config', self.quantization_configurations[0])
258
+ elif self.base_config != self.quantization_configurations[0]:
260
259
  Logger.critical(
261
- "'base_config' should be the same as the sole item in 'quantization_config_list'.") # pragma: no cover
260
+ "'base_config' should be the same as the sole item in 'quantization_configurations'.") # pragma: no cover
262
261
 
263
- elif len(self.quantization_config_list) == 0:
264
- Logger.critical("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.") # pragma: no cover
262
+ elif len(self.quantization_configurations) == 0:
263
+ Logger.critical("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided configurations is empty.") # pragma: no cover
265
264
 
266
265
  def clone_and_edit(self, **kwargs) -> 'QuantizationConfigOptions':
267
266
  """
@@ -274,10 +273,10 @@ class QuantizationConfigOptions:
274
273
  A new instance of QuantizationConfigOptions with updated configurations.
275
274
  """
276
275
  updated_base_config = replace(self.base_config, **kwargs)
277
- updated_configs_list = [
278
- replace(cfg, **kwargs) for cfg in self.quantization_config_list
276
+ updated_configs = [
277
+ replace(cfg, **kwargs) for cfg in self.quantization_configurations
279
278
  ]
280
- return replace(self, base_config=updated_base_config, quantization_config_list=updated_configs_list)
279
+ return replace(self, base_config=updated_base_config, quantization_configurations=tuple(updated_configs))
281
280
 
282
281
  def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs) -> 'QuantizationConfigOptions':
283
282
  """
@@ -292,7 +291,7 @@ class QuantizationConfigOptions:
292
291
  """
293
292
  updated_base_config = self.base_config
294
293
  updated_configs = []
295
- for qc in self.quantization_config_list:
294
+ for qc in self.quantization_configurations:
296
295
  if attrs is None:
297
296
  attrs_to_update = list(qc.attr_weights_configs_mapping.keys())
298
297
  else:
@@ -300,7 +299,7 @@ class QuantizationConfigOptions:
300
299
  # Ensure all attributes exist in the config
301
300
  for attr in attrs_to_update:
302
301
  if attr not in qc.attr_weights_configs_mapping:
303
- Logger.critical(f"{attr} does not exist in {qc}.")
302
+ Logger.critical(f"{attr} does not exist in {qc}.") # pragma: no cover
304
303
  updated_attr_mapping = {
305
304
  attr: qc.attr_weights_configs_mapping[attr].clone_and_edit(**kwargs)
306
305
  for attr in attrs_to_update
@@ -308,7 +307,7 @@ class QuantizationConfigOptions:
308
307
  if qc == updated_base_config:
309
308
  updated_base_config = replace(updated_base_config, attr_weights_configs_mapping=updated_attr_mapping)
310
309
  updated_configs.append(replace(qc, attr_weights_configs_mapping=updated_attr_mapping))
311
- return replace(self, base_config=updated_base_config, quantization_config_list=updated_configs)
310
+ return replace(self, base_config=updated_base_config, quantization_configurations=tuple(updated_configs))
312
311
 
313
312
  def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Optional[Dict[str, str]]) -> 'QuantizationConfigOptions':
314
313
  """
@@ -322,7 +321,7 @@ class QuantizationConfigOptions:
322
321
  """
323
322
  updated_configs = []
324
323
  new_base_config = self.base_config
325
- for qc in self.quantization_config_list:
324
+ for qc in self.quantization_configurations:
326
325
  if layer_attrs_mapping is None:
327
326
  new_attr_mapping = {}
328
327
  else:
@@ -333,7 +332,7 @@ class QuantizationConfigOptions:
333
332
  if qc == self.base_config:
334
333
  new_base_config = replace(qc, attr_weights_configs_mapping=new_attr_mapping)
335
334
  updated_configs.append(replace(qc, attr_weights_configs_mapping=new_attr_mapping))
336
- return replace(self, base_config=new_base_config, quantization_config_list=updated_configs)
335
+ return replace(self, base_config=new_base_config, quantization_configurations=tuple(updated_configs))
337
336
 
338
337
  def get_info(self) -> Dict[str, Any]:
339
338
  """
@@ -342,7 +341,7 @@ class QuantizationConfigOptions:
342
341
  Returns:
343
342
  dict: Information about the quantization configuration options as a dictionary.
344
343
  """
345
- return {f'option {i}': cfg.get_info() for i, cfg in enumerate(self.quantization_config_list)}
344
+ return {f'option {i}': cfg.get_info() for i, cfg in enumerate(self.quantization_configurations)}
346
345
 
347
346
 
348
347
  @dataclass(frozen=True)
@@ -350,22 +349,7 @@ class TargetPlatformModelComponent:
350
349
  """
351
350
  Component of TargetPlatformModel (Fusing, OperatorsSet, etc.).
352
351
  """
353
-
354
- def __post_init__(self):
355
- """
356
- Post-initialization to register the component with the current TargetPlatformModel.
357
- """
358
- _current_tp_model.get().append_component(self)
359
-
360
- def get_info(self) -> Dict[str, Any]:
361
- """
362
- Get information about the component to display.
363
-
364
- Returns:
365
- Dict[str, Any]: Returns an empty dictionary. The actual component should override
366
- this method to provide relevant information.
367
- """
368
- return {}
352
+ pass
369
353
 
370
354
 
371
355
  @dataclass(frozen=True)
@@ -374,12 +358,7 @@ class OperatorsSetBase(TargetPlatformModelComponent):
374
358
  Base class to represent a set of a target platform model component of operator set types.
375
359
  Inherits from TargetPlatformModelComponent.
376
360
  """
377
- def __post_init__(self):
378
- """
379
- Post-initialization to ensure the component is registered with the TargetPlatformModel.
380
- Calls the parent class's __post_init__ method to append this component to the current TargetPlatformModel.
381
- """
382
- super().__post_init__()
361
+ pass
383
362
 
384
363
 
385
364
  @dataclass(frozen=True)
@@ -394,23 +373,9 @@ class OperatorsSet(OperatorsSetBase):
394
373
  is_default (bool): Indicates whether this set is the default quantization configuration
395
374
  for the TargetPlatformModel or a fusing set.
396
375
  """
397
- name: str
376
+ name: Union[str, OperatorSetNames]
398
377
  qc_options: QuantizationConfigOptions = None
399
378
 
400
- def __post_init__(self):
401
- """
402
- Post-initialization processing to mark the operator set as default if applicable.
403
-
404
- Calls the parent class's __post_init__ method and sets `is_default` to True
405
- if this set corresponds to the default quantization configuration for the
406
- TargetPlatformModel or if it is a fusing set.
407
-
408
- """
409
- super().__post_init__()
410
- is_fusing_set = self.qc_options is None
411
- is_default = _current_tp_model.get().default_qco == self.qc_options or is_fusing_set
412
- object.__setattr__(self, 'is_default', is_default)
413
-
414
379
  def get_info(self) -> Dict[str, Any]:
415
380
  """
416
381
  Get information about the set as a dictionary.
@@ -419,83 +384,67 @@ class OperatorsSet(OperatorsSetBase):
419
384
  Dict[str, Any]: A dictionary containing the set name and
420
385
  whether it is the default quantization configuration.
421
386
  """
422
- return {"name": self.name,
423
- "is_default_qc": self.is_default}
387
+ return {"name": self.name}
424
388
 
425
389
 
426
390
  @dataclass(frozen=True)
427
391
  class OperatorSetConcat(OperatorsSetBase):
428
392
  """
429
- Concatenate a list of operator sets to treat them similarly in different places (like fusing).
393
+ Concatenate a tuple of operator sets to treat them similarly in different places (like fusing).
430
394
 
431
395
  Attributes:
432
- op_set_list (List[OperatorsSet]): List of operator sets to group.
396
+ operators_set (Tuple[OperatorsSet]): Tuple of operator sets to group.
433
397
  qc_options (None): Configuration options for the set, always None for concatenated sets.
434
- name (str): Concatenated name generated from the names of the operator sets in the list.
398
+ name (str): Concatenated name generated from the names of the operator sets.
435
399
  """
436
- op_set_list: List[OperatorsSet] = field(default_factory=list)
400
+ operators_set: Tuple[OperatorsSet]
437
401
  qc_options: None = field(default=None, init=False)
438
- name: str = None
439
402
 
440
403
  def __post_init__(self):
441
404
  """
442
405
  Post-initialization processing to generate the concatenated name and set it as the `name` attribute.
443
406
 
444
407
  Calls the parent class's __post_init__ method and creates a concatenated name
445
- by joining the names of all operator sets in `op_set_list`.
408
+ by joining the names of all operator sets in `operators_set`.
446
409
  """
447
- super().__post_init__()
448
410
  # Generate the concatenated name from the operator sets
449
- concatenated_name = "_".join([op.name for op in self.op_set_list])
411
+ concatenated_name = "_".join([op.name.value if hasattr(op.name, "value") else op.name for op in self.operators_set])
450
412
  # Set the inherited name attribute using `object.__setattr__` since the dataclass is frozen
451
413
  object.__setattr__(self, "name", concatenated_name)
452
414
 
453
- def get_info(self) -> Dict[str, Any]:
454
- """
455
- Get information about the concatenated set as a dictionary.
456
-
457
- Returns:
458
- Dict[str, Any]: A dictionary containing the concatenated name and
459
- the list of names of the operator sets in `op_set_list`.
460
- """
461
- return {"name": self.name,
462
- OPS_SET_LIST: [s.name for s in self.op_set_list]}
463
-
464
415
 
465
416
  @dataclass(frozen=True)
466
417
  class Fusing(TargetPlatformModelComponent):
467
418
  """
468
- Fusing defines a list of operators that should be combined and treated as a single operator,
419
+ Fusing defines a tuple of operators that should be combined and treated as a single operator,
469
420
  hence no quantization is applied between them.
470
421
 
471
422
  Attributes:
472
- operator_groups_list (Tuple[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups,
423
+ operator_groups (Tuple[Union[OperatorsSet, OperatorSetConcat]]): A tuple of operator groups,
473
424
  each being either an OperatorSetConcat or an OperatorsSet.
474
425
  name (str): The name for the Fusing instance. If not provided, it is generated from the operator groups' names.
475
426
  """
476
- operator_groups_list: Tuple[Union[OperatorsSet, OperatorSetConcat]]
477
- name: str = None
427
+ operator_groups: Tuple[Union[OperatorsSet, OperatorSetConcat]]
478
428
 
479
429
  def __post_init__(self):
480
430
  """
481
431
  Post-initialization processing for input validation and name generation.
482
432
 
483
- Calls the parent class's __post_init__ method, validates the operator_groups_list,
433
+ Calls the parent class's __post_init__ method, validates the operator_groups,
484
434
  and generates the name if not explicitly provided.
485
435
 
486
436
  Raises:
487
- Logger critical if operator_groups_list is not a list or if it contains fewer than two operators.
437
+ Logger critical if operator_groups is not a tuple or if it contains fewer than two operators.
488
438
  """
489
- super().__post_init__()
490
- # Validate the operator_groups_list
491
- if not isinstance(self.operator_groups_list, list):
439
+ # Validate the operator_groups
440
+ if not isinstance(self.operator_groups, tuple):
492
441
  Logger.critical(
493
- f"List of operator groups should be of type list but is {type(self.operator_groups_list)}.") # pragma: no cover
494
- if len(self.operator_groups_list) < 2:
442
+ f"Operator groups should be of type 'tuple' but is {type(self.operator_groups)}.") # pragma: no cover
443
+ if len(self.operator_groups) < 2:
495
444
  Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
496
445
 
497
446
  # Generate the name from the operator groups if not provided
498
- generated_name = '_'.join([x.name for x in self.operator_groups_list])
447
+ generated_name = '_'.join([x.name.value if hasattr(x.name, 'value') else x.name for x in self.operator_groups])
499
448
  object.__setattr__(self, 'name', generated_name)
500
449
 
501
450
  def contains(self, other: Any) -> bool:
@@ -512,11 +461,11 @@ class Fusing(TargetPlatformModelComponent):
512
461
  return False
513
462
 
514
463
  # Check for containment by comparing operator groups
515
- for i in range(len(self.operator_groups_list) - len(other.operator_groups_list) + 1):
516
- for j in range(len(other.operator_groups_list)):
517
- if self.operator_groups_list[i + j] != other.operator_groups_list[j] and not (
518
- isinstance(self.operator_groups_list[i + j], OperatorSetConcat) and (
519
- other.operator_groups_list[j] in self.operator_groups_list[i + j].op_set_list)):
464
+ for i in range(len(self.operator_groups) - len(other.operator_groups) + 1):
465
+ for j in range(len(other.operator_groups)):
466
+ if self.operator_groups[i + j] != other.operator_groups[j] and not (
467
+ isinstance(self.operator_groups[i + j], OperatorSetConcat) and (
468
+ other.operator_groups[j] in self.operator_groups[i + j].operators_set)):
520
469
  break
521
470
  else:
522
471
  # If all checks pass, the other Fusing instance is contained
@@ -534,8 +483,8 @@ class Fusing(TargetPlatformModelComponent):
534
483
  or just the sequence of operator groups if no name is set.
535
484
  """
536
485
  if self.name is not None:
537
- return {self.name: ' -> '.join([x.name for x in self.operator_groups_list])}
538
- return ' -> '.join([x.name for x in self.operator_groups_list])
486
+ return {self.name: ' -> '.join([x.name for x in self.operator_groups])}
487
+ return ' -> '.join([x.name for x in self.operator_groups])
539
488
 
540
489
 
541
490
  @dataclass(frozen=True)
@@ -550,8 +499,8 @@ class TargetPlatformModel:
550
499
  tpc_platform_type (Optional[str]): Type of the platform for the Target Platform Configuration.
551
500
  add_metadata (bool): Flag to determine if metadata should be added.
552
501
  name (str): Name of the Target Platform Model.
553
- operator_set (List[OperatorsSetBase]): List of operator sets within the model.
554
- fusing_patterns (List[Fusing]): List of fusing patterns for the model.
502
+ operator_set (Tuple[OperatorsSetBase]): Tuple of operator sets within the model.
503
+ fusing_patterns (Tuple[Fusing]): Tuple of fusing patterns for the model.
555
504
  is_simd_padding (bool): Indicates if SIMD padding is applied.
556
505
  SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
557
506
  """
@@ -561,8 +510,8 @@ class TargetPlatformModel:
561
510
  tpc_platform_type: Optional[str]
562
511
  add_metadata: bool = True
563
512
  name: str = "default_tp_model"
564
- operator_set: List[OperatorsSetBase] = field(default_factory=list)
565
- fusing_patterns: List[Fusing] = field(default_factory=list)
513
+ operator_set: Tuple[OperatorsSetBase] = None
514
+ fusing_patterns: Tuple[Fusing] = None
566
515
  is_simd_padding: bool = False
567
516
 
568
517
  SCHEMA_VERSION: int = 1
@@ -578,26 +527,12 @@ class TargetPlatformModel:
578
527
  # Validate `default_qco`
579
528
  if not isinstance(self.default_qco, QuantizationConfigOptions):
580
529
  Logger.critical("'default_qco' must be an instance of QuantizationConfigOptions.") # pragma: no cover
581
- if len(self.default_qco.quantization_config_list) != 1:
530
+ if len(self.default_qco.quantization_configurations) != 1:
582
531
  Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
583
532
 
584
- def append_component(self, tp_model_component: TargetPlatformModelComponent):
585
- """
586
- Attach a TargetPlatformModel component to the model (like Fusing or OperatorsSet).
587
-
588
- Args:
589
- tp_model_component (TargetPlatformModelComponent): Component to attach to the model.
590
-
591
- Raises:
592
- Logger critical if the component is not an instance of Fusing or OperatorsSetBase.
593
- """
594
- if isinstance(tp_model_component, Fusing):
595
- self.fusing_patterns.append(tp_model_component)
596
- elif isinstance(tp_model_component, OperatorsSetBase):
597
- self.operator_set.append(tp_model_component)
598
- else:
599
- Logger.critical(
600
- f"Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.") # pragma: no cover
533
+ opsets_names = [op.name.value if hasattr(op.name, "value") else op.name for op in self.operator_set] if self.operator_set else []
534
+ if len(set(opsets_names)) != len(opsets_names):
535
+ Logger.critical("Operator Sets must have unique names.") # pragma: no cover
601
536
 
602
537
  def get_info(self) -> Dict[str, Any]:
603
538
  """
@@ -608,51 +543,10 @@ class TargetPlatformModel:
608
543
  """
609
544
  return {
610
545
  "Model name": self.name,
611
- "Operators sets": [o.get_info() for o in self.operator_set],
612
- "Fusing patterns": [f.get_info() for f in self.fusing_patterns],
546
+ "Operators sets": [o.get_info() for o in self.operator_set] if self.operator_set else [],
547
+ "Fusing patterns": [f.get_info() for f in self.fusing_patterns] if self.fusing_patterns else [],
613
548
  }
614
549
 
615
- def __validate_model(self):
616
- """
617
- Validate the model's configuration to ensure its integrity.
618
-
619
- Raises:
620
- Logger critical if the model contains multiple operator sets with the same name.
621
- """
622
- opsets_names = [op.name for op in self.operator_set]
623
- if len(set(opsets_names)) != len(opsets_names):
624
- Logger.critical("Operator Sets must have unique names.") # pragma: no cover
625
-
626
- def __enter__(self) -> 'TargetPlatformModel':
627
- """
628
- Start defining the TargetPlatformModel using a 'with' statement.
629
-
630
- Returns:
631
- TargetPlatformModel: The initialized TargetPlatformModel object.
632
- """
633
- _current_tp_model.set(self)
634
- return self
635
-
636
- def __exit__(self, exc_type, exc_value, tb) -> 'TargetPlatformModel':
637
- """
638
- Finalize and validate the TargetPlatformModel at the end of the 'with' clause.
639
-
640
- Args:
641
- exc_type: Exception type, if any occurred.
642
- exc_value: Exception value, if any occurred.
643
- tb: Traceback object, if an exception occurred.
644
-
645
- Raises:
646
- The exception raised in the 'with' block, if any.
647
-
648
- Returns:
649
- TargetPlatformModel: The validated TargetPlatformModel object.
650
- """
651
- if exc_value is not None:
652
- raise exc_value
653
- self.__validate_model()
654
- _current_tp_model.reset()
655
- return self
656
550
 
657
551
  def show(self):
658
552
  """
@@ -15,7 +15,6 @@
15
15
 
16
16
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import AttributeFilter
17
17
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater, LayerFilterParams, OperationsToLayers, get_current_tpc
18
- from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import get_default_quantization_config_options
19
18
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, OperatorsSet, \
20
19
  OperatorSetConcat, Signedness, AttributeQuantizationConfig, OpQuantizationConfig, QuantizationConfigOptions, Fusing
21
20
 
@@ -0,0 +1,56 @@
1
+ from typing import Dict, Tuple, List, Any, Optional
2
+
3
+ from model_compression_toolkit import DefaultDict
4
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
5
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, \
6
+ OperationsSetToLayers
7
+
8
+
9
+ class AttachTpModelToFw:
10
+
11
+ def __init__(self):
12
+ self._opset2layer = None
13
+
14
+ # A mapping that associates each layer type in the operation set (with weight attributes and a quantization
15
+ # configuration in the target platform model) to its framework-specific attribute name. If not all layer types
16
+ # in the operation set are provided in the mapping, a DefaultDict should be supplied to handle missing entries.
17
+ self._opset2attr_mapping = None # Mapping of operation sets to their corresponding framework-specific layers
18
+
19
+ def attach(self, tpc_model: TargetPlatformModel,
20
+ custom_opset2layer: Dict[str, Tuple[List[Any], Optional[Dict[str, DefaultDict]]]] = None
21
+ ) -> TargetPlatformCapabilities:
22
+ """
23
+ Attaching a TargetPlatformModel which includes a platform capabilities description to specific
24
+ framework's operators.
25
+
26
+ Args:
27
+ tpc_model: a TargetPlatformModel object.
28
+ custom_opset2layer: optional set of custom operator sets which allows to add/override the built-in set
29
+ of framework operator, to define a specific behavior for those operators. This dictionary should map
30
+ an operator set unique name to a pair of: a list of framework operators and an optional
31
+ operator's attributes names mapping.
32
+
33
+ Returns: a TargetPlatformCapabilities object.
34
+
35
+ """
36
+
37
+ tpc = TargetPlatformCapabilities(tpc_model)
38
+
39
+ with tpc:
40
+ for opset_name, operators in self._opset2layer.items():
41
+ attr_mapping = self._opset2attr_mapping.get(opset_name)
42
+ OperationsSetToLayers(opset_name, operators, attr_mapping=attr_mapping)
43
+
44
+ if custom_opset2layer is not None:
45
+ for opset_name, operators in custom_opset2layer.items():
46
+ if len(operators) == 1:
47
+ OperationsSetToLayers(opset_name, operators[0])
48
+ elif len(operators) == 2:
49
+ OperationsSetToLayers(opset_name, operators[0], attr_mapping=operators[1])
50
+ else:
51
+ raise ValueError(f"Custom operator set to layer mapping should include up to 2 elements - "
52
+ f"a list of layers to attach to the operator and an optional mapping of "
53
+ f"attributes names, but given a mapping contains {len(operators)} elements.")
54
+
55
+ return tpc
56
+
@@ -0,0 +1,107 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import tensorflow as tf
17
+ from packaging import version
18
+
19
+ from model_compression_toolkit.verify_packages import FOUND_SONY_CUSTOM_LAYERS
20
+
21
+ if FOUND_SONY_CUSTOM_LAYERS:
22
+ from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess
23
+
24
+ if version.parse(tf.__version__) >= version.parse("2.13"):
25
+ from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
26
+ MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
27
+ Conv2DTranspose, Identity, Concatenate, BatchNormalization, Minimum, Maximum
28
+ else:
29
+ from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
30
+ MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
31
+ Conv2DTranspose, Concatenate, BatchNormalization, Minimum, Maximum
32
+
33
+ from model_compression_toolkit import DefaultDict
34
+ from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS, \
35
+ BIAS_ATTR, KERAS_KERNEL, KERAS_DEPTHWISE_KERNEL
36
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
37
+ from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams
38
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
39
+ AttachTpModelToFw
40
+
41
+
42
+ class AttachTpModelToKeras(AttachTpModelToFw):
43
+ def __init__(self):
44
+ super().__init__()
45
+
46
+ self._opset2layer = {
47
+ OperatorSetNames.OPSET_CONV.value: [Conv2D, tf.nn.conv2d],
48
+ OperatorSetNames.OPSET_DEPTHWISE_CONV.value: [DepthwiseConv2D, tf.nn.depthwise_conv2d],
49
+ OperatorSetNames.OPSET_CONV_TRANSPOSE.value: [Conv2DTranspose, tf.nn.conv2d_transpose],
50
+ OperatorSetNames.OPSET_FULLY_CONNECTED.value: [Dense],
51
+ OperatorSetNames.OPSET_CONCATENATE.value: [tf.concat, Concatenate],
52
+ OperatorSetNames.OPSET_STACK.value: [tf.stack],
53
+ OperatorSetNames.OPSET_UNSTACK.value: [tf.unstack],
54
+ OperatorSetNames.OPSET_GATHER.value: [tf.gather, tf.compat.v1.gather],
55
+ OperatorSetNames.OPSET_EXPAND.value: [],
56
+ OperatorSetNames.OPSET_BATCH_NORM.value: [BatchNormalization],
57
+ OperatorSetNames.OPSET_RELU.value: [tf.nn.relu, ReLU],
58
+ OperatorSetNames.OPSET_RELU6.value: [tf.nn.relu6],
59
+ OperatorSetNames.OPSET_LEAKY_RELU.value: [tf.nn.leaky_relu, LeakyReLU],
60
+ OperatorSetNames.OPSET_HARD_TANH.value: [LayerFilterParams(Activation, activation="hard_tanh")],
61
+ OperatorSetNames.OPSET_ADD.value: [tf.add, Add],
62
+ OperatorSetNames.OPSET_SUB.value: [tf.subtract, Subtract],
63
+ OperatorSetNames.OPSET_MUL.value: [tf.math.multiply, Multiply],
64
+ OperatorSetNames.OPSET_DIV.value: [tf.math.divide, tf.math.truediv],
65
+ OperatorSetNames.OPSET_MIN.value: [tf.math.minimum, Minimum],
66
+ OperatorSetNames.OPSET_MAX.value: [tf.math.maximum, Maximum],
67
+ OperatorSetNames.OPSET_PRELU.value: [PReLU],
68
+ OperatorSetNames.OPSET_SWISH.value: [tf.nn.swish, LayerFilterParams(Activation, activation="swish")],
69
+ OperatorSetNames.OPSET_SIGMOID.value: [tf.nn.sigmoid, LayerFilterParams(Activation, activation="sigmoid")],
70
+ OperatorSetNames.OPSET_TANH.value: [tf.nn.tanh, LayerFilterParams(Activation, activation="tanh")],
71
+ OperatorSetNames.OPSET_GELU.value: [tf.nn.gelu, LayerFilterParams(Activation, activation="gelu")],
72
+ OperatorSetNames.OPSET_HARDSIGMOID.value: [tf.keras.activations.hard_sigmoid,
73
+ LayerFilterParams(Activation, activation="hard_sigmoid")],
74
+ OperatorSetNames.OPSET_FLATTEN.value: [Flatten],
75
+ OperatorSetNames.OPSET_GET_ITEM.value: [tf.__operators__.getitem],
76
+ OperatorSetNames.OPSET_RESHAPE.value: [Reshape, tf.reshape],
77
+ OperatorSetNames.OPSET_PERMUTE.value: [Permute],
78
+ OperatorSetNames.OPSET_TRANSPOSE.value: [tf.transpose],
79
+ OperatorSetNames.OPSET_DROPOUT.value: [Dropout],
80
+ OperatorSetNames.OPSET_SPLIT.value: [tf.split],
81
+ OperatorSetNames.OPSET_MAXPOOL.value: [MaxPooling2D],
82
+ OperatorSetNames.OPSET_SHAPE.value: [tf.shape, tf.compat.v1.shape],
83
+ OperatorSetNames.OPSET_EQUAL.value: [tf.math.equal],
84
+ OperatorSetNames.OPSET_ARGMAX.value: [tf.math.argmax],
85
+ OperatorSetNames.OPSET_TOPK.value: [tf.nn.top_k],
86
+ OperatorSetNames.OPSET_FAKE_QUANT_WITH_MIN_MAX_VARS.value: [tf.quantization.fake_quant_with_min_max_vars],
87
+ OperatorSetNames.OPSET_COMBINED_NON_MAX_SUPPRESSION.value: [tf.image.combined_non_max_suppression],
88
+ OperatorSetNames.OPSET_CROPPING2D.value: [Cropping2D],
89
+ OperatorSetNames.OPSET_ZERO_PADDING2d.value: [ZeroPadding2D],
90
+ OperatorSetNames.OPSET_CAST.value: [tf.cast],
91
+ OperatorSetNames.OPSET_STRIDED_SLICE.value: [tf.strided_slice]
92
+ }
93
+
94
+ if FOUND_SONY_CUSTOM_LAYERS:
95
+ self._opset2layer[OperatorSetNames.OPSET_POST_PROCESS] = [SSDPostProcess]
96
+
97
+ self._opset2attr_mapping = {OperatorSetNames.OPSET_CONV.value: {
98
+ KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
99
+ BIAS_ATTR: DefaultDict(default_value=BIAS)},
100
+ OperatorSetNames.OPSET_DEPTHWISE_CONV.value: {
101
+ KERNEL_ATTR: DefaultDict({
102
+ DepthwiseConv2D: KERAS_DEPTHWISE_KERNEL,
103
+ tf.nn.depthwise_conv2d: KERAS_DEPTHWISE_KERNEL}, default_value=KERAS_KERNEL),
104
+ BIAS_ATTR: DefaultDict(default_value=BIAS)},
105
+ OperatorSetNames.OPSET_FULLY_CONNECTED.value: {
106
+ KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
107
+ BIAS_ATTR: DefaultDict(default_value=BIAS)}}