mct-nightly 2.4.0.20250630.629__py3-none-any.whl → 2.4.0.20250702.605__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/METADATA +16 -16
  2. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/RECORD +75 -72
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -1
  5. model_compression_toolkit/core/common/framework_info.py +5 -32
  6. model_compression_toolkit/core/common/fusion/graph_fuser.py +12 -9
  7. model_compression_toolkit/core/common/graph/base_graph.py +20 -37
  8. model_compression_toolkit/core/common/graph/base_node.py +13 -106
  9. model_compression_toolkit/core/common/graph/functional_node.py +1 -1
  10. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +12 -10
  11. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +14 -9
  12. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +9 -15
  13. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +2 -3
  14. model_compression_toolkit/core/common/network_editors/__init__.py +8 -1
  15. model_compression_toolkit/core/common/network_editors/actions.py +4 -96
  16. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  17. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +116 -56
  18. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
  19. model_compression_toolkit/core/common/quantization/node_quantization_config.py +55 -179
  20. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +21 -1
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +8 -5
  22. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -70
  23. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +10 -12
  24. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +54 -30
  25. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  26. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +93 -398
  27. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +2 -5
  28. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -4
  29. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -6
  30. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +12 -6
  31. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +1 -1
  32. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -2
  33. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +33 -33
  34. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +2 -4
  35. model_compression_toolkit/core/graph_prep_runner.py +31 -20
  36. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +5 -2
  37. model_compression_toolkit/core/keras/default_framework_info.py +0 -11
  38. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +9 -6
  39. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +3 -1
  40. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -1
  41. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +2 -1
  42. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
  43. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +47 -0
  44. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +3 -2
  45. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +5 -2
  46. model_compression_toolkit/core/pytorch/default_framework_info.py +0 -12
  47. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  48. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +2 -0
  49. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +1 -1
  50. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +2 -1
  51. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +1 -1
  52. model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
  53. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +45 -0
  54. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +3 -2
  55. model_compression_toolkit/core/runner.py +1 -1
  56. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +7 -3
  57. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  58. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +12 -3
  59. model_compression_toolkit/pruning/keras/pruning_facade.py +5 -9
  60. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -5
  61. model_compression_toolkit/ptq/keras/quantization_facade.py +1 -1
  62. model_compression_toolkit/qat/keras/quantization_facade.py +1 -1
  63. model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
  64. model_compression_toolkit/quantization_preparation/__init__.py +14 -0
  65. model_compression_toolkit/quantization_preparation/load_fqc.py +223 -0
  66. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  67. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -78
  68. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/WHEEL +0 -0
  69. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/licenses/LICENSE.md +0 -0
  70. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/top_level.txt +0 -0
  71. /model_compression_toolkit/core/keras/{quantizer → quantization}/__init__.py +0 -0
  72. /model_compression_toolkit/core/keras/{quantizer → quantization}/fake_quant_builder.py +0 -0
  73. /model_compression_toolkit/core/keras/{quantizer → quantization}/lut_fake_quant.py +0 -0
  74. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/__init__.py +0 -0
  75. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/fake_quant_builder.py +0 -0
  76. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/lut_fake_quant.py +0 -0
@@ -18,7 +18,7 @@ from typing import Any, Callable
18
18
  from model_compression_toolkit.core import QuantizationConfig
19
19
  from model_compression_toolkit.core.common import BaseNode, Graph
20
20
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
21
+ from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantization_fn
22
22
 
23
23
 
24
24
  def get_previous_node_with_activation_quantization(linear_node: BaseNode,
@@ -67,7 +67,8 @@ def compute_activation_bias_correction(graph: Graph,
67
67
  fw_impl: FrameworkImplementation,
68
68
  linear_node: BaseNode,
69
69
  prev_node: BaseNode,
70
- kernel_size: str) -> Graph:
70
+ kernel_size: str,
71
+ get_activation_quantization_fn_factory: Callable) -> Graph:
71
72
  """
72
73
  Compute the activation bias correction term, and store it in the final activation
73
74
  quantization configuration.
@@ -79,6 +80,7 @@ def compute_activation_bias_correction(graph: Graph,
79
80
  linear_node: Node to compute the activation bias correction for.
80
81
  prev_node: Node to compute the activation error caused by his activation quantization.
81
82
  kernel_size: The framework specific attribute name of the convolution layer's kernel size.
83
+ get_activation_quantization_fn_factory: activation quantization functions factory.
82
84
 
83
85
  Returns:
84
86
  Graph with activation bias correction term for each node.
@@ -105,7 +107,9 @@ def compute_activation_bias_correction(graph: Graph,
105
107
  float_centers = calculate_bin_centers(float_bins)
106
108
 
107
109
  # Quantize the bin edges and calculate the centers of the quantized bins
108
- quant_bins = prev_node_act_quant_cfg.quantize_node_output(fw_impl.to_tensor(float_bins))
110
+ activation_quantizer = get_activation_quantization_fn(prev_node_act_quant_cfg,
111
+ get_activation_quantization_fn_factory)
112
+ quant_bins = activation_quantizer(fw_impl.to_tensor(float_bins))
109
113
  quant_bins = fw_impl.to_numpy(quant_bins)
110
114
  quant_centers = calculate_bin_centers(quant_bins)
111
115
 
@@ -149,7 +153,8 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
149
153
  quant_config: QuantizationConfig,
150
154
  fw_impl: FrameworkImplementation,
151
155
  activation_bias_correction_node_matchers: Callable,
152
- kernel_size: str) -> Graph:
156
+ kernel_size: str,
157
+ get_activation_quantization_fn_factory: Callable) -> Graph:
153
158
  """
154
159
  Compute the activation bias correction term for the graph.
155
160
 
@@ -159,7 +164,7 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
159
164
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
160
165
  activation_bias_correction_node_matchers: Function to match the layers for activation bias correction.
161
166
  kernel_size: The framework specific attribute name of the convolution layer's kernel size.
162
-
167
+ get_activation_quantization_fn_factory: activation quantization functions factory.
163
168
 
164
169
  Returns:
165
170
  Graph with activation bias correction term for each relevant node.
@@ -175,5 +180,6 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
175
180
  fw_impl=fw_impl,
176
181
  linear_node=n,
177
182
  prev_node=prev_node,
178
- kernel_size=kernel_size)
183
+ kernel_size=kernel_size,
184
+ get_activation_quantization_fn_factory=get_activation_quantization_fn_factory)
179
185
  return graph
@@ -43,7 +43,7 @@ def compute_bias_correction_of_graph(graph: Graph,
43
43
  for n in graph.nodes:
44
44
  # Bias correction is computed based on the quantized kernel, so we need to get the specific kernel attribute
45
45
  # name out of all the weights attributes of the node.
46
- if n.is_kernel_op:
46
+ if n.kernel_attr:
47
47
  if n.is_weights_quantization_enabled(n.kernel_attr):
48
48
  # Bias correction is not applied to layers with constant inputs.
49
49
  if n.has_positional_weights:
@@ -124,7 +124,7 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
124
124
 
125
125
  bn_node.prior_info = copy.deepcopy(source_node.prior_info)
126
126
 
127
- bn_node.candidates_quantization_cfg = copy.deepcopy(source_node.candidates_quantization_cfg)
127
+ bn_node.quantization_cfg = copy.deepcopy(source_node.quantization_cfg)
128
128
 
129
129
  for qc in bn_node.candidates_quantization_cfg:
130
130
  qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
@@ -139,7 +139,6 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
139
139
  # reconstructed node BN attributes need to be quantized and how.
140
140
  qc.weights_quantization_cfg.set_attr_config(attr,
141
141
  WeightsAttrQuantizationConfig(
142
- QuantizationConfig(),
143
142
  AttributeQuantizationConfig(
144
143
  enable_weights_quantization=False)))
145
144
 
@@ -16,21 +16,20 @@ import copy
16
16
  import numpy as np
17
17
  from typing import List, Tuple, Any, Callable
18
18
 
19
- from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
20
19
  from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
21
20
  ActivationQuantizationMode
22
21
  from model_compression_toolkit.logger import Logger
23
- from model_compression_toolkit.core.common import FrameworkInfo, Graph, BaseNode
22
+ from model_compression_toolkit.core.common import Graph, BaseNode
24
23
  from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
25
24
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
26
- from model_compression_toolkit.core.common.quantization.set_node_quantization_config import create_node_activation_qc, \
27
- set_quantization_configs_to_node
28
25
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
29
26
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
30
- import get_activations_qparams
27
+ import compute_activation_qparams
31
28
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
32
29
  _mse_error_histogram
33
30
  from model_compression_toolkit.core.common.quantization.quantization_params_generation import z_score_filter
31
+ from model_compression_toolkit.quantization_preparation.load_fqc import set_quantization_configs_to_node, \
32
+ fetch_qc_options_for_node
34
33
  from model_compression_toolkit.target_platform_capabilities import QuantizationMethod, AttributeQuantizationConfig
35
34
 
36
35
  """
@@ -67,8 +66,7 @@ def op2d_bias_correction(op2d_node: BaseNode,
67
66
  # Add an attribute quantization configuration to the newly added bias attribute, with disabled quantization
68
67
  for qc in op2d_node.candidates_quantization_cfg:
69
68
  qc.weights_quantization_cfg.set_attr_config(bias_flag_str,
70
- WeightsAttrQuantizationConfig(QuantizationConfig(),
71
- AttributeQuantizationConfig(
69
+ WeightsAttrQuantizationConfig(AttributeQuantizationConfig(
72
70
  enable_weights_quantization=False)))
73
71
 
74
72
  # Each node adds a different noise due to the shifting. It depends on the
@@ -253,6 +251,7 @@ def shift_negative_function(graph: Graph,
253
251
  padding_str: str,
254
252
  bias_str: str,
255
253
  bias_flag_str: str,
254
+ get_activation_quantization_fn_factory: Callable,
256
255
  zero_padding_node: BaseNode = None,
257
256
  bypass_nodes: List = None,
258
257
  params_search_quantization_fn: Callable = None
@@ -278,6 +277,7 @@ def shift_negative_function(graph: Graph,
278
277
  padding_str: The framework specific attribute name of the padding.
279
278
  bias_str: The framework specific attribute name of the bias.
280
279
  bias_flag_str: The framework specific attribute name of the bias flag.
280
+ get_activation_quantization_fn_factory: activation quantization functions factory.
281
281
  zero_padding_node: ZeroPadding2D node that may be in the graph before the linear layer.
282
282
  params_search_quantization_fn: Function to quantize np tensor using a framework (tf/torch) quantization method. Needed for better param_search estimating the expected loss.
283
283
 
@@ -327,13 +327,15 @@ def shift_negative_function(graph: Graph,
327
327
  'float32') # Change to type float32 to support tensorflow dtypes
328
328
  for _shift_value in _q_points:
329
329
  _hist_bins = hist_bins.astype(np.float32) + _shift_value
330
- fw_quant_fn = non_linear_node_cfg_candidate.activation_quantization_fn(non_linear_node_cfg_candidate.activation_n_bits,qparams)
330
+ quantizer_factory = get_activation_quantization_fn_factory(
331
+ non_linear_node_cfg_candidate.activation_quantization_method)
332
+ fw_quant_fn = quantizer_factory(non_linear_node_cfg_candidate.activation_n_bits, qparams)
331
333
  """
332
334
  In SNC, when better shifting values are tested for better choice,
333
335
  the histogram (which is a numpy object) is quantized using the non-linear node activation
334
336
  quantization function (to estimate the expected mse comparing to the original histogram).
335
337
  The quantization function is a framework function, which makes it fail since it
336
- expects a fw tensor. The commmon part of SNC receives an argument which is a callable
338
+ expects a fw tensor. The common part of SNC receives an argument which is a callable
337
339
  that receives two argument and returns one: it gets the fw activation quantization function
338
340
  and the bins to quantize. The function (of each fw) responsible for doing (if needed) a preprocessing and postprocessing
339
341
  to the bins which is a numpy object.
@@ -395,9 +397,7 @@ def shift_negative_function(graph: Graph,
395
397
 
396
398
  set_quantization_configs_to_node(node=add_node,
397
399
  graph=graph,
398
- quant_config=core_config.quantization_config,
399
- fqc=graph.fqc,
400
- mixed_precision_enable=core_config.is_mixed_precision_enabled)
400
+ fqc=graph.fqc)
401
401
 
402
402
  update_fused_op_with_add(graph=graph,
403
403
  non_linear_node=non_linear_node,
@@ -421,9 +421,7 @@ def shift_negative_function(graph: Graph,
421
421
  # Set quantization configuration to node, even though we do not quantize it:
422
422
  set_quantization_configs_to_node(node=pad_node,
423
423
  graph=graph,
424
- quant_config=core_config.quantization_config,
425
- fqc=graph.fqc,
426
- mixed_precision_enable=core_config.is_mixed_precision_enabled)
424
+ fqc=graph.fqc)
427
425
 
428
426
  for candidate_qc in pad_node.candidates_quantization_cfg:
429
427
  candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
@@ -448,7 +446,7 @@ def shift_negative_function(graph: Graph,
448
446
  bypass_candidate_qc.activation_quantization_cfg.activation_quantization_params[SIGNED] = False
449
447
  graph.shift_stats_collector(bypass_node, np.array(shift_value))
450
448
 
451
- add_node_qco = add_node.get_qco(graph.fqc).quantization_configurations
449
+ add_node_qco = fetch_qc_options_for_node(add_node, graph.fqc).quantization_configurations
452
450
  add_supported_bitwidths = [c.activation_n_bits for c in add_node_qco]
453
451
  if original_non_linear_activation_nbits not in add_supported_bitwidths:
454
452
  raise ValueError(
@@ -456,18 +454,16 @@ def shift_negative_function(graph: Graph,
456
454
  f"bitwidth is {original_non_linear_activation_nbits}. Consider adapting the TPC so 'Add' will support the "
457
455
  f"same bitwidth as {non_linear_node.type} or disable shift negative correction.")
458
456
 
459
- for op_qc_idx, candidate_qc in enumerate(add_node.candidates_quantization_cfg):
460
- for attr in add_node.get_node_weights_attributes():
461
- # TODO: do we not quantize the weights of this 'add' on purpose?
462
- candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False
457
+ set_quantization_configs_to_node(add_node, graph, graph.fqc)
458
+ # TODO: do we not quantize the weights of this 'add' on purpose?
459
+ add_node.quantization_cfg.disable_weights_quantization()
463
460
 
464
- candidate_qc.activation_quantization_cfg = create_node_activation_qc(core_config.quantization_config,
465
- add_node_qco[op_qc_idx])
461
+ def update(c):
462
+ c.activation_quantization_cfg.activation_n_bits = original_non_linear_activation_nbits
463
+ c.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold,
464
+ SIGNED: False})
466
465
 
467
- candidate_qc.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold,
468
- SIGNED: False})
469
-
470
- candidate_qc.activation_quantization_cfg.activation_n_bits = original_non_linear_activation_nbits
466
+ add_node.quantization_cfg.update_all(update, remove_duplicates=True)
471
467
 
472
468
  # Add the new padding node to a fused op with the op2d.
473
469
  if pad_node:
@@ -476,11 +472,11 @@ def shift_negative_function(graph: Graph,
476
472
  op2d_node=op2d_node)
477
473
 
478
474
  if non_linear_node_cfg_candidate.shift_negative_threshold_recalculation:
479
- activation_param = get_activations_qparams(activation_quant_cfg=non_linear_node_cfg_candidate,
480
- nodes_prior_info=non_linear_node.prior_info,
481
- out_stats_container=graph.get_out_stats_collector(non_linear_node))
475
+ activation_param = compute_activation_qparams(activation_quant_cfg=non_linear_node_cfg_candidate,
476
+ node_prior_info=non_linear_node.prior_info,
477
+ out_stats_container=graph.get_out_stats_collector(non_linear_node))
482
478
 
483
- assert activation_param.get(SIGNED) == False
479
+ assert activation_param.get(SIGNED) is False
484
480
  for candidate_qc in non_linear_node.candidates_quantization_cfg:
485
481
  candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_param)
486
482
 
@@ -573,6 +569,7 @@ def apply_shift_negative_correction(graph: Graph,
573
569
  padding_str: str,
574
570
  bias_str: str,
575
571
  bias_flag_str: str,
572
+ get_activation_quantization_fn_factory: Callable,
576
573
  params_search_quantization_fn: Callable=None) -> Graph:
577
574
  """
578
575
  Apply the substitution even if the linear node is not immediately after
@@ -594,6 +591,9 @@ def apply_shift_negative_correction(graph: Graph,
594
591
  padding_str: The framework specific attribute name of the padding.
595
592
  bias_str: The framework specific attribute name of the bias.
596
593
  bias_flag_str: The framework specific attribute name of the bias flag.
594
+ get_activation_quantization_fn_factory: activation quantization functions factory.
595
+ params_search_quantization_fn: Function to quantize np tensor using a framework (tf/torch) quantization method. Needed for better param_search estimating the expected loss.
596
+
597
597
  Returns:
598
598
  Graph after applying shift negative on selected activations.
599
599
  """
@@ -601,9 +601,8 @@ def apply_shift_negative_correction(graph: Graph,
601
601
  nodes = list(graph.nodes())
602
602
  for n in nodes:
603
603
  # Skip substitution if QuantizationMethod is uniform.
604
- node_qco = n.get_qco(graph.fqc)
605
- if any([op_qc.activation_quantization_method is QuantizationMethod.UNIFORM
606
- for op_qc in node_qco.quantization_configurations]):
604
+ if any(aqc.activation_quantization_cfg.activation_quantization_method == QuantizationMethod.UNIFORM
605
+ for aqc in n.candidates_quantization_cfg):
607
606
  continue
608
607
 
609
608
  if snc_node_types.apply(n):
@@ -625,6 +624,7 @@ def apply_shift_negative_correction(graph: Graph,
625
624
  padding_str,
626
625
  bias_str,
627
626
  bias_flag_str,
627
+ get_activation_quantization_fn_factory,
628
628
  zero_padding_node=pad_node,
629
629
  bypass_nodes=bypass_nodes,
630
630
  params_search_quantization_fn=params_search_quantization_fn)
@@ -207,7 +207,7 @@ class TensorboardWriter(object):
207
207
  attr = dict()
208
208
  if n.final_activation_quantization_cfg is not None:
209
209
  attr.update(n.final_activation_quantization_cfg.__dict__)
210
- elif n.candidates_quantization_cfg is not None:
210
+ elif n.quantization_cfg is not None:
211
211
  attr.update(n.get_unified_activation_candidates_dict())
212
212
  return attr
213
213
 
@@ -229,7 +229,7 @@ class TensorboardWriter(object):
229
229
  attr = dict()
230
230
  if n.final_weights_quantization_cfg is not None:
231
231
  attr.update(n.final_weights_quantization_cfg.__dict__)
232
- elif n.candidates_quantization_cfg is not None:
232
+ elif n.quantization_cfg is not None:
233
233
  attr.update(n.get_unified_weights_candidates_dict())
234
234
  return attr
235
235
 
@@ -530,8 +530,6 @@ def init_tensorboard_writer() -> TensorboardWriter:
530
530
  Create a TensorBoardWriter object initialized with the logger dir path if it was set,
531
531
  or None otherwise.
532
532
 
533
- Args:
534
-
535
533
  Returns:
536
534
  A TensorBoardWriter object.
537
535
  """
@@ -16,22 +16,22 @@
16
16
 
17
17
  from typing import Callable, Any
18
18
 
19
- from model_compression_toolkit.core.common import FrameworkInfo
20
19
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
- from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator
22
20
  from model_compression_toolkit.core.common.graph.base_graph import Graph
23
21
  from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
24
22
  from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates
25
- from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
23
+ from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG, \
24
+ QuantizationErrorMethod
26
25
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
27
- from model_compression_toolkit.core.common.quantization.set_node_quantization_config import \
28
- set_quantization_configuration_to_graph
26
+ from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_manual_bitwidth_config
29
27
  from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
30
28
  from model_compression_toolkit.core.common.substitutions.linear_collapsing_substitution import \
31
29
  linear_collapsing_substitute
32
30
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
31
+ from model_compression_toolkit.quantization_preparation.load_fqc import load_fqc_configuration
33
32
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
34
33
  FrameworkQuantizationCapabilities
34
+ from model_compression_toolkit.logger import Logger
35
35
 
36
36
 
37
37
  def graph_preparation_runner(in_model: Any,
@@ -112,6 +112,12 @@ def get_finalized_graph(initial_graph: Graph,
112
112
 
113
113
  Returns: Graph object that represents the model, after applying all required modifications to it.
114
114
  """
115
+ if quant_config.weights_error_method == QuantizationErrorMethod.HMSE:
116
+ if not running_gptq:
117
+ raise ValueError(f"The HMSE error method for parameters selection is only supported when running GPTQ "
118
+ f"optimization due to long execution time that is not suitable for basic PTQ.")
119
+ Logger.warning("Using the HMSE error method for weights quantization parameters search. "
120
+ "Note: This method may significantly increase runtime during the parameter search process.")
115
121
 
116
122
  ######################################
117
123
  # Graph substitution (prepare graph)
@@ -141,21 +147,26 @@ def get_finalized_graph(initial_graph: Graph,
141
147
  if tb_w is not None:
142
148
  tb_w.add_graph(transformed_graph, 'pre_statistics_collection_substitutions')
143
149
 
144
- ######################################
145
- # Add quantization configurations
146
- ######################################
147
- transformed_graph = set_quantization_configuration_to_graph(graph=transformed_graph,
148
- quant_config=quant_config,
149
- bit_width_config=bit_width_config,
150
- mixed_precision_enable=mixed_precision_enable,
151
- running_gptq=running_gptq)
152
-
153
- ######################################
154
- # Layer fusing
155
- ######################################
156
- fusing_info = FusingInfoGenerator(fqc.get_fusing_patterns()).generate_fusing_info(transformed_graph)
157
- transformed_graph.fusing_info = fusing_info
158
- transformed_graph.override_fused_node_activation_quantization_candidates()
150
+ transformed_graph = load_fqc_configuration(transformed_graph, fqc)
151
+
152
+ # filter candidates per manual config
153
+ if bit_width_config:
154
+ set_manual_bitwidth_config(graph, bit_width_config)
155
+
156
+ # TODO irena: load_fqc_configuration only loads config from tpc. Previously quant_config was read as well.
157
+ # As a first stage we keep the attributes in internal configs and fill them manually from quant_config
158
+ # not to break all the code at once. Eventually we need to handle quant_config directly, without injecting into candidates.
159
+ # TODO 2: Also we adjust candidates for single precision, which we shouldn't do here.
160
+ def update(qc):
161
+ qc.activation_quantization_cfg.set_qc(quant_config)
162
+ qc.weights_quantization_cfg.set_qc(quant_config)
163
+ for attr_cfg in qc.weights_quantization_cfg.get_all_weight_attrs_configs().values():
164
+ attr_cfg.weights_error_method = quant_config.weights_error_method
165
+ attr_cfg.l_p_value = quant_config.l_p_value
166
+ for n in transformed_graph.nodes:
167
+ if not mixed_precision_enable:
168
+ n.quantization_cfg.candidates_quantization_cfg = [n.quantization_cfg.base_quantization_cfg]
169
+ n.quantization_cfg.update_all(update)
159
170
 
160
171
  ######################################
161
172
  # Channel equalization
@@ -14,9 +14,10 @@
14
14
  # ==============================================================================
15
15
  from typing import List
16
16
 
17
- from model_compression_toolkit.core import FrameworkInfo
18
17
  from model_compression_toolkit.core import common
19
18
  from model_compression_toolkit.core.common import BaseNode
19
+ from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantization_fn
20
+ from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
20
21
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
21
22
  from tensorflow.python.util.object_identity import Reference as TFReference
22
23
 
@@ -56,4 +57,6 @@ class QuantizedKerasModelBuilder(KerasModelBuilder):
56
57
  Output of the node.
57
58
 
58
59
  """
59
- return node.final_activation_quantization_cfg.quantize_node_output(input_tensors)
60
+ activation_quantizer = get_activation_quantization_fn(node.final_activation_quantization_cfg,
61
+ get_activation_quantization_fn_factory)
62
+ return activation_quantizer(input_tensors)
@@ -18,7 +18,6 @@ import tensorflow as tf
18
18
  from typing import Tuple, Any, Dict
19
19
  from functools import wraps
20
20
 
21
- from model_compression_toolkit.core.keras.quantizer.lut_fake_quant import activation_lut_kmean_quantizer
22
21
  from packaging import version
23
22
 
24
23
  if version.parse(tf.__version__) >= version.parse("2.13"):
@@ -26,11 +25,9 @@ if version.parse(tf.__version__) >= version.parse("2.13"):
26
25
  else:
27
26
  from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU, Activation # pragma: no cover
28
27
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo, set_fw_info, ChannelAxisMapping
29
- from mct_quantizers import QuantizationMethod
30
28
  from model_compression_toolkit.constants import SOFTMAX_THRESHOLD, ACTIVATION
31
29
  from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
32
30
  KERNEL, DEPTHWISE_KERNEL, GELU
33
- from model_compression_toolkit.core.keras.quantizer.fake_quant_builder import power_of_two_quantization, symmetric_quantization, uniform_quantization
34
31
 
35
32
 
36
33
  class KerasInfo(FrameworkInfo):
@@ -103,14 +100,6 @@ class KerasInfo(FrameworkInfo):
103
100
  tf.nn.softmax: (0, SOFTMAX_THRESHOLD),
104
101
  }
105
102
 
106
- """
107
- Mapping from a QuantizationMethod to an activation quantizer function.
108
- """
109
- activation_quantizer_mapping = {QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
110
- QuantizationMethod.SYMMETRIC: symmetric_quantization,
111
- QuantizationMethod.UNIFORM: uniform_quantization,
112
- QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer}
113
-
114
103
  @classmethod
115
104
  def get_layer_min_max(cls, layer: Any, fw_attrs: Dict) -> Tuple[float, float]:
116
105
  """
@@ -18,13 +18,12 @@ from tensorflow.keras.layers import InputLayer, Dense, DepthwiseConv2D, Conv2D,
18
18
  from typing import List
19
19
 
20
20
  from model_compression_toolkit.core import common
21
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
21
  from model_compression_toolkit.core.common.graph.base_graph import Graph
23
- from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, EdgeMatcher, WalkMatcher
22
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, WalkMatcher
24
23
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
25
- from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
26
24
  from model_compression_toolkit.constants import THRESHOLD
27
- from model_compression_toolkit.core.keras.constants import KERNEL
25
+ from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation import \
26
+ compute_weights_qparams
28
27
  from model_compression_toolkit.logger import Logger
29
28
 
30
29
  input_node = NodeOperationMatcher(InputLayer)
@@ -104,8 +103,12 @@ class BaseInputScaling(common.BaseSubstitution):
104
103
 
105
104
  # After scaling weights may have different thresholds so it needs to be recalculated
106
105
  for nqc in linear_layer.candidates_quantization_cfg:
107
- nqc.weights_quantization_cfg.get_attr_config(linear_layer.kernel_attr).calculate_and_set_weights_params(w1_fixed,
108
- nqc.weights_quantization_cfg.min_threshold)
106
+ attr_cfg = nqc.weights_quantization_cfg.get_attr_config(linear_layer.kernel_attr)
107
+ assert attr_cfg.enable_weights_quantization
108
+ w_params, _ = compute_weights_qparams(w1_fixed, attr_quant_config=attr_cfg,
109
+ output_channels_axis=attr_cfg.weights_channels_axis.output,
110
+ min_threshold=nqc.weights_quantization_cfg.min_threshold)
111
+ attr_cfg.set_weights_quantization_param(w_params)
109
112
 
110
113
  return graph
111
114
 
@@ -34,6 +34,7 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
34
34
  NodeFrameworkAttrMatcher
35
35
  from model_compression_toolkit.core.common.substitutions.shift_negative_activation import \
36
36
  apply_shift_negative_correction
37
+ from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
37
38
  from model_compression_toolkit.core.keras.constants import KERNEL_SIZE, STRIDES, ACTIVATION, SWISH, \
38
39
  SELU, GELU, FUNCTION, ADD, PAD
39
40
  from model_compression_toolkit.core.keras.constants import NEGATIVE_SLOPE, PADDING, PAD_SAME, PAD_VALID, BIAS, USE_BIAS
@@ -252,5 +253,6 @@ def keras_apply_shift_negative_correction(graph: Graph,
252
253
  is_padding_node_and_node_has_padding,
253
254
  PADDING,
254
255
  BIAS,
255
- USE_BIAS
256
+ USE_BIAS,
257
+ get_activation_quantization_fn_factory
256
258
  )
@@ -94,7 +94,7 @@ class WeightsHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
94
94
  for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
95
95
 
96
96
  # Check if the target node's layer type is supported.
97
- if not ipt_node.is_kernel_op:
97
+ if not ipt_node.kernel_attr:
98
98
  Logger.critical(f"Hessian information with respect to weights is not supported for "
99
99
  f"{ipt_node.type} layers.") # pragma: no cover
100
100
 
@@ -23,6 +23,7 @@ from model_compression_toolkit.core.common.mixed_precision.configurable_quantize
23
23
  verify_candidates_descending_order, init_activation_quantizers
24
24
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
25
25
  CandidateNodeQuantizationConfig
26
+ from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
26
27
  from model_compression_toolkit.logger import Logger
27
28
 
28
29
  import tensorflow as tf
@@ -67,7 +68,7 @@ class ConfigurableActivationQuantizer(BaseKerasInferableQuantizer):
67
68
  if qc.activation_quantization_cfg.quant_mode != node_q_cfg[0].activation_quantization_cfg.quant_mode:
68
69
  Logger.critical("Unsupported configuration: Mixing candidates with differing activation quantization states (enabled/disabled).") # pragma: no cover
69
70
 
70
- self.activation_quantizers = init_activation_quantizers(self.node_q_cfg)
71
+ self.activation_quantizers = init_activation_quantizers(self.node_q_cfg, get_activation_quantization_fn_factory)
71
72
  self.active_quantization_config_index = max_candidate_idx # initialize with first config as default
72
73
 
73
74
  def set_active_activation_quantizer(self, index: Optional[int]):
@@ -155,7 +155,7 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
155
155
  """
156
156
 
157
157
  attributes_with_axis = {}
158
- if node.is_kernel_op:
158
+ if node.kernel_attr:
159
159
  attributes_with_axis[node.kernel_attr] = (node.channel_axis.output, node.channel_axis.input)
160
160
 
161
161
  # Bias is a vector at the length of the number of output channels.
@@ -0,0 +1,47 @@
1
+ # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from collections.abc import Callable
17
+
18
+ from mct_quantizers import QuantizationMethod
19
+
20
+
21
+ from model_compression_toolkit.core.keras.quantization.fake_quant_builder import power_of_two_quantization, \
22
+ symmetric_quantization, uniform_quantization
23
+ from model_compression_toolkit.core.keras.quantization.lut_fake_quant import activation_lut_kmean_quantizer
24
+
25
+
26
+ """
27
+ Mapping from a QuantizationMethod to an activation quantizer function.
28
+ """
29
+ _activation_quantizer_factory_mapping = {
30
+ QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
31
+ QuantizationMethod.SYMMETRIC: symmetric_quantization,
32
+ QuantizationMethod.UNIFORM: uniform_quantization,
33
+ QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer
34
+ }
35
+
36
+
37
+ def get_activation_quantization_fn_factory(quantization_method: QuantizationMethod) -> Callable[[int, dict], Callable]:
38
+ """
39
+ Get factory for activation quantizer.
40
+
41
+ Args:
42
+ quantization_method: quantization method for activation.
43
+
44
+ Returns:
45
+ Factory that accepts activation bitwidth and a dict of quantization params, and returns the quantizer.
46
+ """
47
+ return _activation_quantizer_factory_mapping[quantization_method]
@@ -25,7 +25,7 @@ else:
25
25
 
26
26
  from model_compression_toolkit.core import QuantizationConfig
27
27
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
28
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
28
+ from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
29
29
  from model_compression_toolkit.core.common import Graph
30
30
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
31
31
  from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \
@@ -60,5 +60,6 @@ def keras_compute_activation_bias_correction_of_graph(graph: Graph,
60
60
  fw_impl=fw_impl,
61
61
  activation_bias_correction_node_matchers=
62
62
  activation_bias_correction_node_matchers,
63
- kernel_size=KERNEL_SIZE)
63
+ kernel_size=KERNEL_SIZE,
64
+ get_activation_quantization_fn_factory=get_activation_quantization_fn_factory)
64
65
  return graph
@@ -17,9 +17,10 @@ from typing import List, Tuple
17
17
 
18
18
  import torch
19
19
 
20
- from model_compression_toolkit.core import FrameworkInfo
21
20
  from model_compression_toolkit.core import common
22
21
  from model_compression_toolkit.core.common import BaseNode
22
+ from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantization_fn
23
+ from model_compression_toolkit.core.pytorch.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
23
24
  from model_compression_toolkit.core.common.user_info import UserInformation
24
25
  from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder, \
25
26
  PytorchModel
@@ -60,7 +61,9 @@ class QuantizedPyTorchModel(PytorchModel):
60
61
  if node.is_activation_quantization_enabled():
61
62
  if isinstance(input_tensors, list):
62
63
  input_tensors = torch.cat(input_tensors, dim=0)
63
- return node.final_activation_quantization_cfg.quantize_node_output(input_tensors)
64
+ activation_quantizer = get_activation_quantization_fn(node.final_activation_quantization_cfg,
65
+ get_activation_quantization_fn_factory)
66
+ return activation_quantizer(input_tensors)
64
67
  return input_tensors
65
68
 
66
69