mct-nightly 2.2.0.20241201.617__py3-none-any.whl → 2.2.0.20241202.131715__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/RECORD +58 -58
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/constants.py +0 -3
  5. model_compression_toolkit/core/common/graph/base_node.py +7 -5
  6. model_compression_toolkit/core/common/graph/functional_node.py +1 -1
  7. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -2
  8. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -2
  9. model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
  10. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +2 -1
  11. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +1 -1
  12. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -1
  13. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  14. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  15. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +2 -2
  16. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -1
  17. model_compression_toolkit/metadata.py +14 -5
  18. model_compression_toolkit/target_platform_capabilities/schema/__init__.py +14 -0
  19. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +11 -0
  20. model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +37 -0
  21. model_compression_toolkit/target_platform_capabilities/{target_platform/op_quantization_config.py → schema/v1.py} +377 -24
  22. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +3 -5
  23. model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +2 -214
  24. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +1 -2
  25. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +6 -10
  26. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +39 -32
  27. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_keras.py +3 -2
  28. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_pytorch.py +3 -5
  29. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +36 -31
  30. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_keras.py +3 -2
  31. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_pytorch.py +3 -4
  32. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +37 -32
  33. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_keras.py +3 -2
  34. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_pytorch.py +3 -4
  35. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +39 -32
  36. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +3 -2
  37. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_pytorch.py +3 -4
  38. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +36 -31
  39. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +3 -2
  40. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py +3 -4
  41. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +45 -38
  42. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +3 -2
  43. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +3 -4
  44. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +37 -32
  45. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +3 -2
  46. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +3 -4
  47. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +70 -62
  48. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +3 -2
  49. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +3 -4
  50. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +22 -17
  51. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_keras.py +3 -4
  52. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +3 -4
  53. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +56 -51
  54. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -4
  55. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -4
  56. model_compression_toolkit/target_platform_capabilities/target_platform/fusing.py +0 -85
  57. model_compression_toolkit/target_platform_capabilities/target_platform/operators.py +0 -87
  58. model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model_component.py +0 -40
  59. {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/LICENSE.md +0 -0
  60. {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/WHEEL +0 -0
  61. {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,14 +12,23 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
16
15
  import copy
17
- from typing import List, Dict, Union, Any, Tuple
16
+
18
17
  from enum import Enum
19
18
 
19
+ import pprint
20
+
21
+ from typing import Dict, Any, Union, Tuple, List, Optional
22
+
20
23
  from mct_quantizers import QuantizationMethod
21
24
  from model_compression_toolkit.constants import FLOAT_BITWIDTH
25
+
22
26
  from model_compression_toolkit.logger import Logger
27
+ from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST
28
+ from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
29
+ from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import \
30
+ get_current_tp_model, _current_tp_model
31
+ from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import clone_and_edit_object_params
23
32
 
24
33
 
25
34
  class Signedness(Enum):
@@ -35,27 +44,6 @@ class Signedness(Enum):
35
44
  UNSIGNED = 2
36
45
 
37
46
 
38
- def clone_and_edit_object_params(obj: Any, **kwargs: Dict) -> Any:
39
- """
40
- Clones the given object and edit some of its parameters.
41
-
42
- Args:
43
- obj: An object to clone.
44
- **kwargs: Keyword arguments to edit in the cloned object.
45
-
46
- Returns:
47
- Edited copy of the given object.
48
- """
49
-
50
- obj_copy = copy.deepcopy(obj)
51
- for k, v in kwargs.items():
52
- assert hasattr(obj_copy,
53
- k), f'Edit parameter is possible only for existing parameters in the given object, ' \
54
- f'but {k} is not a parameter of {obj_copy}.'
55
- setattr(obj_copy, k, v)
56
- return obj_copy
57
-
58
-
59
47
  class AttributeQuantizationConfig:
60
48
  """
61
49
  Hold the quantization configuration of a weight attribute of a layer.
@@ -376,3 +364,368 @@ class QuantizationConfigOptions:
376
364
  def get_info(self):
377
365
  return {f'option {i}': cfg.get_info() for i, cfg in enumerate(self.quantization_config_list)}
378
366
 
367
+
368
+ class TargetPlatformModelComponent:
369
+ """
370
+ Component of TargetPlatformModel (Fusing, OperatorsSet, etc.)
371
+ """
372
+ def __init__(self, name: str):
373
+ """
374
+
375
+ Args:
376
+ name: Name of component.
377
+ """
378
+ self.name = name
379
+ _current_tp_model.get().append_component(self)
380
+
381
+ def get_info(self) -> Dict[str, Any]:
382
+ """
383
+
384
+ Returns: Get information about the component to display (return an empty dictionary.
385
+ the actual component should fill it with info).
386
+
387
+ """
388
+ return {}
389
+
390
+
391
+ class OperatorsSetBase(TargetPlatformModelComponent):
392
+ """
393
+ Base class to represent a set of operators.
394
+ """
395
+ def __init__(self, name: str):
396
+ """
397
+
398
+ Args:
399
+ name: Name of OperatorsSet.
400
+ """
401
+ super().__init__(name=name)
402
+
403
+
404
+ class OperatorsSet(OperatorsSetBase):
405
+ def __init__(self,
406
+ name: str,
407
+ qc_options: QuantizationConfigOptions = None):
408
+ """
409
+ Set of operators that are represented by a unique label.
410
+
411
+ Args:
412
+ name (str): Set's label (must be unique in a TargetPlatformModel).
413
+ qc_options (QuantizationConfigOptions): Configuration options to use for this set of operations.
414
+ """
415
+
416
+ super().__init__(name)
417
+ self.qc_options = qc_options
418
+ is_fusing_set = qc_options is None
419
+ self.is_default = _current_tp_model.get().default_qco == self.qc_options or is_fusing_set
420
+
421
+
422
+ def get_info(self) -> Dict[str,Any]:
423
+ """
424
+
425
+ Returns: Info about the set as a dictionary.
426
+
427
+ """
428
+ return {"name": self.name,
429
+ "is_default_qc": self.is_default}
430
+
431
+
432
+ class OperatorSetConcat(OperatorsSetBase):
433
+ """
434
+ Concatenate a list of operator sets to treat them similarly in different places (like fusing).
435
+ """
436
+ def __init__(self, *opsets: OperatorsSet):
437
+ """
438
+ Group a list of operation sets.
439
+
440
+ Args:
441
+ *opsets (OperatorsSet): List of operator sets to group.
442
+ """
443
+ name = "_".join([a.name for a in opsets])
444
+ super().__init__(name=name)
445
+ self.op_set_list = opsets
446
+ self.qc_options = None # Concat have no qc options
447
+
448
+ def get_info(self) -> Dict[str,Any]:
449
+ """
450
+
451
+ Returns: Info about the sets group as a dictionary.
452
+
453
+ """
454
+ return {"name": self.name,
455
+ OPS_SET_LIST: [s.name for s in self.op_set_list]}
456
+
457
+
458
+ class Fusing(TargetPlatformModelComponent):
459
+ """
460
+ Fusing defines a list of operators that should be combined and treated as a single operator,
461
+ hence no quantization is applied between them.
462
+ """
463
+
464
+ def __init__(self,
465
+ operator_groups_list: List[Union[OperatorsSet, OperatorSetConcat]],
466
+ name: str = None):
467
+ """
468
+ Args:
469
+ operator_groups_list (List[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups, each being either an OperatorSetConcat or an OperatorsSet.
470
+ name (str): The name for the Fusing instance. If not provided, it's generated from the operator groups' names.
471
+ """
472
+ assert isinstance(operator_groups_list,
473
+ list), f'List of operator groups should be of type list but is {type(operator_groups_list)}'
474
+ assert len(operator_groups_list) >= 2, f'Fusing can not be created for a single operators group'
475
+
476
+ # Generate a name from the operator groups if no name is provided
477
+ if name is None:
478
+ name = '_'.join([x.name for x in operator_groups_list])
479
+
480
+ super().__init__(name)
481
+ self.operator_groups_list = operator_groups_list
482
+
483
+ def contains(self, other: Any) -> bool:
484
+ """
485
+ Determines if the current Fusing instance contains another Fusing instance.
486
+
487
+ Args:
488
+ other: The other Fusing instance to check against.
489
+
490
+ Returns:
491
+ A boolean indicating whether the other instance is contained within this one.
492
+ """
493
+ if not isinstance(other, Fusing):
494
+ return False
495
+
496
+ # Check for containment by comparing operator groups
497
+ for i in range(len(self.operator_groups_list) - len(other.operator_groups_list) + 1):
498
+ for j in range(len(other.operator_groups_list)):
499
+ if self.operator_groups_list[i + j] != other.operator_groups_list[j] and not (
500
+ isinstance(self.operator_groups_list[i + j], OperatorSetConcat) and (
501
+ other.operator_groups_list[j] in self.operator_groups_list[i + j].op_set_list)):
502
+ break
503
+ else:
504
+ # If all checks pass, the other Fusing instance is contained
505
+ return True
506
+ # Other Fusing instance is not contained
507
+ return False
508
+
509
+ def get_info(self):
510
+ """
511
+ Retrieves information about the Fusing instance, including its name and the sequence of operator groups.
512
+
513
+ Returns:
514
+ A dictionary with the Fusing instance's name as the key and the sequence of operator groups as the value,
515
+ or just the sequence of operator groups if no name is set.
516
+ """
517
+ if self.name is not None:
518
+ return {self.name: ' -> '.join([x.name for x in self.operator_groups_list])}
519
+ return ' -> '.join([x.name for x in self.operator_groups_list])
520
+
521
+
522
+ class TargetPlatformModel(ImmutableClass):
523
+ """
524
+ Represents the hardware configuration used for quantized model inference.
525
+
526
+ This model defines:
527
+ - The operators and their associated quantization configurations.
528
+ - Fusing patterns, enabling multiple operators to be combined into a single operator
529
+ for optimization during inference.
530
+ - Versioning support through minor and patch versions for backward compatibility.
531
+
532
+ Attributes:
533
+ SCHEMA_VERSION (int): The schema version of the target platform model.
534
+ """
535
+ SCHEMA_VERSION = 1
536
+ def __init__(self,
537
+ default_qco: QuantizationConfigOptions,
538
+ tpc_minor_version: Optional[int],
539
+ tpc_patch_version: Optional[int],
540
+ tpc_platform_type: Optional[str],
541
+ add_metadata: bool = True,
542
+ name="default_tp_model"):
543
+ """
544
+
545
+ Args:
546
+ default_qco (QuantizationConfigOptions): Default QuantizationConfigOptions to use for operators that their QuantizationConfigOptions are not defined in the model.
547
+ tpc_minor_version (Optional[int]): The minor version of the target platform capabilities.
548
+ tpc_patch_version (Optional[int]): The patch version of the target platform capabilities.
549
+ tpc_platform_type (Optional[str]): The platform type of the target platform capabilities.
550
+ add_metadata (bool): Whether to add metadata to the model or not.
551
+ name (str): Name of the model.
552
+
553
+ Raises:
554
+ AssertionError: If the provided `default_qco` does not contain exactly one quantization configuration.
555
+ """
556
+
557
+ super().__init__()
558
+ self.tpc_minor_version = tpc_minor_version
559
+ self.tpc_patch_version = tpc_patch_version
560
+ self.tpc_platform_type = tpc_platform_type
561
+ self.add_metadata = add_metadata
562
+ self.name = name
563
+ self.operator_set = []
564
+ assert isinstance(default_qco, QuantizationConfigOptions), \
565
+ "default_qco must be an instance of QuantizationConfigOptions"
566
+ assert len(default_qco.quantization_config_list) == 1, \
567
+ "Default QuantizationConfigOptions must contain exactly one option."
568
+
569
+ self.default_qco = default_qco
570
+ self.fusing_patterns = []
571
+ self.is_simd_padding = False
572
+
573
+ def get_config_options_by_operators_set(self,
574
+ operators_set_name: str) -> QuantizationConfigOptions:
575
+ """
576
+ Get the QuantizationConfigOptions of a OperatorsSet by the OperatorsSet name.
577
+ If the name is not in the model, the default QuantizationConfigOptions is returned.
578
+
579
+ Args:
580
+ operators_set_name: Name of OperatorsSet to get.
581
+
582
+ Returns:
583
+ QuantizationConfigOptions to use for ops in OperatorsSet named operators_set_name.
584
+ """
585
+ for op_set in self.operator_set:
586
+ if operators_set_name == op_set.name:
587
+ return op_set.qc_options
588
+ return self.default_qco
589
+
590
+ def get_default_op_quantization_config(self) -> OpQuantizationConfig:
591
+ """
592
+
593
+ Returns: The default OpQuantizationConfig of the TargetPlatformModel.
594
+
595
+ """
596
+ assert len(self.default_qco.quantization_config_list) == 1, \
597
+ f'Default quantization configuration options must contain only one option,' \
598
+ f' but found {len(get_current_tp_model().default_qco.quantization_config_list)} configurations.'
599
+ return self.default_qco.quantization_config_list[0]
600
+
601
+ def is_opset_in_model(self,
602
+ opset_name: str) -> bool:
603
+ """
604
+ Check whether an operators set is defined in the model or not.
605
+
606
+ Args:
607
+ opset_name: Operators set name to check.
608
+
609
+ Returns:
610
+ Whether an operators set is defined in the model or not.
611
+ """
612
+ return opset_name in [x.name for x in self.operator_set]
613
+
614
+ def get_opset_by_name(self,
615
+ opset_name: str) -> OperatorsSetBase:
616
+ """
617
+ Get an OperatorsSet object from the model by its name.
618
+ If name is not in the model - None is returned.
619
+
620
+ Args:
621
+ opset_name: OperatorsSet name to retrieve.
622
+
623
+ Returns:
624
+ OperatorsSet object with the name opset_name, or None if opset_name is not in the model.
625
+ """
626
+
627
+ opset_list = [x for x in self.operator_set if x.name == opset_name]
628
+ assert len(opset_list) <= 1, f'Found more than one OperatorsSet in' \
629
+ f' TargetPlatformModel with the name {opset_name}. ' \
630
+ f'OperatorsSet name must be unique.'
631
+ if len(opset_list) == 0: # opset_name is not in the model.
632
+ return None
633
+
634
+ return opset_list[0] # There's one opset with that name
635
+
636
+ def append_component(self,
637
+ tp_model_component: TargetPlatformModelComponent):
638
+ """
639
+ Attach a TargetPlatformModel component to the model. Components can be for example:
640
+ Fusing, OperatorsSet, etc.
641
+
642
+ Args:
643
+ tp_model_component: Component to attach to the model.
644
+
645
+ """
646
+ if isinstance(tp_model_component, Fusing):
647
+ self.fusing_patterns.append(tp_model_component)
648
+ elif isinstance(tp_model_component, OperatorsSetBase):
649
+ self.operator_set.append(tp_model_component)
650
+ else: # pragma: no cover
651
+ Logger.critical(f'Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.')
652
+
653
+ def __enter__(self):
654
+ """
655
+ Start defining the TargetPlatformModel using 'with'.
656
+
657
+ Returns: Initialized TargetPlatformModel object.
658
+
659
+ """
660
+ _current_tp_model.set(self)
661
+ return self
662
+
663
+ def __exit__(self, exc_type, exc_value, tb):
664
+ """
665
+ Finish defining the TargetPlatformModel at the end of the 'with' clause.
666
+ Returns the final and immutable TargetPlatformModel instance.
667
+ """
668
+
669
+ if exc_value is not None:
670
+ print(exc_value, exc_value.args)
671
+ raise exc_value
672
+ self.__validate_model() # Assert that model is valid.
673
+ _current_tp_model.reset()
674
+ self.initialized_done() # Make model immutable.
675
+ return self
676
+
677
+ def __validate_model(self):
678
+ """
679
+
680
+ Assert model is valid.
681
+ Model is invalid if, for example, it contains multiple operator sets with the same name,
682
+ as their names should be unique.
683
+
684
+ """
685
+ opsets_names = [op.name for op in self.operator_set]
686
+ if len(set(opsets_names)) != len(opsets_names):
687
+ Logger.critical(f'Operator Sets must have unique names.')
688
+
689
+ def get_default_config(self) -> OpQuantizationConfig:
690
+ """
691
+
692
+ Returns:
693
+
694
+ """
695
+ assert len(self.default_qco.quantization_config_list) == 1, \
696
+ f'Default quantization configuration options must contain only one option,' \
697
+ f' but found {len(self.default_qco.quantization_config_list)} configurations.'
698
+ return self.default_qco.quantization_config_list[0]
699
+
700
+ def get_info(self) -> Dict[str, Any]:
701
+ """
702
+
703
+ Returns: Dictionary that summarizes the TargetPlatformModel properties (for display purposes).
704
+
705
+ """
706
+ return {"Model name": self.name,
707
+ "Default quantization config": self.get_default_config().get_info(),
708
+ "Operators sets": [o.get_info() for o in self.operator_set],
709
+ "Fusing patterns": [f.get_info() for f in self.fusing_patterns]
710
+ }
711
+
712
+ def show(self):
713
+ """
714
+
715
+ Display the TargetPlatformModel.
716
+
717
+ """
718
+ pprint.pprint(self.get_info(), sort_dicts=False)
719
+
720
+ def set_simd_padding(self,
721
+ is_simd_padding: bool):
722
+ """
723
+ Set flag is_simd_padding to indicate whether this TP model defines
724
+ that padding due to SIMD constrains occurs.
725
+
726
+ Args:
727
+ is_simd_padding: Whether this TP model defines that padding due to SIMD constrains occurs.
728
+
729
+ """
730
+ self.is_simd_padding = is_simd_padding
731
+
@@ -13,13 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.target_platform_capabilities.target_platform.fusing import Fusing
17
16
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import AttributeFilter
18
17
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater, LayerFilterParams, OperationsToLayers, get_current_tpc
19
- from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import get_default_quantization_config_options, TargetPlatformModel
20
- from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import \
21
- OpQuantizationConfig, QuantizationConfigOptions, AttributeQuantizationConfig, Signedness
22
- from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSet, OperatorSetConcat
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import get_default_quantization_config_options
19
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, OperatorsSet, \
20
+ OperatorSetConcat, Signedness, AttributeQuantizationConfig, OpQuantizationConfig, QuantizationConfigOptions, Fusing
23
21
 
24
22
  from mct_quantizers import QuantizationMethod
25
23
 
@@ -13,19 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import pprint
17
- from typing import Any, Dict
18
-
19
- from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import _current_tp_model, \
20
- get_current_tp_model
21
- from model_compression_toolkit.target_platform_capabilities.target_platform.fusing import Fusing
22
- from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import \
23
- TargetPlatformModelComponent
24
- from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import OpQuantizationConfig, \
25
- QuantizationConfigOptions
26
- from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSetBase
27
- from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
28
- from model_compression_toolkit.logger import Logger
16
+ from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import get_current_tp_model
17
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions
29
18
 
30
19
 
31
20
  def get_default_quantization_config_options() -> QuantizationConfigOptions:
@@ -39,204 +28,3 @@ def get_default_quantization_config_options() -> QuantizationConfigOptions:
39
28
  return get_current_tp_model().default_qco
40
29
 
41
30
 
42
- def get_default_quantization_config():
43
- """
44
-
45
- Returns: The default OpQuantizationConfig of the model. This is the OpQuantizationConfig
46
- to use when a layer's options is queried and it wasn't specified in the TargetPlatformCapabilities.
47
- This OpQuantizationConfig is the single option in the default QuantizationConfigOptions.
48
-
49
- """
50
-
51
- return get_current_tp_model().get_default_op_quantization_config()
52
-
53
-
54
- class TargetPlatformModel(ImmutableClass):
55
- """
56
- Modeling of the hardware the quantized model will use during inference.
57
- The model contains definition of operators, quantization configurations of them, and
58
- fusing patterns so that multiple operators will be combined into a single operator.
59
- """
60
-
61
- def __init__(self,
62
- default_qco: QuantizationConfigOptions,
63
- add_metadata: bool = False,
64
- name="default_tp_model"):
65
- """
66
-
67
- Args:
68
- default_qco (QuantizationConfigOptions): Default QuantizationConfigOptions to use for operators that their QuantizationConfigOptions are not defined in the model.
69
- add_metadata (bool): Whether to add metadata to the model or not.
70
- name (str): Name of the model.
71
- """
72
-
73
- super().__init__()
74
- self.add_metadata = add_metadata
75
- self.name = name
76
- self.operator_set = []
77
- assert isinstance(default_qco, QuantizationConfigOptions)
78
- assert len(default_qco.quantization_config_list) == 1, \
79
- f'Default QuantizationConfigOptions must contain only one option'
80
- self.default_qco = default_qco
81
- self.fusing_patterns = []
82
- self.is_simd_padding = False
83
-
84
- def get_config_options_by_operators_set(self,
85
- operators_set_name: str) -> QuantizationConfigOptions:
86
- """
87
- Get the QuantizationConfigOptions of a OperatorsSet by the OperatorsSet name.
88
- If the name is not in the model, the default QuantizationConfigOptions is returned.
89
-
90
- Args:
91
- operators_set_name: Name of OperatorsSet to get.
92
-
93
- Returns:
94
- QuantizationConfigOptions to use for ops in OperatorsSet named operators_set_name.
95
- """
96
- for op_set in self.operator_set:
97
- if operators_set_name == op_set.name:
98
- return op_set.qc_options
99
- return self.default_qco
100
-
101
- def get_default_op_quantization_config(self) -> OpQuantizationConfig:
102
- """
103
-
104
- Returns: The default OpQuantizationConfig of the TargetPlatformModel.
105
-
106
- """
107
- assert len(self.default_qco.quantization_config_list) == 1, \
108
- f'Default quantization configuration options must contain only one option,' \
109
- f' but found {len(get_current_tp_model().default_qco.quantization_config_list)} configurations.'
110
- return self.default_qco.quantization_config_list[0]
111
-
112
- def is_opset_in_model(self,
113
- opset_name: str) -> bool:
114
- """
115
- Check whether an operators set is defined in the model or not.
116
-
117
- Args:
118
- opset_name: Operators set name to check.
119
-
120
- Returns:
121
- Whether an operators set is defined in the model or not.
122
- """
123
- return opset_name in [x.name for x in self.operator_set]
124
-
125
- def get_opset_by_name(self,
126
- opset_name: str) -> OperatorsSetBase:
127
- """
128
- Get an OperatorsSet object from the model by its name.
129
- If name is not in the model - None is returned.
130
-
131
- Args:
132
- opset_name: OperatorsSet name to retrieve.
133
-
134
- Returns:
135
- OperatorsSet object with the name opset_name, or None if opset_name is not in the model.
136
- """
137
-
138
- opset_list = [x for x in self.operator_set if x.name == opset_name]
139
- assert len(opset_list) <= 1, f'Found more than one OperatorsSet in' \
140
- f' TargetPlatformModel with the name {opset_name}. ' \
141
- f'OperatorsSet name must be unique.'
142
- if len(opset_list) == 0: # opset_name is not in the model.
143
- return None
144
-
145
- return opset_list[0] # There's one opset with that name
146
-
147
- def append_component(self,
148
- tp_model_component: TargetPlatformModelComponent):
149
- """
150
- Attach a TargetPlatformModel component to the model. Components can be for example:
151
- Fusing, OperatorsSet, etc.
152
-
153
- Args:
154
- tp_model_component: Component to attach to the model.
155
-
156
- """
157
- if isinstance(tp_model_component, Fusing):
158
- self.fusing_patterns.append(tp_model_component)
159
- elif isinstance(tp_model_component, OperatorsSetBase):
160
- self.operator_set.append(tp_model_component)
161
- else: # pragma: no cover
162
- Logger.critical(f'Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.')
163
-
164
- def __enter__(self):
165
- """
166
- Start defining the TargetPlatformModel using 'with'.
167
-
168
- Returns: Initialized TargetPlatformModel object.
169
-
170
- """
171
- _current_tp_model.set(self)
172
- return self
173
-
174
- def __exit__(self, exc_type, exc_value, tb):
175
- """
176
- Finish defining the TargetPlatformModel at the end of the 'with' clause.
177
- Returns the final and immutable TargetPlatformModel instance.
178
- """
179
-
180
- if exc_value is not None:
181
- print(exc_value, exc_value.args)
182
- raise exc_value
183
- self.__validate_model() # Assert that model is valid.
184
- _current_tp_model.reset()
185
- self.initialized_done() # Make model immutable.
186
- return self
187
-
188
- def __validate_model(self):
189
- """
190
-
191
- Assert model is valid.
192
- Model is invalid if, for example, it contains multiple operator sets with the same name,
193
- as their names should be unique.
194
-
195
- """
196
- opsets_names = [op.name for op in self.operator_set]
197
- if len(set(opsets_names)) != len(opsets_names):
198
- Logger.critical(f'Operator Sets must have unique names.')
199
-
200
- def get_default_config(self) -> OpQuantizationConfig:
201
- """
202
-
203
- Returns:
204
-
205
- """
206
- assert len(self.default_qco.quantization_config_list) == 1, \
207
- f'Default quantization configuration options must contain only one option,' \
208
- f' but found {len(self.default_qco.quantization_config_list)} configurations.'
209
- return self.default_qco.quantization_config_list[0]
210
-
211
- def get_info(self) -> Dict[str, Any]:
212
- """
213
-
214
- Returns: Dictionary that summarizes the TargetPlatformModel properties (for display purposes).
215
-
216
- """
217
- return {"Model name": self.name,
218
- "Default quantization config": self.get_default_config().get_info(),
219
- "Operators sets": [o.get_info() for o in self.operator_set],
220
- "Fusing patterns": [f.get_info() for f in self.fusing_patterns]
221
- }
222
-
223
- def show(self):
224
- """
225
-
226
- Display the TargetPlatformModel.
227
-
228
- """
229
- pprint.pprint(self.get_info(), sort_dicts=False)
230
-
231
- def set_simd_padding(self,
232
- is_simd_padding: bool):
233
- """
234
- Set flag is_simd_padding to indicate whether this TP model defines
235
- that padding due to SIMD constrains occurs.
236
-
237
- Args:
238
- is_simd_padding: Whether this TP model defines that padding due to SIMD constrains occurs.
239
-
240
- """
241
- self.is_simd_padding = is_simd_padding
242
-
@@ -18,8 +18,7 @@ from typing import List, Any, Dict
18
18
  from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
20
20
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
21
- from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorSetConcat, \
22
- OperatorsSetBase
21
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorsSetBase, OperatorSetConcat
23
22
  from model_compression_toolkit import DefaultDict
24
23
 
25
24