mct-nightly 2.4.0.20250705.556__py3-none-any.whl → 2.4.0.20250707.643__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {mct_nightly-2.4.0.20250705.556.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.4.0.20250705.556.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/RECORD +36 -38
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/collectors/base_collector.py +4 -1
  5. model_compression_toolkit/core/common/collectors/mean_collector.py +7 -4
  6. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +7 -4
  7. model_compression_toolkit/core/common/model_collector.py +11 -0
  8. model_compression_toolkit/core/common/pruning/memory_calculator.py +1 -1
  9. model_compression_toolkit/core/common/quantization/node_quantization_config.py +22 -87
  10. model_compression_toolkit/core/common/quantization/quantization_config.py +0 -1
  11. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +23 -17
  12. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +26 -48
  13. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +12 -7
  14. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +4 -14
  15. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -1
  16. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +4 -13
  17. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +3 -3
  18. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +5 -7
  19. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +7 -5
  20. model_compression_toolkit/core/graph_prep_runner.py +1 -11
  21. model_compression_toolkit/core/keras/default_framework_info.py +1 -1
  22. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +21 -11
  23. model_compression_toolkit/core/keras/keras_implementation.py +2 -2
  24. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +8 -0
  25. model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
  26. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +9 -1
  27. model_compression_toolkit/core/quantization_prep_runner.py +2 -2
  28. model_compression_toolkit/gptq/keras/quantization_facade.py +0 -3
  29. model_compression_toolkit/ptq/keras/quantization_facade.py +0 -3
  30. model_compression_toolkit/qat/keras/quantization_facade.py +0 -3
  31. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +0 -2
  32. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +0 -6
  33. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +2 -4
  34. model_compression_toolkit/core/common/model_validation.py +0 -41
  35. model_compression_toolkit/core/keras/keras_model_validation.py +0 -37
  36. {mct_nightly-2.4.0.20250705.556.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/WHEEL +0 -0
  37. {mct_nightly-2.4.0.20250705.556.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/licenses/LICENSE.md +0 -0
  38. {mct_nightly-2.4.0.20250705.556.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/top_level.txt +0 -0
@@ -18,21 +18,25 @@ from typing import Dict, Union, Optional, Tuple, Callable
18
18
  from mct_quantizers import QuantizationMethod
19
19
 
20
20
  import model_compression_toolkit.core.common.quantization.quantization_params_generation as qpg
21
+ from model_compression_toolkit.constants import MIN_THRESHOLD
21
22
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
22
23
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
23
24
  from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
24
25
  from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
25
- from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationErrorMethod
26
+ from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationErrorMethod, \
27
+ QuantizationConfig
26
28
 
27
29
 
28
- def compute_activation_qparams(activation_quant_cfg: NodeActivationQuantizationConfig,
30
+ def compute_activation_qparams(quant_cfg: QuantizationConfig,
31
+ node_activation_quant_cfg: NodeActivationQuantizationConfig,
29
32
  node_prior_info: NodePriorInfo,
30
33
  out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
31
34
  """
32
35
  Compute the activations params for a given node in a graph according to a params function.
33
36
 
34
37
  Args:
35
- activation_quant_cfg: node's activation quantization configuration.
38
+ quant_cfg: quantization config.
39
+ node_activation_quant_cfg: node's activation quantization configuration.
36
40
  node_prior_info: Prior info collected for the node that is being quantized.
37
41
  out_stats_container: Tensor containing output statistics of the node.
38
42
 
@@ -40,41 +44,43 @@ def compute_activation_qparams(activation_quant_cfg: NodeActivationQuantizationC
40
44
  The computed activation quantization params.
41
45
  """
42
46
  activation_quantization_params_fn = _get_activation_quantization_params_fn(
43
- activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded())
47
+ node_activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded())
44
48
 
45
49
  # Extract and filter histogram data from the statistics container.
46
- bins_values, bins_counts = _get_histogram_data(activation_quant_cfg, out_stats_container)
50
+ bins_values, bins_counts = _get_histogram_data(out_stats_container,
51
+ activation_error_method=quant_cfg.activation_error_method,
52
+ z_threshold=quant_cfg.z_threshold)
47
53
 
48
54
  # Retrieve the minimum and maximum values from the statistics container.
49
55
  min_value, max_value = out_stats_container.get_min_max_values()
50
56
 
51
57
  # Determine if the activations should be considered signed.
52
- signed = _determine_signedness(activation_quant_cfg, node_prior_info, min_value, bins_values, bins_counts)
58
+ signed = _determine_signedness(node_activation_quant_cfg, node_prior_info, min_value, bins_values, bins_counts)
53
59
 
54
60
  # Compute and return the activation quantization parameters.
55
61
  return activation_quantization_params_fn(
56
62
  bins_values,
57
63
  bins_counts,
58
- activation_quant_cfg.l_p_value,
59
- activation_quant_cfg.activation_n_bits,
64
+ quant_cfg.l_p_value,
65
+ node_activation_quant_cfg.activation_n_bits,
60
66
  min_value,
61
67
  max_value,
62
- min_threshold=activation_quant_cfg.min_threshold,
63
- quant_error_method=activation_quant_cfg.activation_error_method,
68
+ min_threshold=MIN_THRESHOLD,
69
+ quant_error_method=quant_cfg.activation_error_method,
64
70
  is_signed=signed
65
71
  )
66
72
 
67
73
 
68
- def _get_histogram_data(
69
- activation_quant_cfg: NodeActivationQuantizationConfig,
70
- out_stats_container: BaseStatsCollector
71
- ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
74
+ def _get_histogram_data(out_stats_container: BaseStatsCollector,
75
+ activation_error_method: QuantizationErrorMethod,
76
+ z_threshold: float) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
72
77
  """
73
78
  Extract and filter the histogram data from the statistics container.
74
79
 
75
80
  Args:
76
- activation_quant_cfg: Node's activation quantization configuration.
77
81
  out_stats_container: Statistics container with histogram data.
82
+ activation_error_method: activation quantization error method.
83
+ z_threshold: z threshold for z-score filtering.
78
84
 
79
85
  Returns:
80
86
  A tuple containing the filtered bins_values and bins_counts.
@@ -83,12 +89,12 @@ def _get_histogram_data(
83
89
  # If the statistics container collected the histogram, we start by filtering outliers using z threshold
84
90
  # filtering, and then computing the threshold based on the filtered histogram.
85
91
  if out_stats_container.require_collection():
86
- if activation_quant_cfg.activation_error_method == QuantizationErrorMethod.HMSE:
92
+ if activation_error_method == QuantizationErrorMethod.HMSE:
87
93
  bins_values, bins_counts = out_stats_container.weighted_hc.get_histogram()
88
94
  else:
89
95
  bins_values, bins_counts = out_stats_container.hc.get_histogram()
90
96
  bins_counts = qpg.z_score_filter(
91
- activation_quant_cfg.z_threshold,
97
+ z_threshold,
92
98
  bins_values,
93
99
  bins_counts
94
100
  )
@@ -18,7 +18,7 @@ from tqdm import tqdm
18
18
  from typing import List, Callable, Generator
19
19
 
20
20
  from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
21
- from model_compression_toolkit.core import QuantizationErrorMethod
21
+ from model_compression_toolkit.core import QuantizationErrorMethod, QuantizationConfig
22
22
  from model_compression_toolkit.core.common import Graph, BaseNode
23
23
  from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
24
24
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
@@ -31,29 +31,8 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
31
31
  from model_compression_toolkit.logger import Logger
32
32
 
33
33
 
34
- def _collect_nodes_for_hmse(nodes_list: List[BaseNode], graph: Graph) -> List[BaseNode]:
35
- """
36
- Collects nodes that are compatiable for parameters selection search using HMSE,
37
- that is, have a kernel attribute that is configured for HMSE error method.
38
-
39
- Args:
40
- nodes_list: A list of nodes to search quantization parameters for.
41
- graph: Graph to compute its nodes' quantization parameters..
42
-
43
- Returns: A (possibly empty) list of nodes.
44
-
45
- """
46
- hmse_nodes = []
47
- for n in nodes_list:
48
- if n.kernel_attr is not None and n.is_weights_quantization_enabled(n.kernel_attr) and \
49
- all([c.weights_quantization_cfg.get_attr_config(n.kernel_attr).weights_error_method ==
50
- QuantizationErrorMethod.HMSE for c in n.candidates_quantization_cfg]):
51
- hmse_nodes.append(n)
52
-
53
- return hmse_nodes
54
-
55
-
56
34
  def calculate_quantization_params(graph: Graph,
35
+ quant_cfg: QuantizationConfig,
57
36
  fw_impl: FrameworkImplementation,
58
37
  repr_data_gen_fn: Callable[[], Generator],
59
38
  nodes: List[BaseNode] = None,
@@ -68,6 +47,7 @@ def calculate_quantization_params(graph: Graph,
68
47
 
69
48
  Args:
70
49
  graph: Graph to compute its nodes' thresholds.
50
+ quant_cfg: quantization config.
71
51
  fw_impl: FrameworkImplementation object.
72
52
  repr_data_gen_fn: callable returning representative dataset generator.
73
53
  nodes: List of nodes to compute their thresholds instead of computing it for all nodes in the graph.
@@ -85,15 +65,16 @@ def calculate_quantization_params(graph: Graph,
85
65
  # Collecting nodes that are configured to search weights quantization parameters using HMSE optimization
86
66
  # and computing required Hessian information to be used for HMSE parameters selection.
87
67
  # The Hessian scores are computed and stored in the hessian_info_service object.
88
- nodes_for_hmse = _collect_nodes_for_hmse(nodes_list, graph)
89
- if len(nodes_for_hmse) > 0:
90
- dataloader = fw_impl.convert_data_gen_to_dataloader(repr_data_gen_fn, batch_size=1)
91
- request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
92
- granularity=HessianScoresGranularity.PER_ELEMENT,
93
- data_loader=dataloader,
94
- n_samples=num_hessian_samples,
95
- target_nodes=nodes_for_hmse)
96
- hessian_info_service.fetch_hessian(request)
68
+ if quant_cfg.weights_error_method == QuantizationErrorMethod.HMSE:
69
+ nodes_for_hmse = [n for n in nodes_list if n.kernel_attr and n.is_weights_quantization_enabled(n.kernel_attr)]
70
+ if nodes_for_hmse:
71
+ dataloader = fw_impl.convert_data_gen_to_dataloader(repr_data_gen_fn, batch_size=1)
72
+ request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
73
+ granularity=HessianScoresGranularity.PER_ELEMENT,
74
+ data_loader=dataloader,
75
+ n_samples=num_hessian_samples,
76
+ target_nodes=nodes_for_hmse)
77
+ hessian_info_service.fetch_hessian(request)
97
78
 
98
79
  for n in tqdm(nodes_list, "Calculating quantization parameters"): # iterate only nodes that we should compute their thresholds
99
80
  for candidate_qc in n.candidates_quantization_cfg:
@@ -101,28 +82,24 @@ def calculate_quantization_params(graph: Graph,
101
82
  if n.is_weights_quantization_enabled(attr):
102
83
  # If the node's weights attribute should be quantized, we compute its quantization parameters
103
84
  attr_cfg = candidate_qc.weights_quantization_cfg.get_attr_config(attr)
104
- channels_axis = attr_cfg.weights_channels_axis
105
- if channels_axis is not None:
106
- output_channels_axis = channels_axis[0]
107
- else:
108
- output_channels_axis = None
109
-
110
- mod_attr_cfg = attr_cfg
85
+ output_channels_axis = attr_cfg.weights_channels_axis.output
111
86
 
112
- if attr_cfg.weights_error_method == QuantizationErrorMethod.HMSE:
87
+ weights_error_method = quant_cfg.weights_error_method
88
+ if weights_error_method == QuantizationErrorMethod.HMSE:
113
89
  # Although we collected nodes for HMSE before running the loop, we keep this verification to
114
90
  # notify the user in case of HMSE configured for node that is not compatible for this method
115
91
  if n.kernel_attr is None or n.kernel_attr not in attr:
116
92
  Logger.warning(f"The HMSE error method for parameters selection is only supported for "
117
93
  f"kernel weights attributes. Running parameters selection for attribute "
118
94
  f"'{attr}' in node '{n.name}' with the default MSE error method instead.")
119
- mod_attr_cfg = copy.deepcopy(attr_cfg)
120
- mod_attr_cfg.weights_error_method = QuantizationErrorMethod.MSE
95
+ weights_error_method = QuantizationErrorMethod.MSE
121
96
 
122
- min_threshold = candidate_qc.weights_quantization_cfg.min_threshold
123
97
  weights_params, output_channels_axis = compute_weights_qparams(n.get_weights_by_keys(attr),
124
- mod_attr_cfg, output_channels_axis,
125
- min_threshold=min_threshold, node=n,
98
+ attr_cfg,
99
+ weights_error_method,
100
+ quant_cfg.l_p_value,
101
+ output_channels_axis,
102
+ node=n,
126
103
  hessian_info_service=hessian_info_service,
127
104
  num_hessian_samples=num_hessian_samples)
128
105
  attr_cfg.weights_channels_axis = ChannelAxisMapping(output_channels_axis, attr_cfg.weights_channels_axis.input)
@@ -130,8 +107,9 @@ def calculate_quantization_params(graph: Graph,
130
107
 
131
108
  if n.is_activation_quantization_enabled():
132
109
  # If node's activations should be quantized as well, we compute its activation quantization parameters
133
- activation_params = compute_activation_qparams(
134
- activation_quant_cfg=candidate_qc.activation_quantization_cfg, node_prior_info=n.prior_info,
135
- out_stats_container=graph.get_out_stats_collector(n))
110
+ activation_params = compute_activation_qparams(quant_cfg=quant_cfg,
111
+ node_activation_quant_cfg=candidate_qc.activation_quantization_cfg,
112
+ node_prior_info=n.prior_info,
113
+ out_stats_container=graph.get_out_stats_collector(n))
136
114
  # Create a NodeQuantizationConfig containing all quantization params and attach it to the node
137
115
  candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params)
@@ -18,7 +18,8 @@ from typing import Dict, Any, Tuple, Callable, TYPE_CHECKING
18
18
  import numpy as np
19
19
  from mct_quantizers import QuantizationMethod
20
20
 
21
- from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
21
+ from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES, MIN_THRESHOLD
22
+ from model_compression_toolkit.core import QuantizationErrorMethod
22
23
  from model_compression_toolkit.core.common.hessian import HessianInfoService
23
24
  from model_compression_toolkit.core.common.quantization.quantization_params_generation import \
24
25
  power_of_two_selection_tensor, lut_kmeans_tensor, symmetric_selection_tensor, uniform_selection_tensor
@@ -28,10 +29,12 @@ if TYPE_CHECKING:
28
29
  from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
29
30
 
30
31
 
31
- def compute_weights_qparams(weights_attr_values: np.ndarray,
32
+ def compute_weights_qparams(weights_attr_data: np.ndarray,
32
33
  attr_quant_config: 'WeightsAttrQuantizationConfig',
34
+ weights_error_method: QuantizationErrorMethod,
35
+ l_p_value: int,
33
36
  output_channels_axis: int,
34
- min_threshold: float,
37
+ min_threshold: float = MIN_THRESHOLD,
35
38
  node=None,
36
39
  hessian_info_service: HessianInfoService = None,
37
40
  num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Tuple[Dict[Any, Any], int]:
@@ -40,8 +43,10 @@ def compute_weights_qparams(weights_attr_values: np.ndarray,
40
43
  instance.
41
44
 
42
45
  Args:
43
- weights_attr_values: Weights attribute parameter to compute the quantization thresholds for.
46
+ weights_attr_data: Weights attribute parameter to compute the quantization thresholds for.
44
47
  attr_quant_config: A specific weights attribute quantization configuration to get its params.
48
+ weights_error_method: quantization error method.
49
+ l_p_value: p-norm to use for the Lp-norm distance.
45
50
  output_channels_axis: Index of the kernel output channels dimension.
46
51
  min_threshold: Minimal threshold to use if threshold is too small.
47
52
  node: The node for which the quantization error is computed (used only with HMSE error method).
@@ -54,13 +59,13 @@ def compute_weights_qparams(weights_attr_values: np.ndarray,
54
59
  """
55
60
  params_fn = _get_weights_quantization_params_fn(attr_quant_config.weights_quantization_method)
56
61
  weights_params, output_channels_axis = params_fn(
57
- weights_attr_values,
58
- p=attr_quant_config.l_p_value,
62
+ weights_attr_data,
63
+ p=l_p_value,
59
64
  n_bits=attr_quant_config.weights_n_bits,
60
65
  per_channel=attr_quant_config.weights_per_channel_threshold,
61
66
  channel_axis=output_channels_axis,
62
67
  min_threshold=min_threshold,
63
- quant_error_method=attr_quant_config.weights_error_method,
68
+ quant_error_method=weights_error_method,
64
69
  node=node,
65
70
  hessian_info_service=hessian_info_service,
66
71
  num_hessian_samples=num_hessian_samples)
@@ -14,8 +14,6 @@
14
14
  # ==============================================================================
15
15
  import copy
16
16
 
17
- from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
18
- from model_compression_toolkit.core import CoreConfig
19
17
  from model_compression_toolkit.core.common import Graph, BaseNode
20
18
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
19
  from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
@@ -23,7 +21,6 @@ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_s
23
21
 
24
22
 
25
23
  def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
26
- core_config: CoreConfig,
27
24
  fw_impl: FrameworkImplementation) -> Graph:
28
25
  """
29
26
  Get a graph, where each node has a final weights quantization configuration (with a bias
@@ -31,7 +28,6 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
31
28
 
32
29
  Args:
33
30
  graph_to_apply_bias_correction: Graph to apply bias correction to.
34
- core_config: CoreConfig containing parameters of how the model should be quantized.
35
31
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
36
32
 
37
33
  Returns:
@@ -40,20 +36,14 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
40
36
 
41
37
  graph = copy.deepcopy(graph_to_apply_bias_correction)
42
38
  for n in graph.nodes:
43
- # bias correction is only relevant for nodes with kernel op
44
- if core_config.quantization_config.weights_bias_correction and n.kernel_attr is not None and \
45
- n.is_weights_quantization_enabled(n.kernel_attr) and \
46
- not n.final_weights_quantization_cfg.weights_second_moment_correction:
47
- # If a kernel was quantized and weights bias correction is enabled in n.quantization_cfg,
48
- # a bias correction term was calculated during model preparation, and is used now in the node's bias term.
49
- if n.final_weights_quantization_cfg.weights_bias_correction:
50
- _apply_bias_correction_to_node(n, fw_impl, core_config.quantization_config)
39
+ if (n.final_weights_quantization_cfg and n.final_weights_quantization_cfg.bias_corrected is not None and
40
+ not n.final_weights_quantization_cfg.weights_second_moment_correction):
41
+ _apply_bias_correction_to_node(n, fw_impl)
51
42
  return graph
52
43
 
53
44
 
54
45
  def _apply_bias_correction_to_node(node: BaseNode,
55
- fw_impl: FrameworkImplementation,
56
- qc: QuantizationConfig):
46
+ fw_impl: FrameworkImplementation):
57
47
  """
58
48
  Set new bias to node using the bias-correction term that is stored in the
59
49
  final weights quantization configuration.
@@ -52,7 +52,8 @@ def _collect_and_assign_act_threshold(graph: Graph,
52
52
 
53
53
  for n in graph.nodes:
54
54
  if n.is_activation_quantization_enabled():
55
- activation_params = compute_activation_qparams(activation_quant_cfg=n.final_activation_quantization_cfg,
55
+ activation_params = compute_activation_qparams(quant_cfg=core_config.quantization_config,
56
+ node_activation_quant_cfg=n.final_activation_quantization_cfg,
56
57
  node_prior_info=n.prior_info,
57
58
  out_stats_container=graph.get_out_stats_collector(n))
58
59
  n.final_activation_quantization_cfg.set_activation_quantization_param(activation_params)
@@ -43,17 +43,9 @@ def compute_bias_correction_of_graph(graph: Graph,
43
43
  for n in graph.nodes:
44
44
  # Bias correction is computed based on the quantized kernel, so we need to get the specific kernel attribute
45
45
  # name out of all the weights attributes of the node.
46
- if n.kernel_attr:
47
- if n.is_weights_quantization_enabled(n.kernel_attr):
48
- # Bias correction is not applied to layers with constant inputs.
49
- if n.has_positional_weights:
50
- for candidate_qc in n.candidates_quantization_cfg:
51
- candidate_qc.weights_quantization_cfg.weights_bias_correction = False
52
- else:
53
- _compute_bias_correction_per_candidate_qc(n,
54
- n.kernel_attr,
55
- graph.get_in_stats_collector(n),
56
- fw_impl=fw_impl)
46
+ if n.kernel_attr and n.is_weights_quantization_enabled(n.kernel_attr) and not n.has_positional_weights:
47
+ _compute_bias_correction_per_candidate_qc(n, n.kernel_attr, graph.get_in_stats_collector(n),
48
+ fw_impl=fw_impl)
57
49
  return graph
58
50
 
59
51
 
@@ -74,8 +66,7 @@ def _compute_bias_correction_per_candidate_qc(node: BaseNode,
74
66
  """
75
67
 
76
68
  for candidate_qc in node.candidates_quantization_cfg:
77
- if candidate_qc.weights_quantization_cfg.weights_bias_correction and not \
78
- candidate_qc.weights_quantization_cfg.weights_second_moment_correction:
69
+ if not candidate_qc.weights_quantization_cfg.weights_second_moment_correction:
79
70
 
80
71
  quantized_kernel, io_channels_axes = get_quantized_weights_attr_by_qc(kernel_attr,
81
72
  node,
@@ -56,8 +56,9 @@ def statistics_correction_runner(transformed_graph: Graph,
56
56
  ########################################################
57
57
  # Compute bias correction to nodes' config candidates
58
58
  ########################################################
59
- tg_with_bias = compute_bias_correction_of_graph(tg_with_bias,
60
- fw_impl)
59
+ if core_config.quantization_config.weights_bias_correction:
60
+ tg_with_bias = compute_bias_correction_of_graph(tg_with_bias,
61
+ fw_impl)
61
62
 
62
63
  if tb_w is not None:
63
64
  tb_w.add_graph(tg_with_bias, 'statistics_computation')
@@ -96,7 +97,6 @@ def apply_statistics_correction(transformed_graph: Graph,
96
97
  #############################################
97
98
  if core_config.quantization_config.weights_bias_correction:
98
99
  transformed_graph = apply_bias_correction_to_graph(transformed_graph,
99
- core_config,
100
100
  fw_impl=fw_impl)
101
101
  if tb_w is not None:
102
102
  tb_w.add_graph(transformed_graph, 'after_statistics_correction')
@@ -20,7 +20,6 @@ from typing import Callable
20
20
  import numpy as np
21
21
 
22
22
  from model_compression_toolkit.core.common import Graph
23
- from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
24
23
  from model_compression_toolkit.core import common
25
24
  from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
26
25
  ActivationQuantizationMode
@@ -84,14 +83,10 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
84
83
  # If the linear operator is part of a reused group (it is the "base" node, or a reused node),
85
84
  # we should skip the substitution.
86
85
  if source_node.is_reused():
87
- for qc in source_node.candidates_quantization_cfg:
88
- qc.weights_quantization_cfg.weights_second_moment_correction = False
89
86
  return graph
90
87
 
91
88
  # We apply only on nodes with folded BatchNormalization.
92
89
  if source_node.prior_info.std_output is None or source_node.prior_info.mean_output is None:
93
- for qc in source_node.candidates_quantization_cfg:
94
- qc.weights_quantization_cfg.weights_second_moment_correction = False
95
90
  return graph
96
91
 
97
92
  # This feature disabled for models with weights quantization method of Power of 2
@@ -103,10 +98,13 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
103
98
  == QuantizationMethod.POWER_OF_TWO):
104
99
  Logger.warning("Second moment statistics correction feature disabled for models with weights "
105
100
  "quantization method of Power of 2")
106
- for qc_inner in source_node.candidates_quantization_cfg:
107
- qc_inner.weights_quantization_cfg.weights_second_moment_correction = False
108
101
  return graph
109
102
 
103
+ # turn on second moment correction flag
104
+ def set_second_moment_correction(qc):
105
+ qc.weights_quantization_cfg.weights_second_moment_correction = True
106
+ source_node.quantization_cfg.update_all(set_second_moment_correction)
107
+
110
108
  eps = self.epsilon_val
111
109
 
112
110
  original_gamma = source_node.prior_info.std_output
@@ -298,7 +298,7 @@ def shift_negative_function(graph: Graph,
298
298
 
299
299
  negative_rate = np.abs(min_to_correct) / activation_threshold
300
300
 
301
- enable_sub = negative_rate <= non_linear_node_cfg_candidate.shift_negative_ratio
301
+ enable_sub = negative_rate <= core_config.quantization_config.shift_negative_ratio
302
302
  if min_to_correct >= 0 or not enable_sub:
303
303
  return graph
304
304
 
@@ -316,7 +316,7 @@ def shift_negative_function(graph: Graph,
316
316
  if core_config.quantization_config.shift_negative_params_search:
317
317
 
318
318
  hist_bins, hist_count = graph.get_out_stats_collector(non_linear_node).hc.get_histogram()
319
- hist_count = z_score_filter(non_linear_node_cfg_candidate.z_threshold,
319
+ hist_count = z_score_filter(core_config.quantization_config.z_threshold,
320
320
  hist_bins, hist_count)
321
321
 
322
322
  min_mse, _th, _shift = np.inf, None, None
@@ -471,10 +471,12 @@ def shift_negative_function(graph: Graph,
471
471
  pad_node=pad_node,
472
472
  op2d_node=op2d_node)
473
473
 
474
- if non_linear_node_cfg_candidate.shift_negative_threshold_recalculation:
475
- activation_param = compute_activation_qparams(activation_quant_cfg=non_linear_node_cfg_candidate,
474
+ if core_config.quantization_config.shift_negative_threshold_recalculation:
475
+ activation_param = compute_activation_qparams(quant_cfg=core_config.quantization_config,
476
+ node_activation_quant_cfg=non_linear_node_cfg_candidate,
476
477
  node_prior_info=non_linear_node.prior_info,
477
- out_stats_container=graph.get_out_stats_collector(non_linear_node))
478
+ out_stats_container=graph.get_out_stats_collector(
479
+ non_linear_node))
478
480
 
479
481
  assert activation_param.get(SIGNED) is False
480
482
  for candidate_qc in non_linear_node.candidates_quantization_cfg:
@@ -153,20 +153,10 @@ def get_finalized_graph(initial_graph: Graph,
153
153
  if bit_width_config:
154
154
  set_manual_bitwidth_config(graph, bit_width_config)
155
155
 
156
- # TODO irena: load_fqc_configuration only loads config from tpc. Previously quant_config was read as well.
157
- # As a first stage we keep the attributes in internal configs and fill them manually from quant_config
158
- # not to break all the code at once. Eventually we need to handle quant_config directly, without injecting into candidates.
159
- # TODO 2: Also we adjust candidates for single precision, which we shouldn't do here.
160
- def update(qc):
161
- qc.activation_quantization_cfg.set_qc(quant_config)
162
- qc.weights_quantization_cfg.set_qc(quant_config)
163
- for attr_cfg in qc.weights_quantization_cfg.get_all_weight_attrs_configs().values():
164
- attr_cfg.weights_error_method = quant_config.weights_error_method
165
- attr_cfg.l_p_value = quant_config.l_p_value
156
+ # TODO irena: remove after base config is used
166
157
  for n in transformed_graph.nodes:
167
158
  if not mixed_precision_enable:
168
159
  n.quantization_cfg.candidates_quantization_cfg = [n.quantization_cfg.base_quantization_cfg]
169
- n.quantization_cfg.update_all(update)
170
160
 
171
161
  ######################################
172
162
  # Channel equalization
@@ -143,7 +143,7 @@ class KerasInfo(FrameworkInfo):
143
143
  Node's output channel axis.
144
144
 
145
145
  """
146
- return cls.out_channel_axis_mapping.get(node_type, -1)
146
+ return cls.out_channel_axis_mapping.get(node_type)
147
147
 
148
148
 
149
149
  def set_keras_info(func):
@@ -17,7 +17,7 @@
17
17
  from tensorflow.keras.layers import InputLayer, Dense, DepthwiseConv2D, Conv2D, Conv2DTranspose, ZeroPadding2D
18
18
  from typing import List
19
19
 
20
- from model_compression_toolkit.core import common
20
+ from model_compression_toolkit.core import common, QuantizationConfig
21
21
  from model_compression_toolkit.core.common.graph.base_graph import Graph
22
22
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, WalkMatcher
23
23
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
@@ -47,7 +47,8 @@ class BaseInputScaling(common.BaseSubstitution):
47
47
  """
48
48
 
49
49
  def __init__(self,
50
- matcher_instance):
50
+ matcher_instance,
51
+ quant_cfg: QuantizationConfig):
51
52
  """
52
53
  Matches: InputLayer -> (optional nodes) -> (Dense,Conv2D,DepthwiseConv2D,Conv2DTranspose)
53
54
  note: the optional nodes are nodes that don't affect the scaling (such as ZeroPadding)
@@ -55,10 +56,11 @@ class BaseInputScaling(common.BaseSubstitution):
55
56
  Create a substitution using different params which may affect the way this substitution is made.
56
57
  The substitution is looking for edges in the graph which are input layers connected to linear layers.
57
58
  Args:
58
- matcher_instance: matcher instance of type WalkMatcher
59
-
59
+ matcher_instance: matcher instance of type WalkMatcher.
60
+ quant_cfg: quantization config.
60
61
  """
61
62
  super().__init__(matcher_instance=matcher_instance)
63
+ self.quant_cfg = quant_cfg
62
64
 
63
65
  def substitute(self,
64
66
  graph: Graph,
@@ -105,9 +107,11 @@ class BaseInputScaling(common.BaseSubstitution):
105
107
  for nqc in linear_layer.candidates_quantization_cfg:
106
108
  attr_cfg = nqc.weights_quantization_cfg.get_attr_config(linear_layer.kernel_attr)
107
109
  assert attr_cfg.enable_weights_quantization
108
- w_params, _ = compute_weights_qparams(w1_fixed, attr_quant_config=attr_cfg,
109
- output_channels_axis=attr_cfg.weights_channels_axis.output,
110
- min_threshold=nqc.weights_quantization_cfg.min_threshold)
110
+ w_params, _ = compute_weights_qparams(w1_fixed,
111
+ attr_quant_config=attr_cfg,
112
+ weights_error_method=self.quant_cfg.weights_error_method,
113
+ l_p_value=self.quant_cfg.l_p_value,
114
+ output_channels_axis=attr_cfg.weights_channels_axis.output)
111
115
  attr_cfg.set_weights_quantization_param(w_params)
112
116
 
113
117
  return graph
@@ -118,12 +122,15 @@ class InputScaling(BaseInputScaling):
118
122
  Substitution extends BaseInputScaling to the case of Input-->Linear
119
123
  """
120
124
 
121
- def __init__(self):
125
+ def __init__(self, quant_cfg: QuantizationConfig):
122
126
  """
123
127
  Initialize a ScaleEqualization object.
128
+
129
+ Args:
130
+ quant_cfg: quantization config.
124
131
  """
125
132
 
126
- super().__init__(matcher_instance=INPUT_MATCHER)
133
+ super().__init__(matcher_instance=INPUT_MATCHER, quant_cfg=quant_cfg)
127
134
 
128
135
 
129
136
  class InputScalingWithPad(BaseInputScaling):
@@ -131,9 +138,12 @@ class InputScalingWithPad(BaseInputScaling):
131
138
  Substitution extends BaseInputScaling to the case of Input-->ZeroPadding-->Linear
132
139
  """
133
140
 
134
- def __init__(self):
141
+ def __init__(self, quant_cfg: QuantizationConfig):
135
142
  """
136
143
  Initialize a ScaleEqualization object.
144
+
145
+ Args:
146
+ quant_cfg: quantization config.
137
147
  """
138
148
 
139
- super().__init__(matcher_instance=INPUT_MATCHER_WITH_PAD)
149
+ super().__init__(matcher_instance=INPUT_MATCHER_WITH_PAD, quant_cfg=quant_cfg)
@@ -357,8 +357,8 @@ class KerasImplementation(FrameworkImplementation):
357
357
  if quant_config.softmax_shift:
358
358
  substitutions_list.append(keras_softmax_shift())
359
359
  if quant_config.input_scaling:
360
- substitutions_list.append(InputScaling())
361
- substitutions_list.append(InputScalingWithPad())
360
+ substitutions_list.append(InputScaling(quant_config))
361
+ substitutions_list.append(InputScalingWithPad(quant_config))
362
362
  if quant_config.concat_threshold_update:
363
363
  substitutions_list.append(ConcatThresholdUpdate())
364
364
  return substitutions_list
@@ -28,6 +28,10 @@ import numpy as np
28
28
  from model_compression_toolkit.logger import Logger
29
29
 
30
30
 
31
+ # default output channel axis to use when it's not defined in node's fw_info.
32
+ _default_output_channel_axis = -1
33
+
34
+
31
35
  class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementation):
32
36
  """
33
37
  Implementation of the PruningFramework for the Keras framework. This class provides
@@ -172,6 +176,10 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
172
176
 
173
177
  return attributes_with_axis
174
178
 
179
+ @property
180
+ def default_output_channel_axis(self):
181
+ return _default_output_channel_axis
182
+
175
183
 
176
184
  def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool:
177
185
  """
@@ -101,7 +101,7 @@ class PyTorchInfo(FrameworkInfo):
101
101
  Node's output channel axis.
102
102
 
103
103
  """
104
- return cls.out_channel_axis_mapping.get(node_type, 1)
104
+ return cls.out_channel_axis_mapping.get(node_type)
105
105
 
106
106
 
107
107
  def set_pytorch_info(func):