mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -18,8 +18,6 @@ import numpy as np
18
18
 
19
19
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
20
20
  CandidateNodeQuantizationConfig
21
- from model_compression_toolkit.core.common.quantization.quantization_fn_selection import (get_activation_quantization_fn,
22
- get_weights_quantization_fn)
23
21
 
24
22
 
25
23
  def verify_candidates_descending_order(node_q_cfg: List[CandidateNodeQuantizationConfig],
@@ -79,21 +77,20 @@ def init_quantized_weights(node_q_cfg: List[CandidateNodeQuantizationConfig],
79
77
  quantized_weights = []
80
78
  for qc in node_q_cfg:
81
79
  qc_weights_attr = qc.weights_quantization_cfg.get_attr_config(kernel_attr)
82
- weights_quantization_fn = get_weights_quantization_fn(qc_weights_attr.weights_quantization_method)
83
- q_weight = weights_quantization_fn(float_weights,
84
- qc_weights_attr.weights_n_bits,
85
- True,
86
- qc_weights_attr.weights_quantization_params,
87
- qc_weights_attr.weights_per_channel_threshold,
88
- qc_weights_attr.weights_channels_axis[0]) # output channel axis
80
+ q_weight = qc_weights_attr.weights_quantization_fn(float_weights,
81
+ qc_weights_attr.weights_n_bits,
82
+ True,
83
+ qc_weights_attr.weights_quantization_params,
84
+ qc_weights_attr.weights_per_channel_threshold,
85
+ qc_weights_attr.weights_channels_axis[
86
+ 0]) # output channel axis
89
87
 
90
88
  quantized_weights.append(fw_tensor_convert_func(q_weight))
91
89
 
92
90
  return quantized_weights
93
91
 
94
92
 
95
- def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig],
96
- get_activation_quantization_fn_factory: Callable) -> List:
93
+ def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]) -> List:
97
94
  """
98
95
  Builds a list of quantizers for each of the bitwidth candidates for activation quantization,
99
96
  to be stored and used during MP search.
@@ -101,7 +98,6 @@ def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]
101
98
  Args:
102
99
  node_q_cfg: Quantization configuration candidates of the node that generated the layer that will
103
100
  use this quantizer.
104
- get_activation_quantization_fn_factory: activation quantization functions factory.
105
101
 
106
102
  Returns: a list of activation quantizers - for each bitwidth and layer's attribute to be quantized.
107
103
  """
@@ -109,7 +105,6 @@ def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]
109
105
  activation_quantizers = []
110
106
  for index, qc in enumerate(node_q_cfg):
111
107
  q_activation = node_q_cfg[index].activation_quantization_cfg
112
- quantizer = get_activation_quantization_fn(q_activation, get_activation_quantization_fn_factory)
113
- activation_quantizers.append(quantizer)
108
+ activation_quantizers.append(q_activation.quantize_node_output)
114
109
 
115
110
  return activation_quantizers
@@ -12,12 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from model_compression_toolkit.core import ResourceUtilization
15
+ import numpy as np
16
+
17
+ from model_compression_toolkit.core import ResourceUtilization, FrameworkInfo
16
18
  from model_compression_toolkit.core.common import Graph
19
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
20
+ FrameworkQuantizationCapabilities
17
21
 
18
22
 
19
23
  def filter_candidates_for_mixed_precision(graph: Graph,
20
- target_resource_utilization: ResourceUtilization):
24
+ target_resource_utilization: ResourceUtilization,
25
+ fw_info: FrameworkInfo,
26
+ fqc: FrameworkQuantizationCapabilities):
21
27
  """
22
28
  Filters out candidates in case of mixed precision search for only weights or activation compression.
23
29
  For instance, if running only weights compression - filters out candidates of activation configurable nodes
@@ -29,6 +35,9 @@ def filter_candidates_for_mixed_precision(graph: Graph,
29
35
  Args:
30
36
  graph: A graph representation of the model to be quantized.
31
37
  target_resource_utilization: The resource utilization of the target device.
38
+ fw_info: fw_info: Information needed for quantization about the specific framework.
39
+ fqc: FrameworkQuantizationCapabilities object that describes the desired inference target platform.
40
+
32
41
  """
33
42
 
34
43
  tru = target_resource_utilization
@@ -40,21 +49,21 @@ def filter_candidates_for_mixed_precision(graph: Graph,
40
49
  # filter out candidates activation only configurable node
41
50
  activation_configurable_nodes = [n for n in graph.get_activation_configurable_nodes()]
42
51
  for n in activation_configurable_nodes:
43
- base_cfg_nbits = n.quantization_cfg.base_quantization_cfg.activation_quantization_cfg.activation_n_bits
44
- filtered_cfgs = [c for c in n.candidates_quantization_cfg if
52
+ base_cfg_nbits = n.get_qco(fqc).base_config.activation_n_bits
53
+ filtered_conf = [c for c in n.candidates_quantization_cfg if
45
54
  c.activation_quantization_cfg.enable_activation_quantization and
46
55
  c.activation_quantization_cfg.activation_n_bits == base_cfg_nbits]
47
56
 
48
- n.quantization_cfg.candidates_quantization_cfg = filtered_cfgs
57
+ n.candidates_quantization_cfg = filtered_conf
49
58
 
50
59
  elif tru.activation_restricted() and not tru.weight_restricted():
51
60
  # Running mixed precision for activation compression only -
52
61
  # filter out candidates weights only configurable node
53
- weight_configurable_nodes = [n for n in graph.get_weights_configurable_nodes()]
62
+ weight_configurable_nodes = [n for n in graph.get_weights_configurable_nodes(fw_info)]
54
63
  for n in weight_configurable_nodes:
55
- base_cfg_nbits = (n.quantization_cfg.base_quantization_cfg.weights_quantization_cfg.
56
- get_attr_config(n.kernel_attr).weights_n_bits)
57
- filtered_cfgs = [c for c in n.candidates_quantization_cfg if
58
- c.weights_quantization_cfg.get_attr_config(n.kernel_attr).enable_weights_quantization and
59
- c.weights_quantization_cfg.get_attr_config(n.kernel_attr).weights_n_bits == base_cfg_nbits]
60
- n.quantization_cfg.candidates_quantization_cfg = filtered_cfgs
64
+ kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0]
65
+ base_cfg_nbits = n.get_qco(fqc).base_config.attr_weights_configs_mapping[kernel_attr].weights_n_bits
66
+ filtered_conf = [c for c in n.candidates_quantization_cfg if
67
+ c.weights_quantization_cfg.get_attr_config(kernel_attr).enable_weights_quantization and
68
+ c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == base_cfg_nbits]
69
+ n.candidates_quantization_cfg = filtered_conf
@@ -30,10 +30,11 @@ from model_compression_toolkit.core.common.quantization.node_quantization_config
30
30
  class MixedPrecisionRUHelper:
31
31
  """ Helper class for resource utilization computations for mixed precision optimization. """
32
32
 
33
- def __init__(self, graph: Graph, fw_impl: FrameworkImplementation):
33
+ def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation):
34
34
  self.graph = graph
35
+ self.fw_info = fw_info
35
36
  self.fw_impl = fw_impl
36
- self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl)
37
+ self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
37
38
 
38
39
  def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: Dict[BaseNode, int]) -> Dict[RUTarget, np.ndarray]:
39
40
  """
@@ -35,6 +35,7 @@ class BitWidthSearchMethod(Enum):
35
35
 
36
36
 
37
37
  def search_bit_width(graph: Graph,
38
+ fw_info: FrameworkInfo,
38
39
  fw_impl: FrameworkImplementation,
39
40
  target_resource_utilization: ResourceUtilization,
40
41
  mp_config: MixedPrecisionQuantizationConfig,
@@ -51,6 +52,7 @@ def search_bit_width(graph: Graph,
51
52
 
52
53
  Args:
53
54
  graph: Graph to search a MP configuration for.
55
+ fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
54
56
  fw_impl: FrameworkImplementation object with specific framework methods implementation.
55
57
  target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
56
58
  mp_config: Mixed-precision quantization configuration.
@@ -77,7 +79,7 @@ def search_bit_width(graph: Graph,
77
79
 
78
80
  # Set Sensitivity Evaluator for MP search. It should always work with the original MP graph,
79
81
  # even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
80
- se = SensitivityEvaluation(graph, mp_config, representative_data_gen=representative_data_gen,
82
+ se = SensitivityEvaluation(graph, mp_config, representative_data_gen=representative_data_gen, fw_info=fw_info,
81
83
  fw_impl=fw_impl, disable_activation_for_metric=disable_activation_for_metric,
82
84
  hessian_info_service=hessian_info_service)
83
85
 
@@ -91,6 +93,7 @@ def search_bit_width(graph: Graph,
91
93
 
92
94
  # Search manager and LP are highly coupled, so LP search method was moved inside search manager.
93
95
  search_manager = MixedPrecisionSearchManager(graph,
96
+ fw_info=fw_info,
94
97
  fw_impl=fw_impl,
95
98
  sensitivity_evaluator=se,
96
99
  target_resource_utilization=target_resource_utilization,
@@ -102,6 +105,6 @@ def search_bit_width(graph: Graph,
102
105
  if mp_config.refine_mp_solution:
103
106
  nodes_bit_cfg = greedy_solution_refinement_procedure(nodes_bit_cfg, search_manager, target_resource_utilization)
104
107
 
105
- topo_bit_cfg = [nodes_bit_cfg[n] for n in graph.get_configurable_sorted_nodes()]
108
+ topo_bit_cfg = [nodes_bit_cfg[n] for n in graph.get_configurable_sorted_nodes(fw_info)]
106
109
  assert len(topo_bit_cfg) == len(nodes_bit_cfg)
107
110
  return topo_bit_cfg
@@ -53,6 +53,7 @@ class MixedPrecisionSearchManager:
53
53
 
54
54
  def __init__(self,
55
55
  graph: Graph,
56
+ fw_info: FrameworkInfo,
56
57
  fw_impl: FrameworkImplementation,
57
58
  sensitivity_evaluator: SensitivityEvaluation,
58
59
  target_resource_utilization: ResourceUtilization,
@@ -61,12 +62,14 @@ class MixedPrecisionSearchManager:
61
62
 
62
63
  Args:
63
64
  graph: Graph to search for its MP configuration.
65
+ fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
64
66
  fw_impl: FrameworkImplementation object with specific framework methods implementation.
65
67
  sensitivity_evaluator: A SensitivityEvaluation which provides a function that evaluates the sensitivity of
66
68
  a bit-width configuration for the MP model.
67
69
  target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
68
70
  """
69
71
 
72
+ self.fw_info = fw_info
70
73
  self.fw_impl = fw_impl
71
74
 
72
75
  self.original_graph = graph
@@ -78,12 +81,12 @@ class MixedPrecisionSearchManager:
78
81
  self.target_resource_utilization = target_resource_utilization
79
82
  self.mp_config = mp_config
80
83
 
81
- self.mp_topo_configurable_nodes = self.mp_graph.get_configurable_sorted_nodes()
84
+ self.mp_topo_configurable_nodes = self.mp_graph.get_configurable_sorted_nodes(fw_info)
82
85
 
83
86
  self.ru_targets = target_resource_utilization.get_restricted_targets()
84
- self.orig_graph_ru_helper = MixedPrecisionRUHelper(self.original_graph, fw_impl)
87
+ self.orig_graph_ru_helper = MixedPrecisionRUHelper(self.original_graph, fw_info, fw_impl)
85
88
 
86
- self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config()
89
+ self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config(fw_info)
87
90
 
88
91
  self.config_reconstructor = None
89
92
  orig_min_config = self.min_ru_config
@@ -124,9 +124,10 @@ class ResourceUtilizationCalculator:
124
124
  unexpected_qc_error = 'Custom quantization configuration is not expected for non-custom bit mode.'
125
125
  unexpected_qc_nodes_error = 'Custom quantization configuration contains unexpected node names.'
126
126
 
127
- def __init__(self, graph: Graph, fw_impl: FrameworkImplementation):
127
+ def __init__(self, graph: Graph, fw_impl: FrameworkImplementation, fw_info: FrameworkInfo):
128
128
  self.graph = graph
129
129
  self.fw_impl = fw_impl
130
+ self.fw_info = fw_info
130
131
 
131
132
  # Currently we go over the full graph even if utilization won't be requested for all nodes.
132
133
  # We could fill the cache on the fly only for requested nodes, but it's probably negligible.
@@ -543,10 +544,14 @@ class ResourceUtilizationCalculator:
543
544
  self._validate_custom_qcs(w_qc, bitwidth_mode)
544
545
 
545
546
  # check if the node has kernel
546
- if not n.kernel_attr:
547
+ kernel_attrs = self.fw_info.get_kernel_op_attributes(n.type)
548
+ if len(kernel_attrs) > 1: # pragma: no cover
549
+ raise NotImplementedError('Multiple kernel attributes are not supported for BOPS computation.')
550
+ if not kernel_attrs or not kernel_attrs[0]:
547
551
  return 0
548
552
 
549
- node_mac = self.fw_impl.get_node_mac_operations(n)
553
+ kernel_attr = kernel_attrs[0]
554
+ node_mac = self.fw_impl.get_node_mac_operations(n, self.fw_info)
550
555
  if node_mac == 0:
551
556
  return node_mac
552
557
 
@@ -554,12 +559,12 @@ class ResourceUtilizationCalculator:
554
559
  assert len(prev_nodes) == 1, f'Weights node is expected to have exactly one input, {n} has {len(prev_nodes)}'
555
560
  a_node = prev_nodes[0]
556
561
  if (target_criterion == TargetInclusionCriterion.AnyQuantized and
557
- not (a_node.is_activation_quantization_enabled() or n.is_weights_quantization_enabled(n.kernel_attr))):
562
+ not (a_node.is_activation_quantization_enabled() or n.is_weights_quantization_enabled(kernel_attr))):
558
563
  return 0
559
564
 
560
565
  act_qc = self._extract_qc(a_node, act_qcs)
561
566
  a_nbits = self._get_activation_nbits(a_node, bitwidth_mode, act_qc)
562
- w_nbits = self._get_weight_nbits(n, n.kernel_attr, bitwidth_mode, w_qc)
567
+ w_nbits = self._get_weight_nbits(n, kernel_attr, bitwidth_mode, w_qc)
563
568
  node_bops = a_nbits * w_nbits * node_mac
564
569
  return node_bops
565
570
 
@@ -15,7 +15,7 @@
15
15
  import copy
16
16
  from typing import Callable, Any
17
17
 
18
- from model_compression_toolkit.core import ResourceUtilization, CoreConfig, QuantizationErrorMethod
18
+ from model_compression_toolkit.core import FrameworkInfo, ResourceUtilization, CoreConfig, QuantizationErrorMethod
19
19
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
20
20
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
21
21
  ResourceUtilizationCalculator, BitwidthMode, TargetInclusionCriterion
@@ -27,6 +27,7 @@ def compute_resource_utilization_data(in_model: Any,
27
27
  representative_data_gen: Callable,
28
28
  core_config: CoreConfig,
29
29
  fqc: FrameworkQuantizationCapabilities,
30
+ fw_info: FrameworkInfo,
30
31
  fw_impl: FrameworkImplementation) -> ResourceUtilization:
31
32
  """
32
33
  Compute Resource Utilization of a model with the default single precision quantization.
@@ -38,6 +39,7 @@ def compute_resource_utilization_data(in_model: Any,
38
39
  core_config: CoreConfig containing parameters of how the model should be quantized.
39
40
  fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
40
41
  the attached framework operator's information.
42
+ fw_info: Information needed for quantization about the specific framework.
41
43
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
42
44
 
43
45
  Returns:
@@ -53,11 +55,12 @@ def compute_resource_utilization_data(in_model: Any,
53
55
  transformed_graph = graph_preparation_runner(in_model,
54
56
  representative_data_gen=representative_data_gen,
55
57
  quantization_config=core_config.quantization_config,
58
+ fw_info=fw_info,
56
59
  fw_impl=fw_impl,
57
60
  fqc=fqc,
58
61
  bit_width_config=core_config.bit_width_config,
59
62
  mixed_precision_enable=False,
60
63
  running_gptq=False)
61
64
 
62
- ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl)
65
+ ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl, fw_info)
63
66
  return ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused, BitwidthMode.QDefaultSP)
@@ -15,7 +15,7 @@
15
15
  import numpy as np
16
16
  from typing import runtime_checkable, Protocol, Callable, Any, List, Tuple
17
17
 
18
- from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, MpDistanceWeighting
18
+ from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig, MpDistanceWeighting
19
19
  from model_compression_toolkit.core.common import Graph, BaseNode
20
20
  from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
21
21
  HessianScoresGranularity
@@ -62,12 +62,15 @@ class DistanceMetricCalculator(MetricCalculator):
62
62
  graph: Graph,
63
63
  mp_config: MixedPrecisionQuantizationConfig,
64
64
  representative_data_gen: Callable,
65
+ fw_info: FrameworkInfo,
65
66
  fw_impl: Any,
66
67
  hessian_info_service: HessianInfoService = None):
67
68
  """
68
69
  Args:
69
70
  graph: Graph to search for its MP configuration.
70
71
  mp_config: MP Quantization configuration for how the graph should be quantized.
72
+ fw_info: FrameworkInfo object about the specific framework
73
+ (e.g., attributes of different layers' weights to quantize).
71
74
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
72
75
  representative_data_gen: Dataset used for getting batches for inference.
73
76
  hessian_info_service: HessianInfoService to fetch Hessian approximation information.
@@ -75,13 +78,14 @@ class DistanceMetricCalculator(MetricCalculator):
75
78
  self.graph = graph
76
79
  self.mp_config = mp_config
77
80
  self.representative_data_gen = representative_data_gen
81
+ self.fw_info = fw_info
78
82
  self.fw_impl = fw_impl
79
83
 
80
84
  if self.mp_config.distance_weighting_method == MpDistanceWeighting.HESSIAN:
81
85
  assert hessian_info_service is not None, ('Expected HessianInfoService object to be passed with Hessian '
82
86
  'distance weighting')
83
87
 
84
- self.sorted_configurable_nodes_names = graph.get_configurable_sorted_nodes_names()
88
+ self.sorted_configurable_nodes_names = graph.get_configurable_sorted_nodes_names(self.fw_info)
85
89
 
86
90
  # Get interest points and output points set for distance measurement and set other helper datasets
87
91
  # We define a separate set of output nodes of the model for the purpose of sensitivity computation.
@@ -392,8 +396,9 @@ class DistanceMetricCalculator(MetricCalculator):
392
396
  """
393
397
 
394
398
  return [n.node for n in graph.get_outputs()
395
- if (n.node.kernel_attr and n.node.is_weights_quantization_enabled(n.node.kernel_attr))
396
- or n.node.is_activation_quantization_enabled()]
399
+ if (graph.fw_info.is_kernel_op(n.node.type) and
400
+ n.node.is_weights_quantization_enabled(graph.fw_info.get_kernel_op_attributes(n.node.type)[0])) or
401
+ n.node.is_activation_quantization_enabled()]
397
402
 
398
403
  @staticmethod
399
404
  def bound_num_interest_points(sorted_ip_list: List[BaseNode], num_ip_factor: float) -> List[BaseNode]:
@@ -38,6 +38,7 @@ class SensitivityEvaluation:
38
38
  graph: Graph,
39
39
  mp_config: MixedPrecisionQuantizationConfig,
40
40
  representative_data_gen: Callable,
41
+ fw_info: FrameworkInfo,
41
42
  fw_impl: Any,
42
43
  disable_activation_for_metric: bool = False,
43
44
  hessian_info_service: HessianInfoService = None
@@ -45,6 +46,8 @@ class SensitivityEvaluation:
45
46
  """
46
47
  Args:
47
48
  graph: Graph to search for its MP configuration.
49
+ fw_info: FrameworkInfo object about the specific framework
50
+ (e.g., attributes of different layers' weights to quantize).
48
51
  mp_config: MP Quantization configuration for how the graph should be quantized.
49
52
  representative_data_gen: Dataset used for getting batches for inference.
50
53
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
@@ -54,13 +57,14 @@ class SensitivityEvaluation:
54
57
  """
55
58
  self.mp_config = mp_config
56
59
  self.representative_data_gen = representative_data_gen
60
+ self.fw_info = fw_info
57
61
  self.fw_impl = fw_impl
58
62
 
59
63
  if self.mp_config.custom_metric_fn:
60
64
  self.metric_calculator = CustomMetricCalculator(graph, self.mp_config.custom_metric_fn)
61
65
  else:
62
66
  self.metric_calculator = DistanceMetricCalculator(graph, mp_config, representative_data_gen,
63
- fw_impl=fw_impl,
67
+ fw_info=fw_info, fw_impl=fw_impl,
64
68
  hessian_info_service=hessian_info_service)
65
69
 
66
70
  # Build a mixed-precision model which can be configured to use different bitwidth in different layers.
@@ -107,7 +111,8 @@ class SensitivityEvaluation:
107
111
 
108
112
  model_mp, _, conf_node2layers = self.fw_impl.model_builder(evaluation_graph,
109
113
  mode=ModelBuilderMode.MIXEDPRECISION,
110
- append2output=outputs)
114
+ append2output=outputs,
115
+ fw_info=self.fw_info)
111
116
 
112
117
  # Disable all configurable quantizers. They will be activated one at a time during sensitivity evaluation.
113
118
  for layer in itertools.chain(*conf_node2layers.values()):
@@ -50,11 +50,8 @@ def greedy_solution_refinement_procedure(mp_solution: Dict[BaseNode, int],
50
50
  if target_resource_utilization.bops_restricted():
51
51
  Logger.info(f'Target resource utilization constraint BOPs - Skipping MP greedy solution refinement')
52
52
  return mp_solution
53
- assert search_manager.using_virtual_graph is False
54
53
 
55
- tru = target_resource_utilization
56
- activation_restricted = tru.activation_restricted() or tru.total_mem_restricted() or tru.bops_restricted()
57
- weights_restricted = tru.weight_restricted() or tru.total_mem_restricted() or tru.bops_restricted()
54
+ assert search_manager.using_virtual_graph is False
58
55
 
59
56
  new_solution = mp_solution.copy()
60
57
  changed = True
@@ -65,7 +62,7 @@ def greedy_solution_refinement_procedure(mp_solution: Dict[BaseNode, int],
65
62
  nodes_next_candidate = {}
66
63
 
67
64
  for node in search_manager.mp_topo_configurable_nodes:
68
- if new_solution[node] == node.find_max_candidate_index():
65
+ if new_solution[node] == 0:
69
66
  # layer has max config in the given solution, nothing to optimize
70
67
  continue
71
68
 
@@ -74,8 +71,9 @@ def greedy_solution_refinement_procedure(mp_solution: Dict[BaseNode, int],
74
71
  # only weights kernel attribute is quantized with weights mixed precision
75
72
  valid_candidates = _get_valid_candidates_indices(node_candidates,
76
73
  new_solution[node],
77
- activation_restricted,
78
- weights_restricted)
74
+ target_resource_utilization.activation_restricted(),
75
+ target_resource_utilization.weight_restricted()
76
+ )
79
77
 
80
78
  # Create a list of ru for the valid candidates.
81
79
  updated_ru = []
@@ -18,7 +18,7 @@ import numpy as np
18
18
  from typing import List, Union, Tuple, Optional
19
19
 
20
20
  from networkx.algorithms.dag import topological_sort
21
- from model_compression_toolkit.core import QuantizationErrorMethod
21
+ from model_compression_toolkit.core import FrameworkInfo, QuantizationErrorMethod
22
22
  from model_compression_toolkit.core import common
23
23
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
24
24
  from model_compression_toolkit.core.common.graph.base_graph import Graph
@@ -30,6 +30,7 @@ from model_compression_toolkit.core.common.collectors.statistics_collector impor
30
30
 
31
31
 
32
32
  def create_stats_collector_for_node(node: common.BaseNode,
33
+ fw_info: FrameworkInfo,
33
34
  quant_node_in_fln: bool) -> BaseStatsCollector:
34
35
  """
35
36
  Gets a node and a groups list and create and return a statistics collector for a node
@@ -38,7 +39,7 @@ def create_stats_collector_for_node(node: common.BaseNode,
38
39
 
39
40
  Args:
40
41
  node: Node to create its statistics collector.
41
- quant_node_in_fln: Whether the node should be quantized as part of an FLN.
42
+ fw_info: Information relevant to a specific framework about what is out channel axis (for statistics per-channel).
42
43
 
43
44
  Returns:
44
45
  Statistics collector for statistics collection for the node.
@@ -47,7 +48,7 @@ def create_stats_collector_for_node(node: common.BaseNode,
47
48
  if node.is_activation_quantization_enabled() or quant_node_in_fln:
48
49
  min_output = getattr(node.prior_info, 'min_output', None)
49
50
  max_output = getattr(node.prior_info, 'max_output', None)
50
- stats_collector = common.StatsCollector(out_channel_axis=node.out_channel_axis,
51
+ stats_collector = common.StatsCollector(out_channel_axis=fw_info.out_channel_axis_mapping.get(node.type),
51
52
  init_min_value=min_output,
52
53
  init_max_value=max_output)
53
54
  else:
@@ -58,20 +59,20 @@ def create_stats_collector_for_node(node: common.BaseNode,
58
59
 
59
60
  def create_tensor2node(graph: common.Graph,
60
61
  node: common.BaseNode,
61
- next_node_output_channel_axis: int):
62
+ fw_info: common.FrameworkInfo):
62
63
  """
63
64
  Force statistic collector creation and assignment for a node.
64
65
  Args:
65
66
  graph: Graph of the node (for retrieving the current tensor).
66
67
  node: Node to create a tensor for.
67
- next_node_output_channel_axis: channel output axis of next node.
68
+ fw_info: Specific framework information (for example, output channels index).
68
69
 
69
70
  """
70
71
  current_sc = graph.get_out_stats_collector(node)
71
72
  is_list_nostat_collectors = isinstance(current_sc, list) and len(
72
73
  [sc for sc in current_sc if not isinstance(sc, common.NoStatsCollector)]) == 0
73
74
  if isinstance(current_sc, common.NoStatsCollector) or current_sc is None or is_list_nostat_collectors:
74
- stats_collector = common.StatsCollector(next_node_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis)
75
+ stats_collector = common.StatsCollector(fw_info.out_channel_axis_mapping.get(node.type))
75
76
  graph.set_out_stats_collector_to_node(node, stats_collector)
76
77
 
77
78
 
@@ -139,6 +140,7 @@ class ModelCollector:
139
140
 
140
141
  def __init__(self, graph: Graph,
141
142
  fw_impl: FrameworkImplementation,
143
+ fw_info: FrameworkInfo,
142
144
  hessian_info_service: HessianInfoService = None,
143
145
  qc: common.QuantizationConfig = common.DEFAULTCONFIG):
144
146
  """
@@ -147,10 +149,12 @@ class ModelCollector:
147
149
  Args:
148
150
  graph: Graph to build a model from it.
149
151
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
152
+ fw_info: FrameworkInfo object with a specific framework information.
150
153
  qc: Quantization configuration containing parameters for how the graph should be quantized.
151
154
  """
152
155
 
153
156
  self.fw_impl = fw_impl
157
+ self.fw_info = fw_info
154
158
  self.hessian_service = hessian_info_service
155
159
  self.qc = qc
156
160
  self.model_outputs = [out.node for out in graph.get_outputs()]
@@ -158,27 +162,17 @@ class ModelCollector:
158
162
  # Assign statistics collectors to nodes
159
163
  for n in graph.get_topo_sorted_nodes():
160
164
  quant_node_in_fln = n.is_fln_quantization() and graph.fusing_info.is_quantized_node_in_fln(n)
161
- sc = create_stats_collector_for_node(n, quant_node_in_fln=quant_node_in_fln) # Get static collector for the node
162
- if isinstance(sc, common.StatsCollector) and (sc.mc.axis is None or sc.mpcc.axis is None):
163
- # Missing output channel axis info, so try to extract it from previous and next nodes output channel axis.
164
- possible_output_channel_axis_set = {nn.out_channel_axis for nn in graph.get_next_nodes(n) + graph.get_prev_nodes(n)}
165
- # Filter out None values.
166
- possible_output_channel_axis_list = list(filter(lambda x: x is not None, possible_output_channel_axis_set))
167
- if len(possible_output_channel_axis_list) > 0:
168
- if len(possible_output_channel_axis_list) > 1:
169
- Logger.warning(f'Ambiguous input channel data from next nodes for {n.name}.')
170
- sc.mc.axis = possible_output_channel_axis_list[0]
171
- sc.mpcc.axis = possible_output_channel_axis_list[0]
172
-
165
+ sc = create_stats_collector_for_node(n, fw_info=fw_info, quant_node_in_fln=quant_node_in_fln) # Get static collector for the node
173
166
  # If we use bias correction, and the node has kernel weights to quantize, we need to make sure
174
167
  # its previous nodes' tensors are consistent with this node.
175
- if qc.weights_bias_correction and n.kernel_attr is not None and n.is_weights_quantization_enabled(
176
- n.kernel_attr):
168
+ kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0]
169
+ if qc.weights_bias_correction and kernel_attr is not None and n.is_weights_quantization_enabled(
170
+ kernel_attr):
177
171
  for ie in graph.incoming_edges(n):
178
172
  input_node = ie.source_node
179
173
  create_tensor2node(graph,
180
174
  input_node,
181
- n.out_channel_axis)
175
+ fw_info)
182
176
  if sc is not None:
183
177
  graph.set_out_stats_collector_to_node(n, sc)
184
178
 
@@ -211,11 +205,13 @@ class ModelCollector:
211
205
  # TODO: Add integration test for this case
212
206
  append2output = outputs_nodes + [n for n in self.model_outputs if n not in outputs_nodes]
213
207
 
208
+
214
209
  # Build a float model and output all layers' outputs
215
210
  # (that should be collected) as the model's outputs
216
211
  self.model, _ = self.fw_impl.model_builder(graph,
217
212
  mode=ModelBuilderMode.FLOAT,
218
- append2output=append2output)
213
+ append2output=append2output,
214
+ fw_info=self.fw_info)
219
215
 
220
216
  def infer(self, inputs_list: List[np.ndarray]):
221
217
  """
@@ -0,0 +1,44 @@
1
+ from abc import abstractmethod
2
+ from typing import Any
3
+
4
+ from model_compression_toolkit.core import FrameworkInfo
5
+
6
+
7
+ class ModelValidation:
8
+ """
9
+ Class to define validation methods in order to validate the received model to quantize.
10
+ """
11
+
12
+ def __init__(self,
13
+ model: Any,
14
+ fw_info:FrameworkInfo):
15
+ """
16
+ Initialize a ModelValidation object.
17
+
18
+ Args:
19
+ model: Model to check its validity.
20
+ fw_info: Information about the specific framework of the model.
21
+ """
22
+ self.model = model
23
+ self.fw_info = fw_info
24
+
25
+ @abstractmethod
26
+ def validate_output_channel_consistency(self):
27
+ """
28
+
29
+ Validate that output channels index in all layers of the model are the same.
30
+ If the model has layers with different output channels index, it should throw an exception.
31
+
32
+ """
33
+ raise NotImplemented(
34
+ f'Framework validation class did not implement validate_output_channel_consistency') # pragma: no cover
35
+
36
+ def validate(self):
37
+ """
38
+
39
+ Run all validation methods before the quantization process starts.
40
+
41
+ """
42
+ self.validate_output_channel_consistency()
43
+
44
+
@@ -13,14 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.network_editors.actions import (
17
- ChangeCandidatesWeightsQuantConfigAttr,
18
- ChangeFinalWeightsQuantConfigAttr,
19
- ChangeCandidatesActivationQuantConfigAttr,
20
- ChangeCandidatesActivationQuantizationMethod,
21
- ChangeFinalWeightsQuantizationMethod,
22
- ChangeCandidatesWeightsQuantizationMethod,
23
- ChangeFinalActivationQuantConfigAttr)
16
+ from model_compression_toolkit.core.common.network_editors.actions import ChangeCandidatesWeightsQuantConfigAttr, ChangeFinalWeightsQuantConfigAttr, ChangeCandidatesActivationQuantConfigAttr, ChangeQuantizationParamFunction, ChangeCandidatesActivationQuantizationMethod, ChangeFinalWeightsQuantizationMethod, ChangeCandidatesWeightsQuantizationMethod, ChangeFinalActivationQuantConfigAttr
24
17
  from model_compression_toolkit.core.common.network_editors.actions import EditRule
25
18
  from model_compression_toolkit.core.common.network_editors.node_filters import NodeTypeFilter, NodeNameScopeFilter, \
26
19
  NodeNameFilter