mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -23,6 +23,7 @@ import numpy as np
23
23
 
24
24
  from networkx.algorithms.dag import topological_sort
25
25
 
26
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
26
27
  from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo
27
28
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX, EDGE_SOURCE_INDEX
28
29
  from model_compression_toolkit.core.common.graph.edge import Edge, convert_to_edge
@@ -32,8 +33,7 @@ from model_compression_toolkit.core.common.collectors.statistics_collector impor
32
33
  from model_compression_toolkit.core.common.collectors.statistics_collector import scale_statistics, shift_statistics
33
34
  from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
34
35
  from model_compression_toolkit.core.common.user_info import UserInformation
35
- from model_compression_toolkit.core.common.quantization.node_quantization_config import \
36
- NodeActivationQuantizationConfig, ActivationQuantizationMode
36
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
37
37
  from model_compression_toolkit.logger import Logger
38
38
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
39
39
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
@@ -74,6 +74,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
74
74
  input_nodes: List[BaseNode],
75
75
  output_nodes: List[OutTensor],
76
76
  edge_list: List[Edge],
77
+ fw_info: FrameworkInfo = None,
77
78
  **attr):
78
79
  """
79
80
  Args:
@@ -81,6 +82,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
81
82
  input_nodes: List of input nodes the model
82
83
  output_nodes: List of output nodes of the model to a list of their output indices.
83
84
  edge_list: List of edges the graph has between nodes.
85
+ fw_info: FrameworkInfo object (needed for computing the graph's weights memory).
84
86
  **attr: Attributes to add to graph as key=value pairs.
85
87
  """
86
88
 
@@ -101,6 +103,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
101
103
  e.sink_node,
102
104
  **e.get_attributes())
103
105
  self.user_info = UserInformation()
106
+ self.fw_info = fw_info
104
107
 
105
108
  @property
106
109
  def skip_validation_check(self) -> bool:
@@ -121,13 +124,38 @@ class Graph(nx.MultiDiGraph, GraphSearches):
121
124
  def fusing_info(self, fusing_info: FusingInfo):
122
125
  self._fusing_info = fusing_info
123
126
 
124
- def set_fqc(self, fqc: FrameworkQuantizationCapabilities):
127
+ def set_fw_info(self,
128
+ fw_info: FrameworkInfo):
129
+ """
130
+ Set the graph's framework info.
131
+ Args:
132
+ fw_info: FrameworkInfo object.
133
+ """
134
+
135
+ self.fw_info = fw_info
136
+
137
+ def set_fqc(self,
138
+ fqc: FrameworkQuantizationCapabilities):
125
139
  """
126
140
  Set the graph's FQC.
127
141
  Args:
128
142
  fqc: FrameworkQuantizationCapabilities object.
129
143
  """
130
- # TODO irena: this is only passed for negative shift activation.
144
+ # validate graph nodes are either from the framework or a custom layer defined in the FQC
145
+ # Validate graph nodes are either built-in layers from the framework or custom layers defined in the FQC
146
+ fqc_layers = fqc.op_sets_to_layers.get_layers()
147
+ fqc_filtered_layers = [layer for layer in fqc_layers if isinstance(layer, LayerFilterParams)]
148
+ for n in self.nodes:
149
+ is_node_in_fqc = any([n.is_match_type(_type) for _type in fqc_layers]) or \
150
+ any([n.is_match_filter_params(filtered_layer) for filtered_layer in fqc_filtered_layers])
151
+ if n.is_custom:
152
+ if not is_node_in_fqc:
153
+ Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. '
154
+ ' Please add the custom layer to Framework Quantization Capabilities (FQC), or file a feature '
155
+ 'request or an issue if you believe this should be supported.') # pragma: no cover
156
+ if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(fqc).quantization_configurations]):
157
+ Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover
158
+
131
159
  self.fqc = fqc
132
160
 
133
161
  def get_topo_sorted_nodes(self):
@@ -535,6 +563,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
535
563
  return output_edges
536
564
 
537
565
  def get_configurable_sorted_nodes_names(self,
566
+ fw_info: FrameworkInfo,
538
567
  include_reused_nodes: bool = False) -> List[str]:
539
568
  """
540
569
  Get a list of nodes' names that can be configured (namely, has one or
@@ -542,49 +571,56 @@ class Graph(nx.MultiDiGraph, GraphSearches):
542
571
  order of the graph.
543
572
 
544
573
  Args:
574
+ fw_info: FrameworkInfo object with information about the specific framework's model.
545
575
  include_reused_nodes: Whether or not to include reused nodes (False by default).
546
576
 
547
577
  Returns: List of nodes' names that can be configured (namely, has one or
548
578
  more weight qc candidate) sorted topology.
549
579
 
550
580
  """
551
- sorted_names = [n.name for n in self.get_configurable_sorted_nodes(include_reused_nodes=include_reused_nodes)]
581
+ sorted_names = [n.name for n in self.get_configurable_sorted_nodes(fw_info=fw_info,
582
+ include_reused_nodes=include_reused_nodes)]
552
583
  return sorted_names
553
584
 
554
585
  def get_weights_configurable_nodes(self,
586
+ fw_info: FrameworkInfo,
555
587
  include_reused_nodes: bool = False) -> List[BaseNode]:
556
588
  """
557
589
  Get a list of nodes that their weights can be configured (namely, has one or
558
590
  more weight qc candidate and their weights should be quantized).
559
591
 
560
592
  Args:
593
+ fw_info: FrameworkInfo object with information about the specific framework's model.
561
594
  include_reused_nodes: Whether to include reused nodes (False by default).
562
595
 
563
596
  Returns:
564
597
  A list of nodes that their weights can be configured (namely, has one or more weight qc candidate).
565
598
  """
566
599
  # configurability is only relevant for kernel attribute quantization
567
- potential_conf_nodes = [n for n in self.nodes if n.kernel_attr]
600
+ potential_conf_nodes = [n for n in list(self) if fw_info.is_kernel_op(n.type)]
568
601
 
569
602
  def is_configurable(n):
570
- return n.is_configurable_weight(n.kernel_attr) and (not n.reuse or include_reused_nodes)
603
+ kernel_attrs = fw_info.get_kernel_op_attributes(n.type)
604
+ return any(n.is_configurable_weight(attr) for attr in kernel_attrs) and (not n.reuse or include_reused_nodes)
571
605
 
572
606
  return [n for n in potential_conf_nodes if is_configurable(n)]
573
607
 
574
608
  def get_sorted_weights_configurable_nodes(self,
609
+ fw_info: FrameworkInfo,
575
610
  include_reused_nodes: bool = False) -> List[BaseNode]:
576
611
  """
577
612
  Get a list of sorted nodes that their weights can be configured (namely, has one or
578
613
  more weight qc candidate and their weights should be quantized).
579
614
 
580
615
  Args:
616
+ fw_info: FrameworkInfo object with information about the specific framework's model.
581
617
  include_reused_nodes: Whether to include reused nodes (False by default).
582
618
 
583
619
  Returns:
584
620
  A list of nodes that their weights can be configured (namely, has one or more weight qc candidate)
585
621
  sorted topologically.
586
622
  """
587
- return self._sort_nodes_in_list(self.get_weights_configurable_nodes(include_reused_nodes))
623
+ return self._sort_nodes_in_list(self.get_weights_configurable_nodes(fw_info, include_reused_nodes))
588
624
 
589
625
  def get_activation_configurable_nodes(self) -> List[BaseNode]:
590
626
  """
@@ -608,6 +644,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
608
644
  return self._sort_nodes_in_list(self.get_activation_configurable_nodes())
609
645
 
610
646
  def get_configurable_sorted_nodes(self,
647
+ fw_info: FrameworkInfo,
611
648
  include_reused_nodes: bool = False) -> List[BaseNode]:
612
649
  """
613
650
  Get a list of nodes that can be configured (namely, has one or
@@ -615,13 +652,14 @@ class Graph(nx.MultiDiGraph, GraphSearches):
615
652
  The nodes are sorted according to the topological order of the graph.
616
653
 
617
654
  Args:
655
+ fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
618
656
  include_reused_nodes: Whether or not to include reused nodes (False by default).
619
657
 
620
658
  Returns:
621
659
  A list of nodes that can be configured (namely, has one or more qc candidate) sorted topology.
622
660
 
623
661
  """
624
- weights_configurable_nodes = self.get_weights_configurable_nodes(include_reused_nodes)
662
+ weights_configurable_nodes = self.get_weights_configurable_nodes(fw_info, include_reused_nodes)
625
663
  activation_configurable_nodes = self.get_activation_configurable_nodes()
626
664
 
627
665
  # combine and remove duplications
@@ -646,7 +684,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
646
684
  sorted_configurable_nodes.append(n)
647
685
  return sorted_configurable_nodes
648
686
 
649
- def get_min_candidates_config(self) -> Dict[BaseNode, int]:
687
+ def get_min_candidates_config(self, fw_info: FrameworkInfo) -> Dict[BaseNode, int]:
650
688
  """
651
689
  Builds a minimal configuration.
652
690
  Note: we assume that a minimal configuration exists, i.e., each configurable node has exactly one candidate
@@ -659,33 +697,38 @@ class Graph(nx.MultiDiGraph, GraphSearches):
659
697
  Returns:
660
698
  A dict from layer to an index of its minimal candidate.
661
699
  """
662
- conf_sorted_nodes = self.get_configurable_sorted_nodes()
700
+ conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info)
663
701
  return {n: n.find_min_candidate_index() for n in conf_sorted_nodes}
664
702
 
665
- def get_max_candidates_config(self) -> Dict[BaseNode, int]:
703
+ def get_max_candidates_config(self, fw_info: FrameworkInfo) -> Dict[BaseNode, int]:
666
704
  """
667
705
  Builds a maximal configuration.
668
706
  Note: we assume that a maximal configuration exists, i.e., each configurable node has exactly one candidate
669
707
  with maximal n_bits (in both weight and activation if both are quantized, or in the relevant one if only
670
708
  one of them is quantized)
671
709
 
710
+ Args:
711
+ fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
712
+
672
713
  Returns:
673
714
  A dict from layer to an index of its maximal candidate.
674
715
  """
675
- conf_sorted_nodes = self.get_configurable_sorted_nodes()
716
+ conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info)
676
717
  return {n: n.find_max_candidate_index() for n in conf_sorted_nodes}
677
718
 
678
- def get_final_weights_config(self) -> List[Tuple[BaseNode, int]]:
719
+ def get_final_weights_config(self, fw_info: FrameworkInfo) -> List[Tuple[BaseNode, int]]:
679
720
  """
680
721
  Gets the final number of bits for quantization of each weights' configurable layer.
681
722
 
682
- Returns:
683
- A list of pairs of (node type, node's weights quantization bitwidth).
723
+ Args:
724
+ fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
725
+
726
+ Returns: A list of pairs of (node type, node's weights quantization bitwidth).
684
727
 
685
728
  """
686
- sorted_conf_weights = self.get_sorted_weights_configurable_nodes()
729
+ sorted_conf_weights = self.get_sorted_weights_configurable_nodes(fw_info)
687
730
  # a configurable node by definition has a kernel op
688
- return [(n, n.final_weights_quantization_cfg.get_attr_config(n.kernel_attr).weights_n_bits)
731
+ return [(n, n.final_weights_quantization_cfg.get_attr_config(self.fw_info.get_kernel_op_attributes(n.type)[0]).weights_n_bits)
689
732
  for n in sorted_conf_weights]
690
733
 
691
734
  def get_final_activation_config(self) -> List[Tuple[BaseNode, int]]:
@@ -803,7 +846,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
803
846
  next_node = self.out_edges(next_node)[0].sink_node
804
847
 
805
848
  # If next_node is an exit node and has only one incoming edge, the topology is prunable.
806
- if fw_impl.is_node_exit_node(next_node, entry_node) and len(self.in_edges(next_node)) == 1:
849
+ if fw_impl.is_node_exit_node(next_node, entry_node, self.fw_info) and len(self.in_edges(next_node)) == 1:
807
850
  return True
808
851
 
809
852
  # If the next node is not an intermediate node or has more than one incoming/outgoing edge,
@@ -833,7 +876,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
833
876
 
834
877
  intermediate_nodes, exit_node = self._find_intermediate_and_exit_nodes(entry_node, fw_impl)
835
878
 
836
- if not fw_impl.is_node_exit_node(exit_node, entry_node):
879
+ if not fw_impl.is_node_exit_node(exit_node, entry_node, self.fw_info):
837
880
  Logger.critical(f"Node {exit_node} is not a valid exit node for the pruning section starting with {entry_node}.") # pragma: no cover
838
881
 
839
882
  return PruningSection(entry_node=entry_node,
@@ -854,37 +897,21 @@ class Graph(nx.MultiDiGraph, GraphSearches):
854
897
  """
855
898
  intermediate_nodes = []
856
899
  next_node = self.out_edges(entry_node)[0].sink_node
857
- while not fw_impl.is_node_exit_node(next_node, entry_node):
900
+ while not fw_impl.is_node_exit_node(next_node, entry_node, self.fw_info):
858
901
  intermediate_nodes.append(next_node)
859
902
  next_node = self.out_edges(next_node)[0].sink_node
860
903
 
861
904
  return intermediate_nodes, next_node
862
905
 
863
- # TODO irena move to load_fqc and clean up tests (currently tests_pytest/common_tests/unit_tests/core/graph/test_base_graph.py)
864
- def override_fused_node_activation_quantization_candidates(self):
906
+ def disable_fused_nodes_activation_quantization(self):
865
907
  """
866
- Override fused node activation quantization candidates for all nodes in fused operations,
908
+ Disable activation quantization for all nodes in fused operations,
867
909
  except for the last node in each fused group.
868
- Update the value of quantization_config with the value of op_quaitization_cfg from FusingInfo.
869
- """
870
- nodes_in_fln = self.fusing_info.get_inner_fln_nodes()
871
- for node in nodes_in_fln:
872
- fused_node_op_id = self.fusing_info.get_fused_op_id_for_node(node.name)
873
- fusing_op_quantization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id)
874
- if fusing_op_quantization_cfg is not None and fusing_op_quantization_cfg.enable_activation_quantization:
875
- def update(qc):
876
- qc.activation_quantization_cfg = NodeActivationQuantizationConfig(fusing_op_quantization_cfg)
877
- qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
878
- node.quantization_cfg.update_all(update, remove_duplicates=True)
879
- else:
880
- node.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.FLN_NO_QUANT)
881
- # Remove duplicate candidates. We cannot compare whole candidates since activation configs might not
882
- # be identical, but we do want to treat them as such. So we only check duplication by weight configs.
883
- uniq_qcs = []
884
- for qc in node.candidates_quantization_cfg:
885
- if not any(qc.weights_quantization_cfg == uqc.weights_quantization_cfg for uqc in uniq_qcs):
886
- uniq_qcs.append(qc)
887
- node.quantization_cfg.candidates_quantization_cfg = uniq_qcs
910
+ """
911
+ nodes_to_disable = self.fusing_info.get_inner_fln_nodes()
912
+ for node in nodes_to_disable:
913
+ for qc in node.candidates_quantization_cfg:
914
+ qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
888
915
 
889
916
  def validate(self):
890
917
  """
@@ -908,4 +935,4 @@ class Graph(nx.MultiDiGraph, GraphSearches):
908
935
  """
909
936
  Wrap networkx functions (that modifies the graph) with our validate decorator.
910
937
  """
911
- return super().remove_edge(*args, **kwargs)
938
+ return super().remove_edge(*args, **kwargs)