mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -24,6 +24,8 @@ from model_compression_toolkit.core.common.user_info import UserInformation
24
24
  from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder, \
25
25
  PytorchModel
26
26
 
27
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
28
+
27
29
 
28
30
  class FloatPyTorchModel(PytorchModel):
29
31
  """
@@ -32,16 +34,19 @@ class FloatPyTorchModel(PytorchModel):
32
34
 
33
35
  def __init__(self,
34
36
  graph: common.Graph,
35
- append2output=None):
37
+ append2output=None,
38
+ fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO):
36
39
  """
37
40
 
38
41
  Args:
39
42
  graph: Graph to build its corresponding Pytorch model.
40
43
  append2output: List of nodes or OutTensor objects.
44
+ fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
41
45
  """
42
46
 
43
47
  super().__init__(graph,
44
- append2output)
48
+ append2output,
49
+ fw_info)
45
50
 
46
51
  def _quantize_node_activations(self,
47
52
  node: BaseNode,
@@ -66,17 +71,20 @@ class FloatPyTorchModelBuilder(PyTorchModelBuilder):
66
71
  def __init__(self,
67
72
  graph: common.Graph,
68
73
  append2output=None,
74
+ fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
69
75
  return_float_outputs: bool = False):
70
76
  """
71
77
 
72
78
  Args:
73
79
  graph: Graph to build the model from.
74
80
  append2output: Nodes to append to model's output.
81
+ fw_info: Information about the specific framework of the model that is built.
75
82
  return_float_outputs: Whether the model returns float tensors or not.
76
83
  """
77
84
 
78
85
  super().__init__(graph,
79
86
  append2output,
87
+ fw_info,
80
88
  return_float_outputs)
81
89
 
82
90
  def build_model(self) -> Tuple[PytorchModel, UserInformation]:
@@ -86,4 +94,5 @@ class FloatPyTorchModelBuilder(PyTorchModelBuilder):
86
94
 
87
95
  """
88
96
  return FloatPyTorchModel(self.graph,
89
- self.append2output), self.graph.user_info
97
+ self.append2output,
98
+ self.fw_info), self.graph.user_info
@@ -23,6 +23,7 @@ from model_compression_toolkit.core import FrameworkInfo, common
23
23
  from model_compression_toolkit.core.common import BaseNode
24
24
  from model_compression_toolkit.core.common.user_info import UserInformation
25
25
  from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
26
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
26
27
  from model_compression_toolkit.core.pytorch.mixed_precision.configurable_activation_quantizer import \
27
28
  ConfigurableActivationQuantizer
28
29
  from model_compression_toolkit.core.pytorch.mixed_precision.configurable_weights_quantizer import \
@@ -37,12 +38,14 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
37
38
  def __init__(self,
38
39
  graph: common.Graph,
39
40
  append2output=None,
41
+ fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
40
42
  return_float_outputs: bool = False):
41
43
  """
42
44
 
43
45
  Args:
44
46
  graph: Graph to build the model from.
45
47
  append2output: Nodes to append to model's output.
48
+ fw_info: Information about the specific framework of the model that is built.
46
49
  return_float_outputs: Whether the model returns float tensors or not.
47
50
  """
48
51
 
@@ -50,6 +53,7 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
50
53
 
51
54
  super().__init__(graph,
52
55
  append2output,
56
+ fw_info,
53
57
  return_float_outputs,
54
58
  wrapper=self.mixed_precision_wrapper,
55
59
  get_activation_quantizer_holder_fn=self.mixed_precision_activation_holder)
@@ -73,16 +77,17 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
73
77
  ValueError: if kernel attribute is quantized but not configurable.
74
78
  """
75
79
 
76
- if n.kernel_attr is None or not n.is_weights_quantization_enabled(n.kernel_attr):
80
+ kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
81
+ if kernel_attr is None or not n.is_weights_quantization_enabled(kernel_attr):
77
82
  return layer
78
- if not n.is_configurable_weight(n.kernel_attr): # pragma: no cover
83
+ if not n.is_configurable_weight(kernel_attr): # pragma: no cover
79
84
  raise ValueError(f'Weight wrapper is not expected to be created for non-configurable weight of node {n}.')
80
85
  return PytorchQuantizationWrapper(layer,
81
86
  weights_quantizers={
82
- n.kernel_attr: ConfigurableWeightsQuantizer(
87
+ kernel_attr: ConfigurableWeightsQuantizer(
83
88
  **self._get_weights_configurable_quantizer_kwargs(n,
84
- n.kernel_attr),
85
- kernel_attr=n.kernel_attr)})
89
+ kernel_attr),
90
+ kernel_attr=kernel_attr)})
86
91
 
87
92
  def _get_weights_configurable_quantizer_kwargs(self, n: BaseNode, attr: str) -> Dict[str, Any]:
88
93
  """
@@ -142,13 +147,14 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
142
147
  # activation number of bits (in reversed order).
143
148
  # since only kernel attribute is quantized in weights mixed precision,
144
149
  # if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
145
- n.sort_node_candidates()
150
+ n.sort_node_candidates(self.fw_info)
146
151
 
147
152
  max_candidate_idx = n.find_max_candidate_index()
148
153
 
154
+ kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
149
155
  activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': node_q_cfg_candidates,
150
156
  'max_candidate_idx': max_candidate_idx,
151
- 'kernel_attr': n.kernel_attr})] \
157
+ 'kernel_attr': kernel_attr})] \
152
158
  * num_of_outputs
153
159
 
154
160
  # Holder by definition uses a single quantizer for the activation quantization
@@ -171,7 +177,7 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
171
177
  # creating a mapping between graph nodes and model's layers for mixed precision configurability
172
178
  model_layers = dict(model.named_children())
173
179
  conf_node2layers = {n.name: self._find_layers_in_model_by_node(n, model_layers)
174
- for n in self.graph.get_configurable_sorted_nodes()}
180
+ for n in self.graph.get_configurable_sorted_nodes(self.fw_info)}
175
181
 
176
182
  return model, user_info, conf_node2layers
177
183
 
@@ -224,7 +230,8 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
224
230
 
225
231
  """
226
232
  # Only layers with kernel op are considered weights configurable
227
- weights_quant = False if n.kernel_attr is None else n.is_weights_quantization_enabled(n.kernel_attr)
233
+ kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
234
+ weights_quant = False if kernel_attr is None else n.is_weights_quantization_enabled(kernel_attr)
228
235
  act_quant = n.is_activation_quantization_enabled()
229
236
 
230
237
  if weights_quant and not act_quant:
@@ -30,6 +30,7 @@ from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
30
30
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
31
31
  from model_compression_toolkit.core.common.user_info import UserInformation
32
32
  from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder
33
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
33
34
  from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
34
35
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
35
36
  from mct_quantizers.common.constants import ACTIVATION_HOLDER_QUANTIZER
@@ -363,7 +364,7 @@ class PytorchModel(torch.nn.Module):
363
364
  """
364
365
  node_to_output_tensors_dict = dict()
365
366
  node_to_output_tensors_dict_float = dict()
366
- configurable_nodes = self.graph.get_configurable_sorted_nodes_names()
367
+ configurable_nodes = self.graph.get_configurable_sorted_nodes_names(DEFAULT_PYTORCH_INFO)
367
368
  for node in self.node_sort:
368
369
  op_func = self._get_op_func(node, configurable_nodes)
369
370
  input_tensors = _build_input_tensors_list(node,
@@ -439,6 +440,7 @@ class PyTorchModelBuilder(BaseModelBuilder):
439
440
  def __init__(self,
440
441
  graph: common.Graph,
441
442
  append2output=None,
443
+ fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
442
444
  return_float_outputs: bool = False,
443
445
  wrapper: Callable = None,
444
446
  get_activation_quantizer_holder_fn: Callable = None):
@@ -447,6 +449,7 @@ class PyTorchModelBuilder(BaseModelBuilder):
447
449
  Args:
448
450
  graph: Graph to build the model from.
449
451
  append2output: Nodes to append to model's output.
452
+ fw_info: Information about the specific framework of the model that is built.
450
453
  return_float_outputs: Whether the model returns float tensors or not.
451
454
  wrapper: A function wrapper Pytorch Layers.
452
455
  get_activation_quantizer_holder_fn: Function to retrieve a quantization holder for a node.
@@ -454,6 +457,7 @@ class PyTorchModelBuilder(BaseModelBuilder):
454
457
 
455
458
  super().__init__(graph,
456
459
  append2output,
460
+ fw_info,
457
461
  return_float_outputs)
458
462
 
459
463
  self.wrapper = wrapper
@@ -21,6 +21,7 @@ from model_compression_toolkit.core.common import BaseNode
21
21
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
22
22
  from model_compression_toolkit.core.pytorch.back2framework.quantization_wrapper.wrapper_quantize_config import \
23
23
  WrapperQuantizeConfig
24
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
24
25
  from model_compression_toolkit.core.pytorch.utils import set_model, to_torch_tensor
25
26
 
26
27
 
@@ -92,7 +93,7 @@ class QuantizedLayerWrapper(torch.nn.Module):
92
93
  self.layer = n.type(**framework_attr)
93
94
  self.layer.load_state_dict({k: torch.Tensor(v) for k, v in n.weights.items()}, strict=False)
94
95
 
95
- def _quantize_weights(self, n: BaseNode):
96
+ def _quantize_weights(self, n:BaseNode):
96
97
  """
97
98
  Quantize node's weights and load them as the layer's weights.
98
99
 
@@ -103,7 +104,7 @@ class QuantizedLayerWrapper(torch.nn.Module):
103
104
  None.
104
105
  """
105
106
 
106
- self.weight_attrs = [n.kernel_attr]
107
+ self.weight_attrs = DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(n.type)
107
108
 
108
109
  # float_weights is a list of weights for each attribute that we want to quantize.
109
110
  float_weights = [n.get_weights_by_keys(attr) for attr in self.weight_attrs]
@@ -17,13 +17,13 @@ from typing import List, Tuple
17
17
 
18
18
  import torch
19
19
 
20
+ from model_compression_toolkit.core import FrameworkInfo
20
21
  from model_compression_toolkit.core import common
21
22
  from model_compression_toolkit.core.common import BaseNode
22
- from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantization_fn
23
- from model_compression_toolkit.core.pytorch.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
24
23
  from model_compression_toolkit.core.common.user_info import UserInformation
25
24
  from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder, \
26
25
  PytorchModel
26
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
27
27
 
28
28
 
29
29
  class QuantizedPyTorchModel(PytorchModel):
@@ -61,9 +61,7 @@ class QuantizedPyTorchModel(PytorchModel):
61
61
  if node.is_activation_quantization_enabled():
62
62
  if isinstance(input_tensors, list):
63
63
  input_tensors = torch.cat(input_tensors, dim=0)
64
- activation_quantizer = get_activation_quantization_fn(node.final_activation_quantization_cfg,
65
- get_activation_quantization_fn_factory)
66
- return activation_quantizer(input_tensors)
64
+ return node.final_activation_quantization_cfg.quantize_node_output(input_tensors)
67
65
  return input_tensors
68
66
 
69
67
 
@@ -72,17 +70,20 @@ class QuantizedPyTorchModelBuilder(PyTorchModelBuilder):
72
70
  def __init__(self,
73
71
  graph: common.Graph,
74
72
  append2output=None,
73
+ fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
75
74
  return_float_outputs: bool = False):
76
75
  """
77
76
 
78
77
  Args:
79
78
  graph: Graph to build the model from.
80
79
  append2output: Nodes to append to model's output.
80
+ fw_info: Information about the specific framework of the model that is built.
81
81
  return_float_outputs: Whether the model returns float tensors or not.
82
82
  """
83
83
 
84
84
  super().__init__(graph,
85
85
  append2output,
86
+ fw_info,
86
87
  return_float_outputs)
87
88
 
88
89
  def build_model(self) -> Tuple[PytorchModel, UserInformation]:
@@ -12,101 +12,87 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Any
16
- from functools import wraps
17
-
18
- from torch.nn import Hardsigmoid, ReLU, ReLU6, Softmax, Sigmoid, GELU, SELU, SiLU
19
- from torch.nn.functional import hardsigmoid, relu, relu6, softmax, gelu, selu, silu
15
+ from torch.nn import Hardsigmoid, ReLU, ReLU6, Softmax, Sigmoid, GELU, SELU
16
+ from torch.nn.functional import hardsigmoid, relu, relu6, softmax, gelu, selu
20
17
  from torch.nn import Conv2d, ConvTranspose2d, Linear
21
18
  from torch import sigmoid
22
19
 
23
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo, set_fw_info, ChannelAxisMapping
20
+ from model_compression_toolkit.defaultdict import DefaultDict
21
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo, DEFAULT_KERNEL_ATTRIBUTES
22
+ from mct_quantizers import QuantizationMethod
24
23
  from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
25
24
  from model_compression_toolkit.core.pytorch.constants import KERNEL
26
-
27
-
28
- class PyTorchInfo(FrameworkInfo):
29
- """
30
- Extra field defined to handle Activation layer functions:
31
- """
32
-
33
- """
34
- Map each layer to it's weight attribute that should get quantized.
35
- If a layer that is not listed here is queried, None is returned.
36
- """
37
- kernel_ops_attribute_mapping = {Conv2d: KERNEL,
38
- ConvTranspose2d: KERNEL,
39
- Linear: KERNEL}
40
-
41
- """
42
- Map a layer to its kernel's output and input channels indices.
43
- Map's values are tuples of (output_channel_index, input_channel_index).
44
- Default value is returned for layers that are not included.
45
- """
46
- kernel_channels_mapping = {Conv2d: ChannelAxisMapping(0, 1),
47
- Linear: ChannelAxisMapping(0, 1),
48
- ConvTranspose2d: ChannelAxisMapping(1, 0)}
49
-
50
- """
51
- Map a layer to its output channel axis.
52
- Where axis=-1 is the last axis
53
- """
54
- out_channel_axis_mapping = {Conv2d: 1,
55
- Linear: -1,
56
- ConvTranspose2d: 1}
57
-
58
- """
59
- Map from an Pytorch module to its min/max output values (if known).
60
- The values are used for tensor min/max values initialization.
61
- """
62
- _layer_min_max_mapping = {Softmax: (0, SOFTMAX_THRESHOLD),
63
- softmax: (0, SOFTMAX_THRESHOLD),
64
- Sigmoid: (0, 1),
65
- sigmoid: (0, 1),
66
- Hardsigmoid: (0, 1),
67
- hardsigmoid: (0, 1),
68
- ReLU: (0, None),
69
- relu: (0, None),
70
- ReLU6: (0, None),
71
- relu6: (0, None),
72
- GELU: (-0.17, None),
73
- gelu: (-0.17, None),
74
- SELU: (-1.76, None),
75
- selu: (-1.76, None),
76
- silu: (-0.279, None),
77
- SiLU: (-0.279, None),
78
- }
79
-
80
- @classmethod
81
- def get_kernel_channels(cls, node_type: Any) -> ChannelAxisMapping:
82
- """
83
- Returns node's channels mapping from kernel_channels_mapping or framework specific default value.
84
- Args:
85
- node_type: A node type.
86
-
87
- Returns:
88
- Node's channels mapping.
89
-
90
- """
91
- return cls.kernel_channels_mapping.get(node_type, cls._default_channel_mapping)
92
-
93
- @classmethod
94
- def get_out_channel_axis(cls, node_type: Any):
95
- """
96
- Returns node's output channel mapping from out_channel_axis_mapping or framework specific default value.
97
- Args:
98
- node_type: A node type.
99
-
100
- Returns:
101
- Node's output channel axis.
102
-
103
- """
104
- return cls.out_channel_axis_mapping.get(node_type)
105
-
106
-
107
- def set_pytorch_info(func):
108
- @wraps(func)
109
- def wrapper(*args, **kwargs):
110
- set_fw_info(PyTorchInfo)
111
- return func(*args, **kwargs)
112
- return wrapper
25
+ from model_compression_toolkit.core.pytorch.quantizer.fake_quant_builder import power_of_two_quantization, \
26
+ symmetric_quantization, uniform_quantization
27
+ from model_compression_toolkit.core.pytorch.quantizer.lut_fake_quant import activation_lut_kmean_quantizer
28
+
29
+ """
30
+ Map each layer to a list of its' weights attributes that should get quantized.
31
+ If a layer that is not listed here is queried, [None] is returned.
32
+ """
33
+ KERNEL_ATTRIBUTES = DefaultDict({Conv2d: [KERNEL],
34
+ ConvTranspose2d: [KERNEL],
35
+ Linear: [KERNEL]},
36
+ DEFAULT_KERNEL_ATTRIBUTES)
37
+
38
+ """
39
+ Map a layer to its kernel's output and input channels indices.
40
+ Map's values are tuples of (output_channel_index, input_channel_index).
41
+ Default value is returned for layers that are not included.
42
+ """
43
+ DEFAULT_CHANNEL_AXIS_DICT = DefaultDict({Conv2d: (0, 1),
44
+ Linear: (0, 1),
45
+ ConvTranspose2d: (1, 0)},
46
+ (None, None))
47
+
48
+ """
49
+ Map a layer to its output channel axis.
50
+ Where axis=-1 is the last axis
51
+ """
52
+ DEFAULT_OUT_CHANNEL_AXIS_DICT = DefaultDict({Conv2d: 1,
53
+ Linear: -1,
54
+ ConvTranspose2d: 1},
55
+ 1)
56
+
57
+
58
+ """
59
+ Map from an activation function to its min/max output values (if known).
60
+ The values are used for tensor min/max values initialization.
61
+ """
62
+ ACTIVATION2MINMAX = {} # should be an empty dict in Pytorch
63
+
64
+ """
65
+ Map from an Pytorch module to its min/max output values (if known).
66
+ The values are used for tensor min/max values initialization.
67
+ """
68
+ LAYER2MINMAX = {Softmax: (0, SOFTMAX_THRESHOLD),
69
+ softmax: (0, SOFTMAX_THRESHOLD),
70
+ Sigmoid: (0, 1),
71
+ sigmoid: (0, 1),
72
+ Hardsigmoid: (0, 1),
73
+ hardsigmoid: (0, 1),
74
+ ReLU: (0, None),
75
+ relu: (0, None),
76
+ ReLU6: (0, None),
77
+ relu6: (0, None),
78
+ GELU: (-0.17, None),
79
+ gelu: (-0.17, None),
80
+ SELU: (-1.76, None),
81
+ selu: (-1.76, None),
82
+ }
83
+
84
+ """
85
+ Mapping from a QuantizationMethod to an activation quantizer function.
86
+ """
87
+ ACTIVATION_QUANTIZER_MAPPING = {QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
88
+ QuantizationMethod.SYMMETRIC: symmetric_quantization,
89
+ QuantizationMethod.UNIFORM: uniform_quantization,
90
+ QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer}
91
+
92
+
93
+ DEFAULT_PYTORCH_INFO = FrameworkInfo(ACTIVATION_QUANTIZER_MAPPING,
94
+ DEFAULT_CHANNEL_AXIS_DICT,
95
+ ACTIVATION2MINMAX,
96
+ LAYER2MINMAX,
97
+ KERNEL_ATTRIBUTES,
98
+ DEFAULT_OUT_CHANNEL_AXIS_DICT)
@@ -21,18 +21,19 @@ from model_compression_toolkit.core.common.graph.base_graph import Graph
21
21
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
22
22
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
23
23
  from model_compression_toolkit.core.pytorch.constants import IN_CHANNELS, OUT_CHANNELS, KERNEL_SIZE, KERNEL, BIAS
24
- from model_compression_toolkit.core.common.framework_info import get_fw_info
24
+ from model_compression_toolkit.core.common import FrameworkInfo
25
25
 
26
26
 
27
27
  class FunctionalConvSubstitution(common.BaseSubstitution):
28
28
  """
29
29
  Substitute functional convolutions with Layers
30
30
  """
31
- def __init__(self):
31
+ def __init__(self, fw_info: FrameworkInfo):
32
32
  """
33
33
  Matches a functional conv node
34
34
  """
35
35
  func_node = NodeOperationMatcher(conv2d) | NodeOperationMatcher(conv_transpose2d)
36
+ self.fw_info = fw_info
36
37
  super().__init__(matcher_instance=func_node)
37
38
 
38
39
  def substitute(self,
@@ -55,7 +56,7 @@ class FunctionalConvSubstitution(common.BaseSubstitution):
55
56
  else:
56
57
  Logger.critical(f'Substitution filter mismatch. Layer {func_node.type}. Must be {type(Conv2d)} or {type(ConvTranspose2d)}.') # pragma: no cover
57
58
 
58
- out_channel_index, in_channel_index = get_fw_info().get_kernel_channels(new_layer)
59
+ out_channel_index, in_channel_index = self.fw_info.kernel_channels_mapping.get(new_layer)
59
60
 
60
61
  # Create new node of layer convolution
61
62
  if 1 not in func_node.weights:
@@ -95,11 +95,11 @@ class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
95
95
  else:
96
96
  return graph
97
97
  elif non_linear_node.is_match_type(hardtanh):
98
- kwargs = non_linear_node.op_call_kwargs
99
- if (kwargs[HARDTANH_MIN_VAL] == 0.0) and not \
100
- (np.log2(kwargs[HARDTANH_MAX_VAL]).astype(int) - np.log2(kwargs[HARDTANH_MAX_VAL]) == 0):
101
- scale_factor = kwargs[HARDTANH_MAX_VAL] / self.threshold
102
- non_linear_node.functional_op.__defaults__ = (0.0, self.threshold, kwargs[INPLACE])
98
+ if (non_linear_node.framework_attr[HARDTANH_MIN_VAL] == 0.0) and not \
99
+ (np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]).astype(int) -
100
+ np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]) == 0):
101
+ scale_factor = non_linear_node.framework_attr[HARDTANH_MAX_VAL] / self.threshold
102
+ non_linear_node.functional_op.__defaults__ = (0.0, self.threshold, non_linear_node.framework_attr[INPLACE])
103
103
  else:
104
104
  return graph
105
105
  else:
@@ -46,15 +46,17 @@ class ScaleEqualization(BaseScaleEqualization):
46
46
  """
47
47
 
48
48
  def __init__(self,
49
- quant_config: QuantizationConfig):
49
+ quant_config: QuantizationConfig,
50
+ fw_info: FrameworkInfo):
50
51
  """
51
52
  Initialize a ScaleEqualization object.
52
53
  Args:
53
54
  quant_config: Quantization configuration.
55
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
54
56
  groups of layers by how they should be quantized, etc.)
55
57
  """
56
58
 
57
- super().__init__(quant_config=quant_config, matcher_instance=MATCHER,
59
+ super().__init__(quant_config=quant_config, fw_info=fw_info, matcher_instance=MATCHER,
58
60
  kernel_str=KERNEL, bias_str=BIAS)
59
61
 
60
62
 
@@ -64,13 +66,15 @@ class ScaleEqualizationWithPad(BaseScaleEqualization):
64
66
  """
65
67
 
66
68
  def __init__(self,
67
- quant_config: QuantizationConfig):
69
+ quant_config: QuantizationConfig,
70
+ fw_info: FrameworkInfo):
68
71
  """
69
72
  Initialize a ScaleEqualization object.
70
73
  Args:
71
74
  quant_config: Quantization configuration.
75
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
72
76
  groups of layers by how they should be quantized, etc.)
73
77
  """
74
78
 
75
- super().__init__(quant_config=quant_config, matcher_instance=MATCHER_WITH_PAD,
79
+ super().__init__(quant_config=quant_config, fw_info=fw_info, matcher_instance=MATCHER_WITH_PAD,
76
80
  kernel_str=KERNEL, bias_str=BIAS)
@@ -29,7 +29,6 @@ from model_compression_toolkit.core.common import BaseNode, Graph
29
29
  from model_compression_toolkit.core.common.graph.graph_matchers import EdgeMatcher
30
30
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
31
31
  from model_compression_toolkit.core.common.substitutions.shift_negative_activation import apply_shift_negative_correction
32
- from model_compression_toolkit.core.pytorch.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
33
32
  from model_compression_toolkit.core.pytorch.constants import PAD, VALUE, PADDING, BIAS, USE_BIAS
34
33
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
35
34
 
@@ -215,13 +214,15 @@ def is_padding_node_and_node_has_padding(pad_node_to_consider: BaseNode,
215
214
 
216
215
 
217
216
  def pytorch_apply_shift_negative_correction(graph: Graph,
218
- core_config: CoreConfig) -> Graph:
217
+ core_config: CoreConfig,
218
+ fw_info: FrameworkInfo) -> Graph:
219
219
  """
220
220
  Apply shift negative correction (SNC) on a graph built from a Pytorch model.
221
221
 
222
222
  Args:
223
223
  graph: Graph to apply SNC on.
224
224
  core_config: Quantization configuration.
225
+ fw_info: FrameworkInfo object with information about the specific framework's module.
225
226
 
226
227
  Returns:
227
228
  Graph after SNC.
@@ -229,6 +230,7 @@ def pytorch_apply_shift_negative_correction(graph: Graph,
229
230
  snc_node, linear_node, bypass_node, pad_node = shift_negative_activation_node_matchers()
230
231
  return apply_shift_negative_correction(graph,
231
232
  core_config,
233
+ fw_info,
232
234
  snc_node,
233
235
  linear_node,
234
236
  bypass_node,
@@ -240,5 +242,4 @@ def pytorch_apply_shift_negative_correction(graph: Graph,
240
242
  PADDING,
241
243
  BIAS,
242
244
  USE_BIAS,
243
- get_activation_quantization_fn_factory,
244
245
  params_search_quantization_fn=params_search_quantization_fn)
@@ -23,6 +23,7 @@ from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESS
23
23
  from model_compression_toolkit.core.common import Graph
24
24
  from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
25
25
  from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
26
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
26
27
  from model_compression_toolkit.core.pytorch.hessian.hessian_scores_calculator_pytorch import \
27
28
  HessianScoresCalculatorPytorch
28
29
  from model_compression_toolkit.logger import Logger
@@ -91,14 +92,22 @@ class WeightsHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
91
92
  for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
92
93
 
93
94
  # Check if the target node's layer type is supported.
94
- if not ipt_node.kernel_attr:
95
+ if not DEFAULT_PYTORCH_INFO.is_kernel_op(ipt_node.type):
95
96
  Logger.critical(f"Hessian information with respect to weights is not supported for "
96
97
  f"{ipt_node.type} layers.") # pragma: no cover
97
98
 
98
- weights_tensor = getattr(getattr(model, ipt_node.name), ipt_node.kernel_attr)
99
+ # Get the weight attributes for the target node type
100
+ weights_attributes = DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(ipt_node.type)
101
+
102
+ # Get the weight tensor for the target node
103
+ if len(weights_attributes) != 1: # pragma: no cover
104
+ Logger.critical(f"Currently, Hessian scores with respect to weights are supported only for nodes with a "
105
+ f"single weight attribute. {len(weights_attributes)} attributes found.")
106
+
107
+ weights_tensor = getattr(getattr(model, ipt_node.name), weights_attributes[0])
99
108
 
100
109
  # Get the output channel index
101
- output_channel_axis = ipt_node.channel_axis.output
110
+ output_channel_axis, _ = DEFAULT_PYTORCH_INFO.kernel_channels_mapping.get(ipt_node.type)
102
111
  shape_channel_axis = [i for i in range(len(weights_tensor.shape))]
103
112
  if self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL:
104
113
  shape_channel_axis.remove(output_channel_axis)
@@ -20,7 +20,6 @@ from model_compression_toolkit.core.common.mixed_precision.configurable_quantize
20
20
  verify_candidates_descending_order, init_activation_quantizers
21
21
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
22
22
  CandidateNodeQuantizationConfig
23
- from model_compression_toolkit.core.pytorch.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
24
23
  from model_compression_toolkit.logger import Logger
25
24
  from mct_quantizers import QuantizationMethod
26
25
  from mct_quantizers import QuantizationTarget
@@ -68,7 +67,7 @@ class ConfigurableActivationQuantizer(BasePyTorchInferableQuantizer):
68
67
  Logger.critical("Unsupported configuration: Mixing candidates with differing activation quantization states (enabled/disabled).") # pragma: no cover
69
68
 
70
69
  # Setting layer's activation
71
- self.activation_quantizers = init_activation_quantizers(self.node_q_cfg, get_activation_quantization_fn_factory)
70
+ self.activation_quantizers = init_activation_quantizers(self.node_q_cfg)
72
71
  self.active_quantization_config_index = max_candidate_idx # initialize with first config as default
73
72
 
74
73
  def set_active_activation_quantizer(self, index: Optional[int]):