mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -16,20 +16,21 @@ import copy
16
16
  import numpy as np
17
17
  from typing import List, Tuple, Any, Callable
18
18
 
19
+ from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
19
20
  from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
20
21
  ActivationQuantizationMode
21
22
  from model_compression_toolkit.logger import Logger
22
- from model_compression_toolkit.core.common import Graph, BaseNode
23
+ from model_compression_toolkit.core.common import FrameworkInfo, Graph, BaseNode
23
24
  from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
24
25
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
26
+ from model_compression_toolkit.core.common.quantization.set_node_quantization_config import create_node_activation_qc, \
27
+ set_quantization_configs_to_node
25
28
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
26
29
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
27
- import compute_activation_qparams
30
+ import get_activations_qparams
28
31
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
29
32
  _mse_error_histogram
30
33
  from model_compression_toolkit.core.common.quantization.quantization_params_generation import z_score_filter
31
- from model_compression_toolkit.quantization_preparation.load_fqc import set_quantization_configs_to_node, \
32
- fetch_qc_options_for_node
33
34
  from model_compression_toolkit.target_platform_capabilities import QuantizationMethod, AttributeQuantizationConfig
34
35
 
35
36
  """
@@ -45,6 +46,7 @@ If the linear node pads the input tensor with zeros, we modify the padded value
45
46
 
46
47
  def op2d_bias_correction(op2d_node: BaseNode,
47
48
  shift_to_correct: float,
49
+ fw_info: FrameworkInfo,
48
50
  bias_str: str,
49
51
  bias_flag_str: str):
50
52
  """
@@ -55,6 +57,7 @@ def op2d_bias_correction(op2d_node: BaseNode,
55
57
  op2d_node: Node to compute its bias correction term.
56
58
  shift_to_correct: Value that was used to shift the output tensor of
57
59
  the non-linear node.
60
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
58
61
  bias_str:
59
62
  bias_flag_str: The framework specific attribute name of the bias flag.
60
63
  """
@@ -66,19 +69,21 @@ def op2d_bias_correction(op2d_node: BaseNode,
66
69
  # Add an attribute quantization configuration to the newly added bias attribute, with disabled quantization
67
70
  for qc in op2d_node.candidates_quantization_cfg:
68
71
  qc.weights_quantization_cfg.set_attr_config(bias_flag_str,
69
- WeightsAttrQuantizationConfig(AttributeQuantizationConfig(
72
+ WeightsAttrQuantizationConfig(QuantizationConfig(),
73
+ AttributeQuantizationConfig(
70
74
  enable_weights_quantization=False)))
71
75
 
72
76
  # Each node adds a different noise due to the shifting. It depends on the
73
77
  # dimensions of the kernel, thus the correction term is a function of
74
78
  # the layer type.
75
- kernel = op2d_node.get_weights_by_keys(op2d_node.kernel_attr)
79
+ kernel = op2d_node.get_weights_by_keys(fw_info.kernel_ops_attributes_mapping.get(op2d_node.type)[0])
76
80
  if kernel is not None:
81
+ output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(op2d_node.type)
77
82
  axis_not_output_channel = list(range(len(kernel.shape)))
78
- axis_not_output_channel.remove(op2d_node.channel_axis.output)
83
+ axis_not_output_channel.remove(output_channel_index)
79
84
 
80
85
  # special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters
81
- if op2d_node.channel_axis.output == op2d_node.channel_axis.input:
86
+ if output_channel_index == input_channel_index:
82
87
  axis_not_output_channel.remove(3) # 3 is the depth multiplier index
83
88
 
84
89
  bias_correction = shift_to_correct * np.sum(kernel, axis=tuple(axis_not_output_channel))
@@ -245,13 +250,13 @@ def shift_negative_function(graph: Graph,
245
250
  core_config: CoreConfig,
246
251
  non_linear_node: BaseNode,
247
252
  op2d_node: BaseNode,
253
+ fw_info: FrameworkInfo,
248
254
  create_add_node: Callable,
249
255
  get_padding_values: Callable,
250
256
  create_pad_node: Callable,
251
257
  padding_str: str,
252
258
  bias_str: str,
253
259
  bias_flag_str: str,
254
- get_activation_quantization_fn_factory: Callable,
255
260
  zero_padding_node: BaseNode = None,
256
261
  bypass_nodes: List = None,
257
262
  params_search_quantization_fn: Callable = None
@@ -271,13 +276,14 @@ def shift_negative_function(graph: Graph,
271
276
  non_linear_node: Non-linear node with negative values to shift.
272
277
  op2d_node: Linear node to correct its bias to overcome the expected error due to
273
278
  the shifting.
279
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
280
+ groups of layers by how they should be quantized, etc.)
274
281
  create_add_node: Function to create an add node.
275
282
  get_padding_values: Function to compute the op2d node's padding values
276
283
  create_pad_node: Function to create an pad node.
277
284
  padding_str: The framework specific attribute name of the padding.
278
285
  bias_str: The framework specific attribute name of the bias.
279
286
  bias_flag_str: The framework specific attribute name of the bias flag.
280
- get_activation_quantization_fn_factory: activation quantization functions factory.
281
287
  zero_padding_node: ZeroPadding2D node that may be in the graph before the linear layer.
282
288
  params_search_quantization_fn: Function to quantize np tensor using a framework (tf/torch) quantization method. Needed for better param_search estimating the expected loss.
283
289
 
@@ -293,12 +299,13 @@ def shift_negative_function(graph: Graph,
293
299
  # all candidates have same activation config, so taking the first candidate for calculations
294
300
  non_linear_node_cfg_candidate = non_linear_node.candidates_quantization_cfg[0].activation_quantization_cfg
295
301
 
302
+
296
303
  # get the non-linear activation threshold
297
304
  activation_threshold = non_linear_node_cfg_candidate.activation_quantization_params.get(THRESHOLD)
298
305
 
299
306
  negative_rate = np.abs(min_to_correct) / activation_threshold
300
307
 
301
- enable_sub = negative_rate <= core_config.quantization_config.shift_negative_ratio
308
+ enable_sub = negative_rate <= non_linear_node_cfg_candidate.shift_negative_ratio
302
309
  if min_to_correct >= 0 or not enable_sub:
303
310
  return graph
304
311
 
@@ -316,7 +323,7 @@ def shift_negative_function(graph: Graph,
316
323
  if core_config.quantization_config.shift_negative_params_search:
317
324
 
318
325
  hist_bins, hist_count = graph.get_out_stats_collector(non_linear_node).hc.get_histogram()
319
- hist_count = z_score_filter(core_config.quantization_config.z_threshold,
326
+ hist_count = z_score_filter(non_linear_node_cfg_candidate.z_threshold,
320
327
  hist_bins, hist_count)
321
328
 
322
329
  min_mse, _th, _shift = np.inf, None, None
@@ -327,15 +334,13 @@ def shift_negative_function(graph: Graph,
327
334
  'float32') # Change to type float32 to support tensorflow dtypes
328
335
  for _shift_value in _q_points:
329
336
  _hist_bins = hist_bins.astype(np.float32) + _shift_value
330
- quantizer_factory = get_activation_quantization_fn_factory(
331
- non_linear_node_cfg_candidate.activation_quantization_method)
332
- fw_quant_fn = quantizer_factory(non_linear_node_cfg_candidate.activation_n_bits, qparams)
337
+ fw_quant_fn = non_linear_node_cfg_candidate.activation_quantization_fn(non_linear_node_cfg_candidate.activation_n_bits,qparams)
333
338
  """
334
339
  In SNC, when better shifting values are tested for better choice,
335
340
  the histogram (which is a numpy object) is quantized using the non-linear node activation
336
341
  quantization function (to estimate the expected mse comparing to the original histogram).
337
342
  The quantization function is a framework function, which makes it fail since it
338
- expects a fw tensor. The common part of SNC receives an argument which is a callable
343
+ expects a fw tensor. The commmon part of SNC receives an argument which is a callable
339
344
  that receives two argument and returns one: it gets the fw activation quantization function
340
345
  and the bins to quantize. The function (of each fw) responsible for doing (if needed) a preprocessing and postprocessing
341
346
  to the bins which is a numpy object.
@@ -385,6 +390,7 @@ def shift_negative_function(graph: Graph,
385
390
  first_node=non_linear_node)
386
391
  op2d_bias_correction(op2d_node,
387
392
  shift_value,
393
+ fw_info,
388
394
  bias_str,
389
395
  bias_flag_str)
390
396
 
@@ -395,9 +401,12 @@ def shift_negative_function(graph: Graph,
395
401
  graph.set_out_stats_collector_to_node(add_node, add_node_stats_collector)
396
402
  graph.shift_stats_collector(add_node, np.array(shift_value))
397
403
 
398
- set_quantization_configs_to_node(node=add_node,
404
+ set_quantization_configs_to_node(fw_info=fw_info,
405
+ node=add_node,
399
406
  graph=graph,
400
- fqc=graph.fqc)
407
+ quant_config=core_config.quantization_config,
408
+ fqc=graph.fqc,
409
+ mixed_precision_enable=core_config.is_mixed_precision_enabled)
401
410
 
402
411
  update_fused_op_with_add(graph=graph,
403
412
  non_linear_node=non_linear_node,
@@ -419,9 +428,12 @@ def shift_negative_function(graph: Graph,
419
428
  last_node=op2d_node)
420
429
 
421
430
  # Set quantization configuration to node, even though we do not quantize it:
422
- set_quantization_configs_to_node(node=pad_node,
431
+ set_quantization_configs_to_node(fw_info=fw_info,
432
+ node=pad_node,
423
433
  graph=graph,
424
- fqc=graph.fqc)
434
+ quant_config=core_config.quantization_config,
435
+ fqc=graph.fqc,
436
+ mixed_precision_enable=core_config.is_mixed_precision_enabled)
425
437
 
426
438
  for candidate_qc in pad_node.candidates_quantization_cfg:
427
439
  candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
@@ -446,7 +458,7 @@ def shift_negative_function(graph: Graph,
446
458
  bypass_candidate_qc.activation_quantization_cfg.activation_quantization_params[SIGNED] = False
447
459
  graph.shift_stats_collector(bypass_node, np.array(shift_value))
448
460
 
449
- add_node_qco = fetch_qc_options_for_node(add_node, graph.fqc).quantization_configurations
461
+ add_node_qco = add_node.get_qco(graph.fqc).quantization_configurations
450
462
  add_supported_bitwidths = [c.activation_n_bits for c in add_node_qco]
451
463
  if original_non_linear_activation_nbits not in add_supported_bitwidths:
452
464
  raise ValueError(
@@ -454,16 +466,19 @@ def shift_negative_function(graph: Graph,
454
466
  f"bitwidth is {original_non_linear_activation_nbits}. Consider adapting the TPC so 'Add' will support the "
455
467
  f"same bitwidth as {non_linear_node.type} or disable shift negative correction.")
456
468
 
457
- set_quantization_configs_to_node(add_node, graph, graph.fqc)
458
- # TODO: do we not quantize the weights of this 'add' on purpose?
459
- add_node.quantization_cfg.disable_weights_quantization()
469
+ for op_qc_idx, candidate_qc in enumerate(add_node.candidates_quantization_cfg):
470
+ for attr in add_node.get_node_weights_attributes():
471
+ # TODO: do we not quantize the weights of this 'add' on purpose?
472
+ candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False
473
+
474
+ candidate_qc.activation_quantization_cfg = create_node_activation_qc(core_config.quantization_config,
475
+ fw_info,
476
+ add_node_qco[op_qc_idx])
460
477
 
461
- def update(c):
462
- c.activation_quantization_cfg.activation_n_bits = original_non_linear_activation_nbits
463
- c.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold,
464
- SIGNED: False})
478
+ candidate_qc.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold,
479
+ SIGNED: False})
465
480
 
466
- add_node.quantization_cfg.update_all(update, remove_duplicates=True)
481
+ candidate_qc.activation_quantization_cfg.activation_n_bits = original_non_linear_activation_nbits
467
482
 
468
483
  # Add the new padding node to a fused op with the op2d.
469
484
  if pad_node:
@@ -471,14 +486,12 @@ def shift_negative_function(graph: Graph,
471
486
  pad_node=pad_node,
472
487
  op2d_node=op2d_node)
473
488
 
474
- if core_config.quantization_config.shift_negative_threshold_recalculation:
475
- activation_param = compute_activation_qparams(quant_cfg=core_config.quantization_config,
476
- node_activation_quant_cfg=non_linear_node_cfg_candidate,
477
- node_prior_info=non_linear_node.prior_info,
478
- out_stats_container=graph.get_out_stats_collector(
479
- non_linear_node))
489
+ if non_linear_node_cfg_candidate.shift_negative_threshold_recalculation:
490
+ activation_param = get_activations_qparams(activation_quant_cfg=non_linear_node_cfg_candidate,
491
+ nodes_prior_info=non_linear_node.prior_info,
492
+ out_stats_container=graph.get_out_stats_collector(non_linear_node))
480
493
 
481
- assert activation_param.get(SIGNED) is False
494
+ assert activation_param.get(SIGNED) == False
482
495
  for candidate_qc in non_linear_node.candidates_quantization_cfg:
483
496
  candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_param)
484
497
 
@@ -560,6 +573,7 @@ def get_next_nodes_to_correct(n: BaseNode,
560
573
 
561
574
  def apply_shift_negative_correction(graph: Graph,
562
575
  core_config: CoreConfig,
576
+ fw_info: FrameworkInfo,
563
577
  snc_node_types: NodeOperationMatcher,
564
578
  linear_node_types: NodeOperationMatcher,
565
579
  bypass_node_types: NodeOperationMatcher,
@@ -571,7 +585,6 @@ def apply_shift_negative_correction(graph: Graph,
571
585
  padding_str: str,
572
586
  bias_str: str,
573
587
  bias_flag_str: str,
574
- get_activation_quantization_fn_factory: Callable,
575
588
  params_search_quantization_fn: Callable=None) -> Graph:
576
589
  """
577
590
  Apply the substitution even if the linear node is not immediately after
@@ -580,6 +593,7 @@ def apply_shift_negative_correction(graph: Graph,
580
593
  Args:
581
594
  graph: Graph to apply the substitution on.
582
595
  core_config: Quantization configuration to build the substitutions list according to.
596
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
583
597
  groups of layers by how they should be quantized, etc.)
584
598
  snc_node_types: Types of activation nodes with negative outputs to consider.
585
599
  linear_node_types: Types of linear nodes to consider.
@@ -593,9 +607,6 @@ def apply_shift_negative_correction(graph: Graph,
593
607
  padding_str: The framework specific attribute name of the padding.
594
608
  bias_str: The framework specific attribute name of the bias.
595
609
  bias_flag_str: The framework specific attribute name of the bias flag.
596
- get_activation_quantization_fn_factory: activation quantization functions factory.
597
- params_search_quantization_fn: Function to quantize np tensor using a framework (tf/torch) quantization method. Needed for better param_search estimating the expected loss.
598
-
599
610
  Returns:
600
611
  Graph after applying shift negative on selected activations.
601
612
  """
@@ -603,8 +614,9 @@ def apply_shift_negative_correction(graph: Graph,
603
614
  nodes = list(graph.nodes())
604
615
  for n in nodes:
605
616
  # Skip substitution if QuantizationMethod is uniform.
606
- if any(aqc.activation_quantization_cfg.activation_quantization_method == QuantizationMethod.UNIFORM
607
- for aqc in n.candidates_quantization_cfg):
617
+ node_qco = n.get_qco(graph.fqc)
618
+ if any([op_qc.activation_quantization_method is QuantizationMethod.UNIFORM
619
+ for op_qc in node_qco.quantization_configurations]):
608
620
  continue
609
621
 
610
622
  if snc_node_types.apply(n):
@@ -620,13 +632,13 @@ def apply_shift_negative_correction(graph: Graph,
620
632
  core_config,
621
633
  n,
622
634
  linear_node,
635
+ fw_info,
623
636
  create_add_node,
624
637
  get_padding_values,
625
638
  create_pad_node,
626
639
  padding_str,
627
640
  bias_str,
628
641
  bias_flag_str,
629
- get_activation_quantization_fn_factory,
630
642
  zero_padding_node=pad_node,
631
643
  bypass_nodes=bypass_nodes,
632
644
  params_search_quantization_fn=params_search_quantization_fn)
@@ -50,7 +50,9 @@ class BaseVirtualActivationWeightsComposition(BaseSubstitution):
50
50
  return graph
51
51
 
52
52
  # Virtual composed activation-weights node
53
- v_node = VirtualActivationWeightsNode(act_node, weights_node)
53
+ v_node = VirtualActivationWeightsNode(act_node,
54
+ weights_node,
55
+ fw_info=graph.fw_info)
54
56
 
55
57
  # Update graph
56
58
  graph.add_node(v_node)
@@ -50,7 +50,7 @@ class BaseWeightsActivationSplit(BaseSubstitution):
50
50
  Graph after applying the substitution.
51
51
  """
52
52
  # The decomposition works on linear nodes, that is, nodes with kernel ops
53
- kernel_attr = node.kernel_attr
53
+ kernel_attr = graph.fw_info.get_kernel_op_attributes(node.type)[0]
54
54
  if kernel_attr is None:
55
55
  Logger.critical(f"Trying to split node weights and activation, but node "
56
56
  f"{node.name} doesn't have a kernel attribute.")
@@ -59,19 +59,22 @@ class NNVisualizer:
59
59
  def __init__(self,
60
60
  graph_float: Graph,
61
61
  graph_quantized: Graph,
62
- fw_impl: FrameworkImplementation):
62
+ fw_impl: FrameworkImplementation,
63
+ fw_info: FrameworkInfo):
63
64
  """
64
65
  Initialize a NNVisualizer object.
65
66
  Args:
66
67
  graph_float: Float version of the graph.
67
68
  graph_quantized: Quantized version of the graph.
68
69
  fw_impl: Framework implementation with framework-specific methods implementation.
70
+ fw_info: Framework info with framework-specific information.
69
71
 
70
72
  """
71
73
 
72
74
  self.graph_float = graph_float
73
75
  self.graph_quantized = graph_quantized
74
76
  self.fw_impl = fw_impl
77
+ self.fw_info = fw_info
75
78
 
76
79
  # Get compare points of two graphs.
77
80
  self.compare_points, self.compare_points_name = _get_compare_points(self.graph_quantized)
@@ -89,11 +92,13 @@ class NNVisualizer:
89
92
 
90
93
  self.quantized_model, _ = self.fw_impl.model_builder(self.graph_quantized,
91
94
  mode=ModelBuilderMode.QUANTIZED,
92
- append2output=self.compare_points)
95
+ append2output=self.compare_points,
96
+ fw_info=self.fw_info)
93
97
 
94
98
  self.float_model, _ = self.fw_impl.model_builder(self.graph_float,
95
99
  mode=ModelBuilderMode.FLOAT,
96
- append2output=self.compare_points_float)
100
+ append2output=self.compare_points_float,
101
+ fw_info=self.fw_info)
97
102
 
98
103
  def has_compare_points(self) -> bool:
99
104
  """
@@ -89,18 +89,20 @@ class TensorboardWriter(object):
89
89
  Class to log events to display using Tensorboard such as graphs, histograms, images, etc.
90
90
  """
91
91
 
92
- def __init__(self, dir_path: str):
92
+ def __init__(self, dir_path: str, fw_info: FrameworkInfo):
93
93
  """
94
94
  Initialize a TensorboardWriter object.
95
95
 
96
96
  Args:
97
97
  dir_path: Path to save all events to display on Tensorboard.
98
+ fw_info: FrameworkInfo object (needed for computing nodes' weights memory).
98
99
 
99
100
  """
100
101
  self.dir_path = dir_path
101
102
  # we hold EventWriter per tag name, so events can be gathered by tags (like phases during the quantization
102
103
  # process).
103
104
  self.tag_name_to_event_writer = {}
105
+ self.fw_info = fw_info
104
106
 
105
107
  def close(self):
106
108
  """
@@ -207,7 +209,7 @@ class TensorboardWriter(object):
207
209
  attr = dict()
208
210
  if n.final_activation_quantization_cfg is not None:
209
211
  attr.update(n.final_activation_quantization_cfg.__dict__)
210
- elif n.quantization_cfg is not None:
212
+ elif n.candidates_quantization_cfg is not None:
211
213
  attr.update(n.get_unified_activation_candidates_dict())
212
214
  return attr
213
215
 
@@ -229,8 +231,8 @@ class TensorboardWriter(object):
229
231
  attr = dict()
230
232
  if n.final_weights_quantization_cfg is not None:
231
233
  attr.update(n.final_weights_quantization_cfg.__dict__)
232
- elif n.quantization_cfg is not None:
233
- attr.update(n.get_unified_weights_candidates_dict())
234
+ elif n.candidates_quantization_cfg is not None:
235
+ attr.update(n.get_unified_weights_candidates_dict(self.fw_info))
234
236
  return attr
235
237
 
236
238
  def __get_node_attr(n: BaseNode) -> Dict[str, Any]:
@@ -294,7 +296,7 @@ class TensorboardWriter(object):
294
296
 
295
297
  return NodeExecStats(node_name=n.name,
296
298
  memory=[AllocatorMemoryUsed(
297
- total_bytes=int(n.get_memory_bytes())
299
+ total_bytes=int(n.get_memory_bytes(self.fw_info))
298
300
  )])
299
301
 
300
302
  graph_def = GraphDef() # GraphDef to add to Tensorboard
@@ -524,12 +526,14 @@ class TensorboardWriter(object):
524
526
  er.add_event(event)
525
527
  er.flush()
526
528
 
527
-
528
- def init_tensorboard_writer() -> TensorboardWriter:
529
+ def init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter:
529
530
  """
530
531
  Create a TensorBoardWriter object initialized with the logger dir path if it was set,
531
532
  or None otherwise.
532
533
 
534
+ Args:
535
+ fw_info: FrameworkInfo object.
536
+
533
537
  Returns:
534
538
  A TensorBoardWriter object.
535
539
  """
@@ -537,7 +541,7 @@ def init_tensorboard_writer() -> TensorboardWriter:
537
541
  if Logger.LOG_PATH is not None:
538
542
  tb_log_dir = os.path.join(os.getcwd(), Logger.LOG_PATH, 'tensorboard_logs')
539
543
  Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}')
540
- tb_w = TensorboardWriter(tb_log_dir)
544
+ tb_w = TensorboardWriter(tb_log_dir, fw_info)
541
545
  return tb_w
542
546
 
543
547
 
@@ -16,27 +16,28 @@
16
16
 
17
17
  from typing import Callable, Any
18
18
 
19
+ from model_compression_toolkit.core.common import FrameworkInfo
19
20
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
+ from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator
20
22
  from model_compression_toolkit.core.common.graph.base_graph import Graph
21
23
  from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
22
24
  from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates
23
- from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG, \
24
- QuantizationErrorMethod
25
+ from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
25
26
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
26
- from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_manual_bitwidth_config
27
+ from model_compression_toolkit.core.common.quantization.set_node_quantization_config import \
28
+ set_quantization_configuration_to_graph
27
29
  from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
28
30
  from model_compression_toolkit.core.common.substitutions.linear_collapsing_substitution import \
29
31
  linear_collapsing_substitute
30
32
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
31
- from model_compression_toolkit.quantization_preparation.load_fqc import load_fqc_configuration
32
33
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
33
34
  FrameworkQuantizationCapabilities
34
- from model_compression_toolkit.logger import Logger
35
35
 
36
36
 
37
37
  def graph_preparation_runner(in_model: Any,
38
38
  representative_data_gen: Callable,
39
39
  quantization_config: QuantizationConfig,
40
+ fw_info: FrameworkInfo,
40
41
  fw_impl: FrameworkImplementation,
41
42
  fqc: FrameworkQuantizationCapabilities,
42
43
  bit_width_config: BitWidthConfig = None,
@@ -55,6 +56,8 @@ def graph_preparation_runner(in_model: Any,
55
56
  in_model (Any): Model to quantize.
56
57
  representative_data_gen (Callable): Dataset used for calibration.
57
58
  quantization_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be quantized.
59
+ fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices,
60
+ groups of layers by how they should be quantized, etc.).
58
61
  fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
59
62
  fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities object that models the inference target platform and
60
63
  the attached framework operator's information.
@@ -70,6 +73,7 @@ def graph_preparation_runner(in_model: Any,
70
73
  graph = read_model_to_graph(in_model,
71
74
  representative_data_gen,
72
75
  fqc,
76
+ fw_info,
73
77
  fw_impl)
74
78
 
75
79
  if tb_w is not None:
@@ -79,6 +83,7 @@ def graph_preparation_runner(in_model: Any,
79
83
  fqc,
80
84
  quantization_config,
81
85
  bit_width_config,
86
+ fw_info,
82
87
  tb_w,
83
88
  fw_impl,
84
89
  mixed_precision_enable=mixed_precision_enable,
@@ -91,6 +96,7 @@ def get_finalized_graph(initial_graph: Graph,
91
96
  fqc: FrameworkQuantizationCapabilities,
92
97
  quant_config: QuantizationConfig = DEFAULTCONFIG,
93
98
  bit_width_config: BitWidthConfig = None,
99
+ fw_info: FrameworkInfo = None,
94
100
  tb_w: TensorboardWriter = None,
95
101
  fw_impl: FrameworkImplementation = None,
96
102
  mixed_precision_enable: bool = False,
@@ -105,6 +111,8 @@ def get_finalized_graph(initial_graph: Graph,
105
111
  quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be
106
112
  quantized.
107
113
  bit_width_config (BitWidthConfig): Config for bit-width selection. Defaults to None.
114
+ fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g.,
115
+ kernel channels indices, groups of layers by how they should be quantized, etc.)
108
116
  tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
109
117
  fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
110
118
  mixed_precision_enable: is mixed precision enabled.
@@ -112,17 +120,11 @@ def get_finalized_graph(initial_graph: Graph,
112
120
 
113
121
  Returns: Graph object that represents the model, after applying all required modifications to it.
114
122
  """
115
- if quant_config.weights_error_method == QuantizationErrorMethod.HMSE:
116
- if not running_gptq:
117
- raise ValueError(f"The HMSE error method for parameters selection is only supported when running GPTQ "
118
- f"optimization due to long execution time that is not suitable for basic PTQ.")
119
- Logger.warning("Using the HMSE error method for weights quantization parameters search. "
120
- "Note: This method may significantly increase runtime during the parameter search process.")
121
123
 
122
124
  ######################################
123
125
  # Graph substitution (prepare graph)
124
126
  ######################################
125
- graph = substitute(initial_graph, fw_impl.get_substitutions_prepare_graph())
127
+ graph = substitute(initial_graph, fw_impl.get_substitutions_prepare_graph(fw_info))
126
128
 
127
129
  if tb_w is not None:
128
130
  tb_w.add_graph(graph, 'after_graph_preparation')
@@ -132,6 +134,7 @@ def get_finalized_graph(initial_graph: Graph,
132
134
  ##########################################
133
135
  for node in graph.nodes:
134
136
  node.prior_info = fw_impl.get_node_prior_info(node=node,
137
+ fw_info=fw_info,
135
138
  graph=graph)
136
139
 
137
140
  ##################################################
@@ -147,22 +150,28 @@ def get_finalized_graph(initial_graph: Graph,
147
150
  if tb_w is not None:
148
151
  tb_w.add_graph(transformed_graph, 'pre_statistics_collection_substitutions')
149
152
 
150
- transformed_graph = load_fqc_configuration(transformed_graph, fqc)
151
-
152
- # filter candidates per manual config
153
- if bit_width_config:
154
- set_manual_bitwidth_config(graph, bit_width_config)
153
+ ######################################
154
+ # Add quantization configurations
155
+ ######################################
156
+ transformed_graph = set_quantization_configuration_to_graph(graph=transformed_graph,
157
+ quant_config=quant_config,
158
+ bit_width_config=bit_width_config,
159
+ mixed_precision_enable=mixed_precision_enable,
160
+ running_gptq=running_gptq)
155
161
 
156
- # TODO irena: remove after base config is used
157
- for n in transformed_graph.nodes:
158
- if not mixed_precision_enable:
159
- n.quantization_cfg.candidates_quantization_cfg = [n.quantization_cfg.base_quantization_cfg]
162
+ ######################################
163
+ # Layer fusing
164
+ ######################################
165
+ fusing_info = FusingInfoGenerator(fqc.get_fusing_patterns()).generate_fusing_info(transformed_graph)
166
+ transformed_graph.fusing_info = fusing_info
167
+ transformed_graph.disable_fused_nodes_activation_quantization()
160
168
 
161
169
  ######################################
162
170
  # Channel equalization
163
171
  ######################################
164
172
  transformed_graph = substitute(transformed_graph,
165
- fw_impl.get_substitutions_channel_equalization(quant_config))
173
+ fw_impl.get_substitutions_channel_equalization(quant_config,
174
+ fw_info))
166
175
 
167
176
  if tb_w is not None:
168
177
  tb_w.add_graph(transformed_graph, 'after_graph_marking')
@@ -181,6 +190,7 @@ def get_finalized_graph(initial_graph: Graph,
181
190
  def read_model_to_graph(in_model: Any,
182
191
  representative_data_gen: Callable,
183
192
  fqc: FrameworkQuantizationCapabilities,
193
+ fw_info: FrameworkInfo = None,
184
194
  fw_impl: FrameworkImplementation = None) -> Graph:
185
195
 
186
196
  """
@@ -191,6 +201,8 @@ def read_model_to_graph(in_model: Any,
191
201
  representative_data_gen: Dataset used for calibration.
192
202
  fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
193
203
  the attached framework operator's information.
204
+ fw_info: Information needed for quantization about the specific framework (e.g.,
205
+ kernel channels indices, groups of layers by how they should be quantized, etc.)
194
206
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
195
207
 
196
208
  Returns:
@@ -198,5 +210,6 @@ def read_model_to_graph(in_model: Any,
198
210
  """
199
211
  graph = fw_impl.model_reader(in_model,
200
212
  representative_data_gen)
213
+ graph.set_fw_info(fw_info)
201
214
  graph.set_fqc(fqc)
202
215
  return graph
@@ -17,6 +17,7 @@ from typing import List
17
17
  from model_compression_toolkit.core import FrameworkInfo
18
18
  from model_compression_toolkit.core.common import BaseNode
19
19
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
20
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
20
21
  from model_compression_toolkit.core import common
21
22
  from tensorflow.python.util.object_identity import Reference as TFReference
22
23
 
@@ -28,17 +29,20 @@ class FloatKerasModelBuilder(KerasModelBuilder):
28
29
  def __init__(self,
29
30
  graph: common.Graph,
30
31
  append2output=None,
32
+ fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
31
33
  return_float_outputs: bool = False):
32
34
  """
33
35
 
34
36
  Args:
35
37
  graph: Graph to build the model from.
36
38
  append2output: Nodes to append to model's output.
39
+ fw_info: Information about the specific framework of the model that is built.
37
40
  return_float_outputs: Whether the model returns float tensors or not.
38
41
  """
39
42
 
40
43
  super().__init__(graph,
41
44
  append2output,
45
+ fw_info,
42
46
  return_float_outputs)
43
47
 
44
48
  def _quantize_node_activations(self,