mct-nightly 2.4.0.20250617.613__py3-none-any.whl → 2.4.0.20250618.606__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/RECORD +120 -120
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +2 -5
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
  6. model_compression_toolkit/core/common/framework_implementation.py +10 -22
  7. model_compression_toolkit/core/common/framework_info.py +105 -68
  8. model_compression_toolkit/core/common/graph/base_graph.py +15 -42
  9. model_compression_toolkit/core/common/graph/base_node.py +103 -42
  10. model_compression_toolkit/core/common/graph/functional_node.py +18 -1
  11. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
  12. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
  16. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
  17. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
  18. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
  19. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
  20. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
  21. model_compression_toolkit/core/common/model_collector.py +10 -20
  22. model_compression_toolkit/core/common/model_validation.py +1 -4
  23. model_compression_toolkit/core/common/network_editors/actions.py +14 -38
  24. model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
  25. model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
  26. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
  27. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
  28. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
  29. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
  30. model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
  31. model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
  32. model_compression_toolkit/core/common/pruning/pruner.py +1 -6
  33. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
  34. model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
  35. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
  36. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
  37. model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
  38. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
  39. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
  40. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
  41. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
  42. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
  43. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
  44. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
  45. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
  46. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
  47. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
  48. model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
  49. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
  50. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
  51. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  52. model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
  53. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
  54. model_compression_toolkit/core/graph_prep_runner.py +2 -16
  55. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
  56. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
  57. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
  58. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
  59. model_compression_toolkit/core/keras/default_framework_info.py +138 -87
  60. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
  61. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
  62. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
  63. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
  64. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  65. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
  66. model_compression_toolkit/core/keras/keras_implementation.py +15 -35
  67. model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
  68. model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
  69. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
  70. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +0 -2
  71. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
  72. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
  73. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
  74. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
  75. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
  76. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
  77. model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
  78. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
  79. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
  80. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  81. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
  82. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
  83. model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
  84. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
  85. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  86. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
  87. model_compression_toolkit/core/quantization_prep_runner.py +4 -9
  88. model_compression_toolkit/core/runner.py +5 -15
  89. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  90. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  91. model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
  92. model_compression_toolkit/gptq/common/gptq_training.py +1 -8
  93. model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
  94. model_compression_toolkit/gptq/keras/graph_info.py +4 -6
  95. model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
  96. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  97. model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
  98. model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
  99. model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
  100. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  101. model_compression_toolkit/gptq/runner.py +1 -7
  102. model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
  103. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
  104. model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
  105. model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
  106. model_compression_toolkit/ptq/runner.py +1 -4
  107. model_compression_toolkit/qat/common/qat_config.py +2 -6
  108. model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
  109. model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
  110. model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
  111. model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
  112. model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
  113. model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
  114. model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
  115. model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
  116. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
  117. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
  118. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/WHEEL +0 -0
  119. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/licenses/LICENSE.md +0 -0
  120. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@
15
15
  import numpy as np
16
16
  from typing import runtime_checkable, Protocol, Callable, Any, List, Tuple
17
17
 
18
- from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig, MpDistanceWeighting
18
+ from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, MpDistanceWeighting
19
19
  from model_compression_toolkit.core.common import Graph, BaseNode
20
20
  from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
21
21
  HessianScoresGranularity
@@ -62,15 +62,12 @@ class DistanceMetricCalculator(MetricCalculator):
62
62
  graph: Graph,
63
63
  mp_config: MixedPrecisionQuantizationConfig,
64
64
  representative_data_gen: Callable,
65
- fw_info: FrameworkInfo,
66
65
  fw_impl: Any,
67
66
  hessian_info_service: HessianInfoService = None):
68
67
  """
69
68
  Args:
70
69
  graph: Graph to search for its MP configuration.
71
70
  mp_config: MP Quantization configuration for how the graph should be quantized.
72
- fw_info: FrameworkInfo object about the specific framework
73
- (e.g., attributes of different layers' weights to quantize).
74
71
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
75
72
  representative_data_gen: Dataset used for getting batches for inference.
76
73
  hessian_info_service: HessianInfoService to fetch Hessian approximation information.
@@ -78,14 +75,13 @@ class DistanceMetricCalculator(MetricCalculator):
78
75
  self.graph = graph
79
76
  self.mp_config = mp_config
80
77
  self.representative_data_gen = representative_data_gen
81
- self.fw_info = fw_info
82
78
  self.fw_impl = fw_impl
83
79
 
84
80
  if self.mp_config.distance_weighting_method == MpDistanceWeighting.HESSIAN:
85
81
  assert hessian_info_service is not None, ('Expected HessianInfoService object to be passed with Hessian '
86
82
  'distance weighting')
87
83
 
88
- self.sorted_configurable_nodes_names = graph.get_configurable_sorted_nodes_names(self.fw_info)
84
+ self.sorted_configurable_nodes_names = graph.get_configurable_sorted_nodes_names()
89
85
 
90
86
  # Get interest points and output points set for distance measurement and set other helper datasets
91
87
  # We define a separate set of output nodes of the model for the purpose of sensitivity computation.
@@ -396,8 +392,8 @@ class DistanceMetricCalculator(MetricCalculator):
396
392
  """
397
393
 
398
394
  return [n.node for n in graph.get_outputs()
399
- if (graph.fw_info.is_kernel_op(n.node.type) and
400
- n.node.is_weights_quantization_enabled(graph.fw_info.get_kernel_op_attributes(n.node.type)[0])) or
395
+ if (n.node.is_kernel_op and
396
+ n.node.is_weights_quantization_enabled(n.node.kernel_attr)) or
401
397
  n.node.is_activation_quantization_enabled()]
402
398
 
403
399
  @staticmethod
@@ -38,7 +38,6 @@ class SensitivityEvaluation:
38
38
  graph: Graph,
39
39
  mp_config: MixedPrecisionQuantizationConfig,
40
40
  representative_data_gen: Callable,
41
- fw_info: FrameworkInfo,
42
41
  fw_impl: Any,
43
42
  disable_activation_for_metric: bool = False,
44
43
  hessian_info_service: HessianInfoService = None
@@ -46,8 +45,6 @@ class SensitivityEvaluation:
46
45
  """
47
46
  Args:
48
47
  graph: Graph to search for its MP configuration.
49
- fw_info: FrameworkInfo object about the specific framework
50
- (e.g., attributes of different layers' weights to quantize).
51
48
  mp_config: MP Quantization configuration for how the graph should be quantized.
52
49
  representative_data_gen: Dataset used for getting batches for inference.
53
50
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
@@ -57,14 +54,13 @@ class SensitivityEvaluation:
57
54
  """
58
55
  self.mp_config = mp_config
59
56
  self.representative_data_gen = representative_data_gen
60
- self.fw_info = fw_info
61
57
  self.fw_impl = fw_impl
62
58
 
63
59
  if self.mp_config.custom_metric_fn:
64
60
  self.metric_calculator = CustomMetricCalculator(graph, self.mp_config.custom_metric_fn)
65
61
  else:
66
62
  self.metric_calculator = DistanceMetricCalculator(graph, mp_config, representative_data_gen,
67
- fw_info=fw_info, fw_impl=fw_impl,
63
+ fw_impl=fw_impl,
68
64
  hessian_info_service=hessian_info_service)
69
65
 
70
66
  # Build a mixed-precision model which can be configured to use different bitwidth in different layers.
@@ -111,8 +107,7 @@ class SensitivityEvaluation:
111
107
 
112
108
  model_mp, _, conf_node2layers = self.fw_impl.model_builder(evaluation_graph,
113
109
  mode=ModelBuilderMode.MIXEDPRECISION,
114
- append2output=outputs,
115
- fw_info=self.fw_info)
110
+ append2output=outputs)
116
111
 
117
112
  # Disable all configurable quantizers. They will be activated one at a time during sensitivity evaluation.
118
113
  for layer in itertools.chain(*conf_node2layers.values()):
@@ -18,7 +18,7 @@ import numpy as np
18
18
  from typing import List, Union, Tuple, Optional
19
19
 
20
20
  from networkx.algorithms.dag import topological_sort
21
- from model_compression_toolkit.core import FrameworkInfo, QuantizationErrorMethod
21
+ from model_compression_toolkit.core import QuantizationErrorMethod
22
22
  from model_compression_toolkit.core import common
23
23
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
24
24
  from model_compression_toolkit.core.common.graph.base_graph import Graph
@@ -30,7 +30,6 @@ from model_compression_toolkit.core.common.collectors.statistics_collector impor
30
30
 
31
31
 
32
32
  def create_stats_collector_for_node(node: common.BaseNode,
33
- fw_info: FrameworkInfo,
34
33
  quant_node_in_fln: bool) -> BaseStatsCollector:
35
34
  """
36
35
  Gets a node and a groups list and create and return a statistics collector for a node
@@ -39,7 +38,7 @@ def create_stats_collector_for_node(node: common.BaseNode,
39
38
 
40
39
  Args:
41
40
  node: Node to create its statistics collector.
42
- fw_info: Information relevant to a specific framework about what is out channel axis (for statistics per-channel).
41
+ quant_node_in_fln: Whether the node should be quantized as part of an FLN.
43
42
 
44
43
  Returns:
45
44
  Statistics collector for statistics collection for the node.
@@ -48,7 +47,7 @@ def create_stats_collector_for_node(node: common.BaseNode,
48
47
  if node.is_activation_quantization_enabled() or quant_node_in_fln:
49
48
  min_output = getattr(node.prior_info, 'min_output', None)
50
49
  max_output = getattr(node.prior_info, 'max_output', None)
51
- stats_collector = common.StatsCollector(out_channel_axis=fw_info.out_channel_axis_mapping.get(node.type),
50
+ stats_collector = common.StatsCollector(out_channel_axis=node.out_channel_axis,
52
51
  init_min_value=min_output,
53
52
  init_max_value=max_output)
54
53
  else:
@@ -58,21 +57,19 @@ def create_stats_collector_for_node(node: common.BaseNode,
58
57
 
59
58
 
60
59
  def create_tensor2node(graph: common.Graph,
61
- node: common.BaseNode,
62
- fw_info: common.FrameworkInfo):
60
+ node: common.BaseNode):
63
61
  """
64
62
  Force statistic collector creation and assignment for a node.
65
63
  Args:
66
64
  graph: Graph of the node (for retrieving the current tensor).
67
65
  node: Node to create a tensor for.
68
- fw_info: Specific framework information (for example, output channels index).
69
66
 
70
67
  """
71
68
  current_sc = graph.get_out_stats_collector(node)
72
69
  is_list_nostat_collectors = isinstance(current_sc, list) and len(
73
70
  [sc for sc in current_sc if not isinstance(sc, common.NoStatsCollector)]) == 0
74
71
  if isinstance(current_sc, common.NoStatsCollector) or current_sc is None or is_list_nostat_collectors:
75
- stats_collector = common.StatsCollector(fw_info.out_channel_axis_mapping.get(node.type))
72
+ stats_collector = common.StatsCollector(node.out_channel_axis)
76
73
  graph.set_out_stats_collector_to_node(node, stats_collector)
77
74
 
78
75
 
@@ -140,7 +137,6 @@ class ModelCollector:
140
137
 
141
138
  def __init__(self, graph: Graph,
142
139
  fw_impl: FrameworkImplementation,
143
- fw_info: FrameworkInfo,
144
140
  hessian_info_service: HessianInfoService = None,
145
141
  qc: common.QuantizationConfig = common.DEFAULTCONFIG):
146
142
  """
@@ -149,12 +145,10 @@ class ModelCollector:
149
145
  Args:
150
146
  graph: Graph to build a model from it.
151
147
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
152
- fw_info: FrameworkInfo object with a specific framework information.
153
148
  qc: Quantization configuration containing parameters for how the graph should be quantized.
154
149
  """
155
150
 
156
151
  self.fw_impl = fw_impl
157
- self.fw_info = fw_info
158
152
  self.hessian_service = hessian_info_service
159
153
  self.qc = qc
160
154
  self.model_outputs = [out.node for out in graph.get_outputs()]
@@ -162,17 +156,15 @@ class ModelCollector:
162
156
  # Assign statistics collectors to nodes
163
157
  for n in graph.get_topo_sorted_nodes():
164
158
  quant_node_in_fln = n.is_fln_quantization() and graph.fusing_info.is_quantized_node_in_fln(n)
165
- sc = create_stats_collector_for_node(n, fw_info=fw_info, quant_node_in_fln=quant_node_in_fln) # Get static collector for the node
159
+ sc = create_stats_collector_for_node(n, quant_node_in_fln=quant_node_in_fln) # Get static collector for the node
166
160
  # If we use bias correction, and the node has kernel weights to quantize, we need to make sure
167
161
  # its previous nodes' tensors are consistent with this node.
168
- kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0]
169
- if qc.weights_bias_correction and kernel_attr is not None and n.is_weights_quantization_enabled(
170
- kernel_attr):
162
+ if qc.weights_bias_correction and n.kernel_attr is not None and n.is_weights_quantization_enabled(
163
+ n.kernel_attr):
171
164
  for ie in graph.incoming_edges(n):
172
165
  input_node = ie.source_node
173
166
  create_tensor2node(graph,
174
- input_node,
175
- fw_info)
167
+ input_node)
176
168
  if sc is not None:
177
169
  graph.set_out_stats_collector_to_node(n, sc)
178
170
 
@@ -205,13 +197,11 @@ class ModelCollector:
205
197
  # TODO: Add integration test for this case
206
198
  append2output = outputs_nodes + [n for n in self.model_outputs if n not in outputs_nodes]
207
199
 
208
-
209
200
  # Build a float model and output all layers' outputs
210
201
  # (that should be collected) as the model's outputs
211
202
  self.model, _ = self.fw_impl.model_builder(graph,
212
203
  mode=ModelBuilderMode.FLOAT,
213
- append2output=append2output,
214
- fw_info=self.fw_info)
204
+ append2output=append2output)
215
205
 
216
206
  def infer(self, inputs_list: List[np.ndarray]):
217
207
  """
@@ -10,17 +10,14 @@ class ModelValidation:
10
10
  """
11
11
 
12
12
  def __init__(self,
13
- model: Any,
14
- fw_info:FrameworkInfo):
13
+ model: Any):
15
14
  """
16
15
  Initialize a ModelValidation object.
17
16
 
18
17
  Args:
19
18
  model: Model to check its validity.
20
- fw_info: Information about the specific framework of the model.
21
19
  """
22
20
  self.model = model
23
- self.fw_info = fw_info
24
21
 
25
22
  @abstractmethod
26
23
  def validate_output_channel_consistency(self):
@@ -22,7 +22,7 @@ from model_compression_toolkit.core.common import Graph
22
22
  from model_compression_toolkit.logger import Logger
23
23
 
24
24
 
25
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
25
+ from model_compression_toolkit.core.common.framework_info import get_fw_info
26
26
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
27
27
  from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
28
28
  get_activation_quantization_params_fn, get_weights_quantization_params_fn
@@ -64,15 +64,13 @@ class BaseAction(ABC):
64
64
  """
65
65
 
66
66
  @abstractmethod
67
- def apply(self, node: BaseNode, graph, fw_info):
67
+ def apply(self, node: BaseNode, graph):
68
68
  """
69
69
  Apply an action on the node after matching the node with a node filter.
70
70
 
71
71
  Args:
72
72
  node: Node to apply the action on.
73
73
  graph: Graph to apply the action on.
74
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
75
- groups of layers by how they should be quantized, etc.)
76
74
 
77
75
  Returns:
78
76
  Node after action is applied.
@@ -95,15 +93,13 @@ class ChangeCandidatesWeightsQuantConfigAttr(BaseAction):
95
93
  self.kwargs = kwargs
96
94
  self.attr_name = attr_name
97
95
 
98
- def apply(self, node: BaseNode, graph, fw_info):
96
+ def apply(self, node: BaseNode, graph):
99
97
  """
100
98
  Change the attribute 'attr_name' in weights quantization config candidates with 'attr_value'.
101
99
 
102
100
  Args:
103
101
  node: Node object to change its quant_config.
104
102
  graph: Graph to apply the action on.
105
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
106
- groups of layers by how they should be quantized, etc.)
107
103
  Returns:
108
104
  The node after its weights' quantization config candidates have been modified.
109
105
  """
@@ -128,7 +124,7 @@ class ChangeFinalWeightsQuantConfigAttr(BaseAction):
128
124
  self.kwargs = kwargs
129
125
  self.attr_name = attr_name
130
126
 
131
- def apply(self, node: BaseNode, graph, fw_info):
127
+ def apply(self, node: BaseNode, graph):
132
128
  if node.final_weights_quantization_cfg is not None:
133
129
  for parameter_name, parameter_value in self.kwargs.items():
134
130
  node.final_weights_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value,
@@ -147,17 +143,13 @@ class ChangeCandidatesActivationQuantConfigAttr(BaseAction):
147
143
  """
148
144
  self.kwargs = kwargs
149
145
 
150
- def apply(self, node: BaseNode, graph, fw_info):
146
+ def apply(self, node: BaseNode, graph):
151
147
  """
152
148
  Change the attribute 'attr_name' in activation quantization configuration candidates with 'attr_value'.
153
149
 
154
150
  Args:
155
151
  node: Node object to change its quant_config.
156
152
  graph: Graph to apply the action on.
157
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
158
- groups of layers by how they should be quantized, etc.)
159
- Returns:q
160
- The node after its activation quantization configuration candidates have been modified.
161
153
  """
162
154
  for nqc in node.candidates_quantization_cfg:
163
155
  for parameter_name, parameter_value in self.kwargs.items():
@@ -176,7 +168,7 @@ class ChangeFinalActivationQuantConfigAttr(BaseAction):
176
168
  """
177
169
  self.kwargs = kwargs
178
170
 
179
- def apply(self, node: BaseNode, graph, fw_info):
171
+ def apply(self, node: BaseNode, graph):
180
172
  if node.final_activation_quantization_cfg is not None:
181
173
  for parameter_name, parameter_value in self.kwargs.items():
182
174
  node.final_activation_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value)
@@ -203,15 +195,13 @@ class ChangeQuantizationParamFunction(BaseAction):
203
195
  self.weights_quantization_params_fn = weights_quantization_params_fn
204
196
  self.attr_name = attr_name
205
197
 
206
- def apply(self, node: BaseNode, graph, fw_info):
198
+ def apply(self, node: BaseNode, graph):
207
199
  """
208
200
  Change the node's weights/activations quantization params function.
209
201
 
210
202
  Args:
211
203
  node: Node object to change its quantization params function.
212
204
  graph: Graph to apply the action on.
213
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
214
- groups of layers by how they should be quantized, etc.)
215
205
 
216
206
  Returns:
217
207
  The node after its quantization params function has been modified.
@@ -240,15 +230,13 @@ class ChangeFinalActivationQuantizationMethod(BaseAction):
240
230
 
241
231
  self.activation_quantization_method = activation_quantization_method
242
232
 
243
- def apply(self, node: BaseNode, graph, fw_info):
233
+ def apply(self, node: BaseNode, graph):
244
234
  """
245
235
  Change the node's activations quantization function.
246
236
 
247
237
  Args:
248
238
  node: Node object to change its threshold selection function.
249
239
  graph: Graph to apply the action on.
250
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
251
- groups of layers by how they should be quantized, etc.)
252
240
 
253
241
  Returns:
254
242
  The node after its quantization function has been modified.
@@ -262,7 +250,7 @@ class ChangeFinalActivationQuantizationMethod(BaseAction):
262
250
  node.final_activation_quantization_cfg.set_activation_quantization_params_fn(
263
251
  activation_quantization_params_fn)
264
252
 
265
- activation_quantization_fn = fw_info.activation_quantizer_mapping.get(self.activation_quantization_method)
253
+ activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(self.activation_quantization_method)
266
254
 
267
255
  node.final_activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
268
256
  node.final_activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
@@ -282,18 +270,14 @@ class ChangeCandidatesActivationQuantizationMethod(BaseAction):
282
270
  """
283
271
  self.activation_quantization_method = activation_quantization_method
284
272
 
285
- def apply(self, node: BaseNode, graph, fw_info):
273
+ def apply(self, node: BaseNode, graph):
286
274
  """
287
275
  Change the node's activations quantization function.
288
276
 
289
277
  Args:
290
278
  node: Node object to change its threshold selection function.
291
279
  graph: Graph to apply the action on.
292
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
293
- groups of layers by how they should be quantized, etc.)
294
280
 
295
- Returns:
296
- The node after its quantization function has been modified.
297
281
  """
298
282
  if self.activation_quantization_method is not None:
299
283
  for qc in node.candidates_quantization_cfg:
@@ -301,7 +285,7 @@ class ChangeCandidatesActivationQuantizationMethod(BaseAction):
301
285
  self.activation_quantization_method)
302
286
 
303
287
  qc.activation_quantization_cfg.set_activation_quantization_params_fn(activation_quantization_params_fn)
304
- activation_quantization_fn = fw_info.activation_quantizer_mapping.get(
288
+ activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(
305
289
  self.activation_quantization_method)
306
290
 
307
291
  if activation_quantization_fn is None:
@@ -328,18 +312,14 @@ class ChangeFinalWeightsQuantizationMethod(BaseAction):
328
312
  self.weights_quantization_method = weights_quantization_method
329
313
  self.attr_name = attr_name
330
314
 
331
- def apply(self, node: BaseNode, graph, fw_info):
315
+ def apply(self, node: BaseNode, graph):
332
316
  """
333
317
  Change the node's weights quantization function.
334
318
 
335
319
  Args:
336
320
  node: Node object to change its threshold selection function.
337
321
  graph: Graph to apply the action on.
338
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
339
- groups of layers by how they should be quantized, etc.)
340
322
 
341
- Returns:
342
- The node after its quantization function has been modified.
343
323
  """
344
324
 
345
325
  if self.weights_quantization_method is not None and node.final_weights_quantization_cfg is not None:
@@ -376,15 +356,13 @@ class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
376
356
  self.weights_quantization_method = weights_quantization_method
377
357
  self.attr_name = attr_name
378
358
 
379
- def apply(self, node: BaseNode, graph: Graph, fw_info: FrameworkInfo):
359
+ def apply(self, node: BaseNode, graph: Graph):
380
360
  """
381
361
  Change the node's weights quantization function.
382
362
 
383
363
  Args:
384
364
  node: Node object to change its threshold selection function.
385
365
  graph: Graph to apply the action on.
386
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
387
- groups of layers by how they should be quantized, etc.)
388
366
 
389
367
  Returns:
390
368
  The node after its quantization function has been modified.
@@ -422,15 +400,13 @@ class ReplaceLayer(BaseAction):
422
400
  self.layer_type = layer_type
423
401
  self.get_params_and_weights_fn = get_params_and_weights_fn
424
402
 
425
- def apply(self, node: BaseNode, graph: Graph, fw_info: FrameworkInfo):
403
+ def apply(self, node: BaseNode, graph: Graph):
426
404
  """
427
405
  Replacing node's layer type and configurations
428
406
 
429
407
  Args:
430
408
  node: Node object to replace or modify
431
409
  graph: Graph to apply the action on.
432
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
433
- groups of layers by how they should be quantized, etc.)
434
410
 
435
411
  Returns:
436
412
  The node after its layer functionality has been modified.
@@ -14,20 +14,17 @@
14
14
  # ==============================================================================
15
15
  from typing import List
16
16
 
17
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
18
17
  from model_compression_toolkit.core.common.graph.base_graph import Graph
19
18
  from model_compression_toolkit.core.common.network_editors import EditRule
20
19
 
21
20
 
22
21
  def edit_network_graph(graph: Graph,
23
- fw_info: FrameworkInfo,
24
22
  network_editor: List[EditRule]):
25
23
  """
26
24
  Apply a list of edit rules on a graph.
27
25
 
28
26
  Args:
29
27
  graph: The graph to edit.
30
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
31
28
  groups of layers by how they should be quantized, etc.)
32
29
  network_editor: List of edit rules to apply to the graph.
33
30
 
@@ -38,5 +35,5 @@ def edit_network_graph(graph: Graph,
38
35
  for edit_rule in network_editor:
39
36
  filtered_nodes = graph.filter(edit_rule.filter)
40
37
  for node in filtered_nodes:
41
- edit_rule.action.apply(node, graph, fw_info)
38
+ edit_rule.action.apply(node, graph)
42
39
  # return graph
@@ -26,18 +26,14 @@ class ChannelGrouping:
26
26
  based on their importance scores and SIMD group sizes.
27
27
  """
28
28
 
29
- def __init__(self,
30
- prunable_nodes: List[BaseNode],
31
- fw_info: FrameworkInfo):
29
+ def __init__(self, prunable_nodes: List[BaseNode]):
32
30
  """
33
31
  Initializes the ChannelGrouping with necessary information.
34
32
 
35
33
  Args:
36
34
  prunable_nodes: List of nodes that can be pruned.
37
- fw_info: Framework-specific information and utilities.
38
35
  """
39
36
  self.prunable_nodes = prunable_nodes
40
- self.fw_info = fw_info
41
37
  # Store for each node a list of numpy arrays. Each numpy array represents the
42
38
  # indices of the channels in an SIMD group.
43
39
  self._simd_groups_indices = {}
@@ -38,7 +38,6 @@ class GreedyMaskCalculator:
38
38
  """
39
39
  def __init__(self,
40
40
  prunable_nodes: List[BaseNode],
41
- fw_info: FrameworkInfo,
42
41
  simd_groups_scores: Dict[BaseNode, np.ndarray],
43
42
  target_resource_utilization: ResourceUtilization,
44
43
  graph: Graph,
@@ -48,7 +47,6 @@ class GreedyMaskCalculator:
48
47
  """
49
48
  Args:
50
49
  prunable_nodes (List[BaseNode]): Nodes that are eligible for pruning.
51
- fw_info (FrameworkInfo): Framework-specific information and utilities.
52
50
  simd_groups_scores (Dict[BaseNode, np.ndarray]): Importance scores for each SIMG group in a prunable node.
53
51
  target_resource_utilization (ResourceUtilization): The target resource utilization to achieve.
54
52
  graph (Graph): The computational graph of the model.
@@ -57,7 +55,6 @@ class GreedyMaskCalculator:
57
55
  simd_groups_indices (Dict[BaseNode, List[List[int]]]): Indices of SIMD groups in each node.
58
56
  """
59
57
  self.prunable_nodes = prunable_nodes
60
- self.fw_info = fw_info
61
58
  self.target_resource_utilization = target_resource_utilization
62
59
  self.graph = graph
63
60
  self.fw_impl = fw_impl
@@ -67,14 +64,11 @@ class GreedyMaskCalculator:
67
64
  self.simd_groups_scores = simd_groups_scores
68
65
 
69
66
  self.oc_pruning_mask = PerSIMDGroupMask(prunable_nodes=prunable_nodes,
70
- fw_info=fw_info,
71
67
  simd_groups_indices=simd_groups_indices)
72
68
 
73
69
  self.memory_calculator = MemoryCalculator(graph=graph,
74
- fw_info=fw_info,
75
70
  fw_impl=fw_impl)
76
71
 
77
-
78
72
  def get_mask(self) -> Dict[BaseNode, np.ndarray]:
79
73
  """
80
74
  Retrieves the current pruning mask for each prunable node.
@@ -38,8 +38,7 @@ class LFHImportanceMetric(BaseImportanceMetric):
38
38
  graph: Graph,
39
39
  representative_data_gen: Callable,
40
40
  fw_impl: PruningFrameworkImplementation,
41
- pruning_config: PruningConfig,
42
- fw_info: FrameworkInfo):
41
+ pruning_config: PruningConfig):
43
42
  """
44
43
  Initialize the LFHImportanceMetric instance.
45
44
 
@@ -48,13 +47,11 @@ class LFHImportanceMetric(BaseImportanceMetric):
48
47
  representative_data_gen (Callable): Function to generate representative data.
49
48
  fw_impl (PruningFrameworkImplementation): Implementation of pruning for the framework.
50
49
  pruning_config (PruningConfig): Configuration for pruning.
51
- fw_info (FrameworkInfo): Framework-specific information.
52
50
  """
53
51
  self.float_graph = graph
54
52
  self.representative_data_gen = representative_data_gen
55
53
  self.fw_impl = fw_impl
56
54
  self.pruning_config = pruning_config
57
- self.fw_info = fw_info
58
55
 
59
56
  # Initialize internal dictionaries for storing intermediate computations.
60
57
  self._entry_node_to_hessian_score = {}
@@ -158,8 +155,7 @@ class LFHImportanceMetric(BaseImportanceMetric):
158
155
  Dict[BaseNode, List[np.ndarray]]: Dictionary of entry nodes mapped to their SIMD group indices.
159
156
  """
160
157
  # Initialize channel grouping utility.
161
- channel_grouping = ChannelGrouping(prunable_nodes=list(entry_node_to_score.keys()),
162
- fw_info=self.fw_info)
158
+ channel_grouping = ChannelGrouping(prunable_nodes=list(entry_node_to_score.keys()))
163
159
 
164
160
  channel_grouping.group_scores_by_simd_groups(entry_node_to_score)
165
161
  grouped_indices = channel_grouping.simd_groups_indices
@@ -249,20 +245,14 @@ class LFHImportanceMetric(BaseImportanceMetric):
249
245
  Returns:
250
246
  tuple: A tuple containing the kernel attribute, the number of output channels, and the axis of the output channels.
251
247
  """
252
- kernel_attr = self.fw_info.get_kernel_op_attributes(entry_node.type)
253
- # Ensure only one kernel attribute exists for the given node.
254
- if len(kernel_attr) != 1:
255
- Logger.critical(f"Expected a single attribute but found multiple attributes ({len(kernel_attr)}) for node {entry_node}.")
256
- kernel_attr = kernel_attr[0]
257
-
258
248
  # Retrieve and validate the axis index for the output channels.
259
- oc_axis, _ = self.fw_info.kernel_channels_mapping.get(entry_node.type)
249
+ oc_axis = entry_node.channel_axis.output
260
250
  if oc_axis is None or int(oc_axis) != oc_axis:
261
251
  Logger.critical(f"Invalid output channel axis type for node {entry_node}: expected integer but got {oc_axis}.")
262
252
 
263
253
  # Get the number of output channels based on the kernel attribute and axis.
264
- num_oc = entry_node.get_weights_by_keys(kernel_attr[0]).shape[oc_axis]
265
- return kernel_attr, num_oc, oc_axis
254
+ num_oc = entry_node.get_weights_by_keys(entry_node.kernel_attr).shape[oc_axis]
255
+ return entry_node.kernel_attr, num_oc, oc_axis
266
256
 
267
257
  def _concatenate_tensors_by_indices(self,
268
258
  channels: List[np.ndarray],
@@ -35,9 +35,8 @@ class MaskIndicator(Enum):
35
35
  REMAINED = 1
36
36
 
37
37
 
38
-
39
38
  class PerChannelMask:
40
- def __init__(self, prunable_nodes: List[BaseNode], fw_info: FrameworkInfo):
39
+ def __init__(self, prunable_nodes: List[BaseNode]):
41
40
  """
42
41
  Initializes the PerChannelMask with prunable nodes and framework information.
43
42
  This class is responsible for maintaining and updating the pruning masks for each
@@ -46,10 +45,8 @@ class PerChannelMask:
46
45
 
47
46
  Args:
48
47
  prunable_nodes: List of nodes in the model that are subject to pruning.
49
- fw_info: Framework-specific information required for pruning operations.
50
48
  """
51
49
  self.prunable_nodes = prunable_nodes
52
- self.fw_info = fw_info
53
50
  self._mask = None # Initialize the mask dictionary
54
51
  self._init_masks() # Call to initialize masks for each prunable node
55
52
 
@@ -106,8 +103,7 @@ class PerChannelMask:
106
103
  Returns:
107
104
  int: Number of output channels for the node.
108
105
  """
109
- kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)[0]
110
- oc_axis = self.fw_info.kernel_channels_mapping.get(node.type)[0]
111
- num_oc = node.get_weights_by_keys(kernel_attr).shape[oc_axis]
106
+ oc_axis = node.channel_axis.output
107
+ num_oc = node.get_weights_by_keys(node.kernel_attr).shape[oc_axis]
112
108
  return num_oc
113
109
 
@@ -24,10 +24,10 @@ from model_compression_toolkit.core.common.pruning.memory_calculator import Memo
24
24
  from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation
25
25
  from model_compression_toolkit.logger import Logger
26
26
 
27
+
27
28
  class PerSIMDGroupMask:
28
29
  def __init__(self,
29
30
  prunable_nodes: List[BaseNode],
30
- fw_info: FrameworkInfo,
31
31
  simd_groups_indices: Dict[BaseNode, List[List[int]]]):
32
32
  """
33
33
  Initializes a mask calculator for SIMD groups in prunable nodes.
@@ -35,13 +35,11 @@ class PerSIMDGroupMask:
35
35
 
36
36
  Args:
37
37
  prunable_nodes: List of nodes that can be pruned.
38
- fw_info: Framework-specific information.
39
38
  simd_groups_indices: A dictionary mapping each node to its SIMD groups' indices.
40
39
  """
41
40
  # Initialize the per-channel mask
42
- self.per_channel_mask = PerChannelMask(prunable_nodes=prunable_nodes, fw_info=fw_info)
41
+ self.per_channel_mask = PerChannelMask(prunable_nodes=prunable_nodes)
43
42
  self.prunable_nodes = prunable_nodes
44
- self.fw_info = fw_info
45
43
  self.simd_groups_indices = simd_groups_indices
46
44
  self._mask_simd = None # Initialize the SIMD group mask dictionary
47
45
  self._init_masks() # Initialize masks for each prunable node