mct-nightly 2.4.0.20250617.613__py3-none-any.whl → 2.4.0.20250619.621__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/RECORD +123 -123
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +2 -5
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
  6. model_compression_toolkit/core/common/framework_implementation.py +10 -22
  7. model_compression_toolkit/core/common/framework_info.py +105 -68
  8. model_compression_toolkit/core/common/graph/base_graph.py +15 -42
  9. model_compression_toolkit/core/common/graph/base_node.py +103 -42
  10. model_compression_toolkit/core/common/graph/functional_node.py +18 -1
  11. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
  12. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
  16. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
  17. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
  18. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
  19. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
  20. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
  21. model_compression_toolkit/core/common/model_collector.py +10 -20
  22. model_compression_toolkit/core/common/model_validation.py +1 -4
  23. model_compression_toolkit/core/common/network_editors/actions.py +14 -38
  24. model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
  25. model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
  26. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
  27. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
  28. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
  29. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
  30. model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
  31. model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
  32. model_compression_toolkit/core/common/pruning/pruner.py +1 -6
  33. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
  34. model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
  35. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
  36. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
  37. model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
  38. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
  39. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
  40. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
  41. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
  42. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
  43. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
  44. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
  45. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
  46. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
  47. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
  48. model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
  49. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
  50. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
  51. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  52. model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
  53. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
  54. model_compression_toolkit/core/graph_prep_runner.py +2 -16
  55. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
  56. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
  57. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
  58. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
  59. model_compression_toolkit/core/keras/default_framework_info.py +138 -87
  60. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
  61. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
  62. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
  63. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
  64. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  65. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
  66. model_compression_toolkit/core/keras/keras_implementation.py +15 -35
  67. model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
  68. model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
  69. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
  70. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  71. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
  72. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
  73. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
  74. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
  75. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
  76. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
  77. model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
  78. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
  79. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
  80. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  81. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
  82. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
  83. model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
  84. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
  85. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  86. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
  87. model_compression_toolkit/core/quantization_prep_runner.py +4 -9
  88. model_compression_toolkit/core/runner.py +5 -15
  89. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  90. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  91. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -0
  92. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +19 -17
  93. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -0
  94. model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
  95. model_compression_toolkit/gptq/common/gptq_training.py +1 -8
  96. model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
  97. model_compression_toolkit/gptq/keras/graph_info.py +4 -6
  98. model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  100. model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
  101. model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
  102. model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
  103. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  104. model_compression_toolkit/gptq/runner.py +1 -7
  105. model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
  106. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
  107. model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
  108. model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
  109. model_compression_toolkit/ptq/runner.py +1 -4
  110. model_compression_toolkit/qat/common/qat_config.py +2 -6
  111. model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
  112. model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
  113. model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
  114. model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
  115. model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
  116. model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
  117. model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
  118. model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
  119. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
  120. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
  121. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/WHEEL +0 -0
  122. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/licenses/LICENSE.md +0 -0
  123. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/top_level.txt +0 -0
@@ -32,7 +32,6 @@ from model_compression_toolkit.core.common.substitutions.apply_substitutions imp
32
32
  def _collect_and_assign_act_threshold(graph: Graph,
33
33
  representative_data_gen: Callable,
34
34
  core_config: CoreConfig,
35
- fw_info: FrameworkInfo,
36
35
  fw_impl: FrameworkImplementation):
37
36
  """
38
37
  Collect statistics after second moment correction and assign new thresholds to activations.
@@ -41,14 +40,12 @@ def _collect_and_assign_act_threshold(graph: Graph,
41
40
  representative_data_gen (Callable): Dataset used for calibration.
42
41
  core_config (CoreConfig): Configuration object containing parameters of how the model should be
43
42
  quantized, including mixed precision parameters.
44
- fw_info: FrameworkInfo object with information about the specific framework's model.
45
43
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
46
44
  """
47
45
 
48
46
  mi = ModelCollector(graph,
49
47
  fw_impl,
50
- fw_info,
51
- core_config.quantization_config) # Mark points for statistics collection
48
+ core_config.quantization_config) # Mark points for statistics collection
52
49
 
53
50
  for _data in tqdm(representative_data_gen()):
54
51
  mi.infer(_data)
@@ -63,14 +60,12 @@ def _collect_and_assign_act_threshold(graph: Graph,
63
60
 
64
61
 
65
62
  def quantized_model_builder_for_second_moment_correction(graph: common.Graph,
66
- fw_info: FrameworkInfo,
67
63
  fw_impl: Any):
68
64
  """
69
65
  Build a framework model from a graph for second moment correction.
70
66
 
71
67
  Args:
72
- graph: Graph to build the from.
73
- fw_info: FrameworkInfo object with information about the specific framework's model.
68
+ graph: Graph to build from.
74
69
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
75
70
 
76
71
  Returns:
@@ -79,15 +74,13 @@ def quantized_model_builder_for_second_moment_correction(graph: common.Graph,
79
74
  quantized_tg = quantize_graph_weights(graph)
80
75
 
81
76
  quantized_model, user_info = fw_impl.model_builder(quantized_tg,
82
- mode=ModelBuilderMode.FLOAT,
83
- fw_info=fw_info)
77
+ mode=ModelBuilderMode.FLOAT)
84
78
  return quantized_model
85
79
 
86
80
 
87
81
  def apply_second_moment_correction_to_graph(graph: Graph,
88
82
  representative_data_gen: Callable,
89
83
  core_config: CoreConfig,
90
- fw_info: FrameworkInfo,
91
84
  fw_impl: FrameworkImplementation) -> Graph:
92
85
  """
93
86
  Apply second moment correction on graph.
@@ -96,15 +89,14 @@ def apply_second_moment_correction_to_graph(graph: Graph,
96
89
  representative_data_gen (Callable): Dataset used for calibration.
97
90
  core_config (CoreConfig): Configuration object containing parameters of how the model should be
98
91
  quantized, including mixed precision parameters.
99
- fw_info: FrameworkInfo object with information about the specific framework's model.
100
92
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
101
93
 
102
94
  Returns:
103
95
  Graph after second moment correction.
104
96
  """
105
- semi_quantized_model = quantized_model_builder_for_second_moment_correction(graph, fw_info, fw_impl)
97
+ semi_quantized_model = quantized_model_builder_for_second_moment_correction(graph, fw_impl)
106
98
  fw_impl.apply_second_moment_correction(semi_quantized_model, core_config, representative_data_gen, graph)
107
99
  graph = substitute(graph, fw_impl.get_substitutions_after_second_moment_correction(core_config.quantization_config))
108
- _collect_and_assign_act_threshold(graph, representative_data_gen, core_config, fw_info, fw_impl)
100
+ _collect_and_assign_act_threshold(graph, representative_data_gen, core_config, fw_impl)
109
101
 
110
102
  return graph
@@ -64,7 +64,6 @@ def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray:
64
64
 
65
65
  def compute_activation_bias_correction(graph: Graph,
66
66
  quant_config: QuantizationConfig,
67
- fw_info: FrameworkInfo,
68
67
  fw_impl: FrameworkImplementation,
69
68
  linear_node: BaseNode,
70
69
  prev_node: BaseNode,
@@ -76,7 +75,6 @@ def compute_activation_bias_correction(graph: Graph,
76
75
  Args:
77
76
  graph: Graph with nodes to compute the activation bias correction for each node's final activation quantization configuration.
78
77
  quant_config: QuantizationConfig of how the model should be quantized.
79
- fw_info: Framework info like lists of nodes their kernel should quantized.
80
78
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
81
79
  linear_node: Node to compute the activation bias correction for.
82
80
  prev_node: Node to compute the activation error caused by his activation quantization.
@@ -127,19 +125,18 @@ def compute_activation_bias_correction(graph: Graph,
127
125
  if normalized_bias < quant_config.activation_bias_correction_threshold:
128
126
  return graph
129
127
 
130
- kernel = linear_node.get_weights_by_keys(fw_info.kernel_ops_attributes_mapping.get(linear_node.type)[0])
128
+ kernel = linear_node.get_weights_by_keys(linear_node.kernel_attr)
131
129
 
132
130
  # Compute the activation bias correction by applying the quantization error to the kernel, resulting in an output
133
131
  # size matching the number of output channels.
134
132
  if kernel is not None:
135
133
 
136
134
  # Get the axes that are not the output channel.
137
- output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(linear_node.type)
138
135
  axis_not_output_channel = list(range(len(kernel.shape)))
139
- axis_not_output_channel.remove(output_channel_index)
136
+ axis_not_output_channel.remove(linear_node.channel_axis.output)
140
137
 
141
138
  # Special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters.
142
- if output_channel_index == input_channel_index:
139
+ if linear_node.channel_axis.output == linear_node.channel_axis.input:
143
140
  axis_not_output_channel.remove(3) # 3 is the depth multiplier index.
144
141
 
145
142
  activation_bias_correction_term = mean_diff * np.sum(kernel, axis=tuple(axis_not_output_channel))
@@ -150,7 +147,6 @@ def compute_activation_bias_correction(graph: Graph,
150
147
 
151
148
  def compute_activation_bias_correction_of_graph(graph: Graph,
152
149
  quant_config: QuantizationConfig,
153
- fw_info: FrameworkInfo,
154
150
  fw_impl: FrameworkImplementation,
155
151
  activation_bias_correction_node_matchers: Callable,
156
152
  kernel_size: str) -> Graph:
@@ -160,7 +156,6 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
160
156
  Args:
161
157
  graph: Graph with nodes to compute the activation bias correction.
162
158
  quant_config: QuantizationConfig of how the model should be quantized.
163
- fw_info: Framework info like lists of nodes their kernel should quantized.
164
159
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
165
160
  activation_bias_correction_node_matchers: Function to match the layers for activation bias correction.
166
161
  kernel_size: The framework specific attribute name of the convolution layer's kernel size.
@@ -177,7 +172,6 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
177
172
  if prev_node is not None:
178
173
  graph = compute_activation_bias_correction(graph=graph,
179
174
  quant_config=quant_config,
180
- fw_info=fw_info,
181
175
  fw_impl=fw_impl,
182
176
  linear_node=n,
183
177
  prev_node=prev_node,
@@ -18,7 +18,6 @@ from typing import Any
18
18
  import numpy as np
19
19
 
20
20
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
21
  from model_compression_toolkit.core.common import BaseNode, Graph
23
22
  from model_compression_toolkit.core.common.quantization.quantize_node import get_quantized_weights_attr_by_qc
24
23
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
@@ -26,7 +25,6 @@ from model_compression_toolkit.logger import Logger
26
25
 
27
26
 
28
27
  def compute_bias_correction_of_graph(graph: Graph,
29
- fw_info: FrameworkInfo,
30
28
  fw_impl: FrameworkImplementation) -> Graph:
31
29
  """
32
30
  For each node in a graph, and for each candidate weights quantization configuration,
@@ -35,7 +33,6 @@ def compute_bias_correction_of_graph(graph: Graph,
35
33
  Args:
36
34
  graph: Graph with nodes to compute the bias correction for
37
35
  each node's weights quantization configuration candidates.
38
- fw_info: Framework info like lists of nodes their kernel should quantized.
39
36
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
40
37
 
41
38
  Returns:
@@ -46,17 +43,15 @@ def compute_bias_correction_of_graph(graph: Graph,
46
43
  for n in graph.nodes:
47
44
  # Bias correction is computed based on the quantized kernel, so we need to get the specific kernel attribute
48
45
  # name out of all the weights attributes of the node.
49
- if fw_info.is_kernel_op(n.type):
50
- kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0]
51
- if n.is_weights_quantization_enabled(kernel_attr):
46
+ if n.is_kernel_op:
47
+ if n.is_weights_quantization_enabled(n.kernel_attr):
52
48
  # Bias correction is not applied to layers with constant inputs.
53
49
  if n.has_positional_weights:
54
50
  for candidate_qc in n.candidates_quantization_cfg:
55
51
  candidate_qc.weights_quantization_cfg.weights_bias_correction = False
56
52
  else:
57
53
  _compute_bias_correction_per_candidate_qc(n,
58
- kernel_attr,
59
- fw_info,
54
+ n.kernel_attr,
60
55
  graph.get_in_stats_collector(n),
61
56
  fw_impl=fw_impl)
62
57
  return graph
@@ -64,7 +59,6 @@ def compute_bias_correction_of_graph(graph: Graph,
64
59
 
65
60
  def _compute_bias_correction_per_candidate_qc(node: BaseNode,
66
61
  kernel_attr: str,
67
- fw_info: FrameworkInfo,
68
62
  node_in_stats_collector: BaseStatsCollector,
69
63
  fw_impl: FrameworkImplementation):
70
64
  """
@@ -74,7 +68,6 @@ def _compute_bias_correction_per_candidate_qc(node: BaseNode,
74
68
  Args:
75
69
  node: Node to compute the bias correction for its different candidates.
76
70
  kernel_attr: The name of the kernel attribute of the node.
77
- fw_info: Framework info like lists of nodes their kernel should quantized.
78
71
  node_in_stats_collector: Statistics collector of the node for the mean per-channel.
79
72
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
80
73
 
@@ -32,7 +32,6 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
32
32
 
33
33
  def statistics_correction_runner(transformed_graph: Graph,
34
34
  core_config: CoreConfig,
35
- fw_info: FrameworkInfo,
36
35
  fw_impl: FrameworkImplementation,
37
36
  tb_w: TensorboardWriter = None, ) -> Graph:
38
37
  """
@@ -41,7 +40,6 @@ def statistics_correction_runner(transformed_graph: Graph,
41
40
  transformed_graph: Graph to add statistics correction.
42
41
  core_config (CoreConfig): Configuration object containing parameters of how the model should be
43
42
  quantized, including mixed precision parameters.
44
- fw_info: FrameworkInfo object with information about the specific framework's model.
45
43
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
46
44
  tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
47
45
 
@@ -59,7 +57,6 @@ def statistics_correction_runner(transformed_graph: Graph,
59
57
  # Compute bias correction to nodes' config candidates
60
58
  ########################################################
61
59
  tg_with_bias = compute_bias_correction_of_graph(tg_with_bias,
62
- fw_info,
63
60
  fw_impl)
64
61
 
65
62
  if tb_w is not None:
@@ -71,7 +68,6 @@ def statistics_correction_runner(transformed_graph: Graph,
71
68
  def apply_statistics_correction(transformed_graph: Graph,
72
69
  representative_data_gen: Callable,
73
70
  core_config: CoreConfig,
74
- fw_info: FrameworkInfo,
75
71
  fw_impl: FrameworkImplementation,
76
72
  tb_w: TensorboardWriter = None, ) -> Graph:
77
73
  """
@@ -81,7 +77,6 @@ def apply_statistics_correction(transformed_graph: Graph,
81
77
  representative_data_gen (Callable): Dataset used for calibration.
82
78
  core_config (CoreConfig): Configuration object containing parameters of how the model should be
83
79
  quantized, including mixed precision parameters.
84
- fw_info: FrameworkInfo object with information about the specific framework's model.
85
80
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
86
81
  tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
87
82
 
@@ -94,7 +89,7 @@ def apply_statistics_correction(transformed_graph: Graph,
94
89
  #############################################
95
90
  if core_config.quantization_config.weights_second_moment_correction:
96
91
  transformed_graph = apply_second_moment_correction_to_graph(transformed_graph, representative_data_gen,
97
- core_config, fw_info, fw_impl)
92
+ core_config, fw_impl)
98
93
 
99
94
  #############################################
100
95
  # Apply Bias Correction
@@ -97,10 +97,9 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
97
97
  # This feature disabled for models with weights quantization method of Power of 2
98
98
  for qc in source_node.candidates_quantization_cfg:
99
99
  # this feature is relevant only for layers with kernel op
100
- kernel_attr = graph.fw_info.get_kernel_op_attributes(source_node.type)
101
- if kernel_attr is None:
100
+ if source_node.kernel_attr is None:
102
101
  Logger.error(f"Can't preform BatchNorm reconstruction on a node {source_node.name} without a kernel op.")
103
- if (qc.weights_quantization_cfg.get_attr_config(kernel_attr[0]).weights_quantization_method
102
+ if (qc.weights_quantization_cfg.get_attr_config(source_node.kernel_attr).weights_quantization_method
104
103
  == QuantizationMethod.POWER_OF_TWO):
105
104
  Logger.warning("Second moment statistics correction feature disabled for models with weights "
106
105
  "quantization method of Power of 2")
@@ -157,7 +157,7 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
157
157
  graph.remove_node(bn_node)
158
158
  graph.remove_node(source_node)
159
159
 
160
- self._calc_weights_quantization_params(conv_bn, weights_scale, graph.fw_info)
160
+ self._calc_weights_quantization_params(conv_bn, weights_scale)
161
161
 
162
162
  assert num_nodes_before_substitution - len(graph.nodes) == 1
163
163
  assert num_edges_before_substitution - len(graph.edges) == 1
@@ -165,18 +165,15 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
165
165
 
166
166
  def _calc_weights_quantization_params(self,
167
167
  conv_bn: BaseNode,
168
- weights_scale: np.ndarray,
169
- fw_info):
168
+ weights_scale: np.ndarray):
170
169
  """
171
170
  Update node weights quantization params.
172
171
  Args:
173
172
  conv_bn: Convolution node to update the weights quantization params.
174
173
  weights_scale: Weight scale factor in which to multiply the conv node's weight.
175
- fw_info: FrameworkInfo object with information about the specific framework's model
176
174
  """
177
175
  # Conv layer is ensured to have a kernel attribute
178
- kernel_attr = fw_info.get_kernel_op_attributes(conv_bn.type)[0]
179
- conv_bn_kernel_cfg = conv_bn.final_weights_quantization_cfg.get_attr_config(kernel_attr)
176
+ conv_bn_kernel_cfg = conv_bn.final_weights_quantization_cfg.get_attr_config(conv_bn.kernel_attr)
180
177
  # In case of SYMMETRIC weight quantization method, we update the threshold by weights_scale
181
178
  if conv_bn_kernel_cfg.weights_quantization_method == QuantizationMethod.SYMMETRIC:
182
179
  original_threshold = conv_bn_kernel_cfg.weights_quantization_params[THRESHOLD]
@@ -20,8 +20,6 @@ import scipy
20
20
 
21
21
  from model_compression_toolkit.core import common
22
22
  from model_compression_toolkit.core.common import Graph, BaseNode
23
- from model_compression_toolkit.defaultdict import DefaultDict
24
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
25
23
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
26
24
 
27
25
 
@@ -77,7 +75,6 @@ def fixed_second_moment_after_relu(mu: np.ndarray,
77
75
 
78
76
  def scale_reshaping(scale: np.ndarray,
79
77
  op2d: common.BaseNode,
80
- kernel_channel_mapping: DefaultDict,
81
78
  kernel_str: str,
82
79
  in_channels: bool = True) -> np.ndarray:
83
80
  """
@@ -89,7 +86,6 @@ def scale_reshaping(scale: np.ndarray,
89
86
  Args:
90
87
  scale: Scale factor to scale the kernel channels by.
91
88
  op2d: Node to scale its kernel.
92
- kernel_channel_mapping: Mapping from a layer to a tuple of indices of its output/input kernel channels.
93
89
  kernel_str: The framework specific attribute name of the convolution layer's weight/kernel.
94
90
  in_channels: Kernel's index of input channels.
95
91
 
@@ -99,12 +95,11 @@ def scale_reshaping(scale: np.ndarray,
99
95
 
100
96
  op_ndims = op2d.get_weights_by_keys(kernel_str).ndim
101
97
  reshape_target = np.ones(op_ndims, dtype=np.int32)
102
- reshape_target[kernel_channel_mapping.get(op2d.type)[int(in_channels)]] = -1
98
+ reshape_target[op2d.channel_axis.input if in_channels else op2d.channel_axis.output] = -1
103
99
  return np.reshape(scale, reshape_target)
104
100
 
105
101
 
106
- def update_linear_nodes(fw_info: FrameworkInfo,
107
- first_op2d_node: BaseNode,
102
+ def update_linear_nodes(first_op2d_node: BaseNode,
108
103
  second_op2d_node: BaseNode,
109
104
  scale_factor: np.ndarray,
110
105
  kernel_str: str,
@@ -116,7 +111,6 @@ def update_linear_nodes(fw_info: FrameworkInfo,
116
111
  The scale factor contain a scale value per-channel.
117
112
 
118
113
  Args:
119
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
120
114
  groups of layers by how they should be quantized, etc.)
121
115
  first_op2d_node: Node to multiply its kernel by the scale factor.
122
116
  second_op2d_node: Node to divide its kernel by the scale factor.
@@ -125,15 +119,12 @@ def update_linear_nodes(fw_info: FrameworkInfo,
125
119
  kernel_str: The framework specific attribute name of the convolution layer's weight/kernel.
126
120
 
127
121
  """
128
-
129
122
  w2_fixed = second_op2d_node.get_weights_by_keys(kernel_str) / scale_reshaping(scale_factor,
130
123
  second_op2d_node,
131
- fw_info.kernel_channels_mapping,
132
124
  kernel_str)
133
125
 
134
126
  w1_fixed = first_op2d_node.get_weights_by_keys(kernel_str) * scale_reshaping(scale_factor,
135
127
  first_op2d_node,
136
- fw_info.kernel_channels_mapping,
137
128
  kernel_str,
138
129
  in_channels=False)
139
130
 
@@ -168,8 +159,7 @@ def calculate_scale_correction(first_op2d_node: BaseNode) -> tuple:
168
159
  return scale_factor
169
160
 
170
161
 
171
- def scale_equalization_lnl(fw_info: FrameworkInfo,
172
- first_op2d_node: BaseNode,
162
+ def scale_equalization_lnl(first_op2d_node: BaseNode,
173
163
  second_op2d_node: BaseNode,
174
164
  kernel_str: str,
175
165
  bias_str: str):
@@ -179,7 +169,6 @@ def scale_equalization_lnl(fw_info: FrameworkInfo,
179
169
  follows the activation node to get the same expected output without the scaling.
180
170
 
181
171
  Args:
182
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
183
172
  groups of layers by how they should be quantized, etc.)
184
173
  first_op2d_node: Node to multiply its kernel by the scale factor.
185
174
  second_op2d_node: Node to divide its kernel by the scale factor.
@@ -189,8 +178,7 @@ def scale_equalization_lnl(fw_info: FrameworkInfo,
189
178
  """
190
179
  scale_factor = calculate_scale_correction(first_op2d_node)
191
180
 
192
- update_linear_nodes(fw_info,
193
- first_op2d_node,
181
+ update_linear_nodes(first_op2d_node,
194
182
  second_op2d_node,
195
183
  scale_factor,
196
184
  kernel_str,
@@ -206,7 +194,6 @@ class BaseScaleEqualization(common.BaseSubstitution):
206
194
 
207
195
  def __init__(self,
208
196
  quant_config: QuantizationConfig,
209
- fw_info: FrameworkInfo,
210
197
  matcher_instance,
211
198
  kernel_str: str,
212
199
  bias_str: str):
@@ -214,13 +201,11 @@ class BaseScaleEqualization(common.BaseSubstitution):
214
201
  Initialize a ScaleEqualization object.
215
202
  Args:
216
203
  quant_config: QuantizationConfig containing parameters of how the model should be quantized.
217
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
218
204
  groups of layers by how they should be quantized, etc.)
219
205
  matcher_instance: Per substitution matcher instance of type WalkMatcher
220
206
  """
221
207
 
222
208
  self.quant_config = quant_config
223
- self.fw_info = fw_info
224
209
  self.kernel_str = kernel_str
225
210
  self.bias_str = bias_str
226
211
  super().__init__(matcher_instance=matcher_instance)
@@ -243,8 +228,7 @@ class BaseScaleEqualization(common.BaseSubstitution):
243
228
  act_node = nodes_list[1]
244
229
  second_op2d_node = nodes_list[-1]
245
230
  if first_op2d_node.prior_info.std_output is not None and act_node.is_activation_quantization_enabled():
246
- scale_equalization_lnl(self.fw_info,
247
- first_op2d_node,
231
+ scale_equalization_lnl(first_op2d_node,
248
232
  second_op2d_node,
249
233
  self.kernel_str,
250
234
  self.bias_str)
@@ -46,7 +46,6 @@ If the linear node pads the input tensor with zeros, we modify the padded value
46
46
 
47
47
  def op2d_bias_correction(op2d_node: BaseNode,
48
48
  shift_to_correct: float,
49
- fw_info: FrameworkInfo,
50
49
  bias_str: str,
51
50
  bias_flag_str: str):
52
51
  """
@@ -57,7 +56,6 @@ def op2d_bias_correction(op2d_node: BaseNode,
57
56
  op2d_node: Node to compute its bias correction term.
58
57
  shift_to_correct: Value that was used to shift the output tensor of
59
58
  the non-linear node.
60
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
61
59
  bias_str:
62
60
  bias_flag_str: The framework specific attribute name of the bias flag.
63
61
  """
@@ -76,14 +74,13 @@ def op2d_bias_correction(op2d_node: BaseNode,
76
74
  # Each node adds a different noise due to the shifting. It depends on the
77
75
  # dimensions of the kernel, thus the correction term is a function of
78
76
  # the layer type.
79
- kernel = op2d_node.get_weights_by_keys(fw_info.kernel_ops_attributes_mapping.get(op2d_node.type)[0])
77
+ kernel = op2d_node.get_weights_by_keys(op2d_node.kernel_attr)
80
78
  if kernel is not None:
81
- output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(op2d_node.type)
82
79
  axis_not_output_channel = list(range(len(kernel.shape)))
83
- axis_not_output_channel.remove(output_channel_index)
80
+ axis_not_output_channel.remove(op2d_node.channel_axis.output)
84
81
 
85
82
  # special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters
86
- if output_channel_index == input_channel_index:
83
+ if op2d_node.channel_axis.output == op2d_node.channel_axis.input:
87
84
  axis_not_output_channel.remove(3) # 3 is the depth multiplier index
88
85
 
89
86
  bias_correction = shift_to_correct * np.sum(kernel, axis=tuple(axis_not_output_channel))
@@ -250,7 +247,6 @@ def shift_negative_function(graph: Graph,
250
247
  core_config: CoreConfig,
251
248
  non_linear_node: BaseNode,
252
249
  op2d_node: BaseNode,
253
- fw_info: FrameworkInfo,
254
250
  create_add_node: Callable,
255
251
  get_padding_values: Callable,
256
252
  create_pad_node: Callable,
@@ -276,8 +272,6 @@ def shift_negative_function(graph: Graph,
276
272
  non_linear_node: Non-linear node with negative values to shift.
277
273
  op2d_node: Linear node to correct its bias to overcome the expected error due to
278
274
  the shifting.
279
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
280
- groups of layers by how they should be quantized, etc.)
281
275
  create_add_node: Function to create an add node.
282
276
  get_padding_values: Function to compute the op2d node's padding values
283
277
  create_pad_node: Function to create an pad node.
@@ -299,7 +293,6 @@ def shift_negative_function(graph: Graph,
299
293
  # all candidates have same activation config, so taking the first candidate for calculations
300
294
  non_linear_node_cfg_candidate = non_linear_node.candidates_quantization_cfg[0].activation_quantization_cfg
301
295
 
302
-
303
296
  # get the non-linear activation threshold
304
297
  activation_threshold = non_linear_node_cfg_candidate.activation_quantization_params.get(THRESHOLD)
305
298
 
@@ -390,7 +383,6 @@ def shift_negative_function(graph: Graph,
390
383
  first_node=non_linear_node)
391
384
  op2d_bias_correction(op2d_node,
392
385
  shift_value,
393
- fw_info,
394
386
  bias_str,
395
387
  bias_flag_str)
396
388
 
@@ -401,8 +393,7 @@ def shift_negative_function(graph: Graph,
401
393
  graph.set_out_stats_collector_to_node(add_node, add_node_stats_collector)
402
394
  graph.shift_stats_collector(add_node, np.array(shift_value))
403
395
 
404
- set_quantization_configs_to_node(fw_info=fw_info,
405
- node=add_node,
396
+ set_quantization_configs_to_node(node=add_node,
406
397
  graph=graph,
407
398
  quant_config=core_config.quantization_config,
408
399
  fqc=graph.fqc,
@@ -428,8 +419,7 @@ def shift_negative_function(graph: Graph,
428
419
  last_node=op2d_node)
429
420
 
430
421
  # Set quantization configuration to node, even though we do not quantize it:
431
- set_quantization_configs_to_node(fw_info=fw_info,
432
- node=pad_node,
422
+ set_quantization_configs_to_node(node=pad_node,
433
423
  graph=graph,
434
424
  quant_config=core_config.quantization_config,
435
425
  fqc=graph.fqc,
@@ -472,7 +462,6 @@ def shift_negative_function(graph: Graph,
472
462
  candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False
473
463
 
474
464
  candidate_qc.activation_quantization_cfg = create_node_activation_qc(core_config.quantization_config,
475
- fw_info,
476
465
  add_node_qco[op_qc_idx])
477
466
 
478
467
  candidate_qc.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold,
@@ -573,7 +562,6 @@ def get_next_nodes_to_correct(n: BaseNode,
573
562
 
574
563
  def apply_shift_negative_correction(graph: Graph,
575
564
  core_config: CoreConfig,
576
- fw_info: FrameworkInfo,
577
565
  snc_node_types: NodeOperationMatcher,
578
566
  linear_node_types: NodeOperationMatcher,
579
567
  bypass_node_types: NodeOperationMatcher,
@@ -593,7 +581,6 @@ def apply_shift_negative_correction(graph: Graph,
593
581
  Args:
594
582
  graph: Graph to apply the substitution on.
595
583
  core_config: Quantization configuration to build the substitutions list according to.
596
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
597
584
  groups of layers by how they should be quantized, etc.)
598
585
  snc_node_types: Types of activation nodes with negative outputs to consider.
599
586
  linear_node_types: Types of linear nodes to consider.
@@ -632,7 +619,6 @@ def apply_shift_negative_correction(graph: Graph,
632
619
  core_config,
633
620
  n,
634
621
  linear_node,
635
- fw_info,
636
622
  create_add_node,
637
623
  get_padding_values,
638
624
  create_pad_node,
@@ -50,9 +50,7 @@ class BaseVirtualActivationWeightsComposition(BaseSubstitution):
50
50
  return graph
51
51
 
52
52
  # Virtual composed activation-weights node
53
- v_node = VirtualActivationWeightsNode(act_node,
54
- weights_node,
55
- fw_info=graph.fw_info)
53
+ v_node = VirtualActivationWeightsNode(act_node, weights_node)
56
54
 
57
55
  # Update graph
58
56
  graph.add_node(v_node)
@@ -50,7 +50,7 @@ class BaseWeightsActivationSplit(BaseSubstitution):
50
50
  Graph after applying the substitution.
51
51
  """
52
52
  # The decomposition works on linear nodes, that is, nodes with kernel ops
53
- kernel_attr = graph.fw_info.get_kernel_op_attributes(node.type)[0]
53
+ kernel_attr = node.kernel_attr
54
54
  if kernel_attr is None:
55
55
  Logger.critical(f"Trying to split node weights and activation, but node "
56
56
  f"{node.name} doesn't have a kernel attribute.")
@@ -59,22 +59,19 @@ class NNVisualizer:
59
59
  def __init__(self,
60
60
  graph_float: Graph,
61
61
  graph_quantized: Graph,
62
- fw_impl: FrameworkImplementation,
63
- fw_info: FrameworkInfo):
62
+ fw_impl: FrameworkImplementation):
64
63
  """
65
64
  Initialize a NNVisualizer object.
66
65
  Args:
67
66
  graph_float: Float version of the graph.
68
67
  graph_quantized: Quantized version of the graph.
69
68
  fw_impl: Framework implementation with framework-specific methods implementation.
70
- fw_info: Framework info with framework-specific information.
71
69
 
72
70
  """
73
71
 
74
72
  self.graph_float = graph_float
75
73
  self.graph_quantized = graph_quantized
76
74
  self.fw_impl = fw_impl
77
- self.fw_info = fw_info
78
75
 
79
76
  # Get compare points of two graphs.
80
77
  self.compare_points, self.compare_points_name = _get_compare_points(self.graph_quantized)
@@ -92,13 +89,11 @@ class NNVisualizer:
92
89
 
93
90
  self.quantized_model, _ = self.fw_impl.model_builder(self.graph_quantized,
94
91
  mode=ModelBuilderMode.QUANTIZED,
95
- append2output=self.compare_points,
96
- fw_info=self.fw_info)
92
+ append2output=self.compare_points)
97
93
 
98
94
  self.float_model, _ = self.fw_impl.model_builder(self.graph_float,
99
95
  mode=ModelBuilderMode.FLOAT,
100
- append2output=self.compare_points_float,
101
- fw_info=self.fw_info)
96
+ append2output=self.compare_points_float)
102
97
 
103
98
  def has_compare_points(self) -> bool:
104
99
  """
@@ -89,20 +89,18 @@ class TensorboardWriter(object):
89
89
  Class to log events to display using Tensorboard such as graphs, histograms, images, etc.
90
90
  """
91
91
 
92
- def __init__(self, dir_path: str, fw_info: FrameworkInfo):
92
+ def __init__(self, dir_path: str):
93
93
  """
94
94
  Initialize a TensorboardWriter object.
95
95
 
96
96
  Args:
97
97
  dir_path: Path to save all events to display on Tensorboard.
98
- fw_info: FrameworkInfo object (needed for computing nodes' weights memory).
99
98
 
100
99
  """
101
100
  self.dir_path = dir_path
102
101
  # we hold EventWriter per tag name, so events can be gathered by tags (like phases during the quantization
103
102
  # process).
104
103
  self.tag_name_to_event_writer = {}
105
- self.fw_info = fw_info
106
104
 
107
105
  def close(self):
108
106
  """
@@ -232,7 +230,7 @@ class TensorboardWriter(object):
232
230
  if n.final_weights_quantization_cfg is not None:
233
231
  attr.update(n.final_weights_quantization_cfg.__dict__)
234
232
  elif n.candidates_quantization_cfg is not None:
235
- attr.update(n.get_unified_weights_candidates_dict(self.fw_info))
233
+ attr.update(n.get_unified_weights_candidates_dict())
236
234
  return attr
237
235
 
238
236
  def __get_node_attr(n: BaseNode) -> Dict[str, Any]:
@@ -296,7 +294,7 @@ class TensorboardWriter(object):
296
294
 
297
295
  return NodeExecStats(node_name=n.name,
298
296
  memory=[AllocatorMemoryUsed(
299
- total_bytes=int(n.get_memory_bytes(self.fw_info))
297
+ total_bytes=int(n.get_memory_bytes())
300
298
  )])
301
299
 
302
300
  graph_def = GraphDef() # GraphDef to add to Tensorboard
@@ -526,13 +524,13 @@ class TensorboardWriter(object):
526
524
  er.add_event(event)
527
525
  er.flush()
528
526
 
529
- def init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter:
527
+
528
+ def init_tensorboard_writer() -> TensorboardWriter:
530
529
  """
531
530
  Create a TensorBoardWriter object initialized with the logger dir path if it was set,
532
531
  or None otherwise.
533
532
 
534
533
  Args:
535
- fw_info: FrameworkInfo object.
536
534
 
537
535
  Returns:
538
536
  A TensorBoardWriter object.
@@ -541,7 +539,7 @@ def init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter:
541
539
  if Logger.LOG_PATH is not None:
542
540
  tb_log_dir = os.path.join(os.getcwd(), Logger.LOG_PATH, 'tensorboard_logs')
543
541
  Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}')
544
- tb_w = TensorboardWriter(tb_log_dir, fw_info)
542
+ tb_w = TensorboardWriter(tb_log_dir)
545
543
  return tb_w
546
544
 
547
545