mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250927.534__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -37,6 +37,7 @@ def _apply_gptq(gptq_config: GradientPTQConfig,
37
37
  tb_w: TensorboardWriter,
38
38
  tg: Graph,
39
39
  tg_bias: Graph,
40
+ fw_info: FrameworkInfo,
40
41
  fw_impl: FrameworkImplementation,
41
42
  hessian_info_service: HessianInfoService = None) -> Graph:
42
43
  """
@@ -51,6 +52,7 @@ def _apply_gptq(gptq_config: GradientPTQConfig,
51
52
  tb_w: TensorBoardWriter object to log events.
52
53
  tg: Float Reference Graph.
53
54
  tg_bias: Graph of quantized model.
55
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.).
54
56
  fw_impl: Framework implementation per framework
55
57
  hessian_info_service: HessianInfoService to fetch information based on the hessian approximation for the float model.
56
58
  Returns:
@@ -62,6 +64,7 @@ def _apply_gptq(gptq_config: GradientPTQConfig,
62
64
  gptq_config,
63
65
  representative_data_gen,
64
66
  fw_impl,
67
+ fw_info,
65
68
  hessian_info_service=hessian_info_service)
66
69
 
67
70
  if tb_w is not None:
@@ -74,6 +77,7 @@ def gptq_runner(tg: Graph,
74
77
  gptq_config: GradientPTQConfig,
75
78
  representative_data_gen: Callable,
76
79
  gptq_representative_data_gen: Callable,
80
+ fw_info: FrameworkInfo,
77
81
  fw_impl: FrameworkImplementation,
78
82
  tb_w: TensorboardWriter,
79
83
  hessian_info_service: HessianInfoService = None) -> Graph:
@@ -87,6 +91,7 @@ def gptq_runner(tg: Graph,
87
91
  gptq_config: GradientPTQConfig with parameters about the tuning process.
88
92
  representative_data_gen: Dataset used for calibration.
89
93
  gptq_representative_data_gen: Dataset used for GPTQ training
94
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.)
90
95
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
91
96
  tb_w: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
92
97
  hessian_info_service: HessianScoresService to fetch approximations of the hessian scores for the float model.
@@ -99,7 +104,7 @@ def gptq_runner(tg: Graph,
99
104
  #############################################
100
105
  # Apply Statistics Correction
101
106
  #############################################
102
- tg_bias = apply_statistics_correction(tg, representative_data_gen, core_config, fw_impl, tb_w)
107
+ tg_bias = apply_statistics_correction(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
103
108
 
104
109
  if tb_w is not None:
105
110
  tb_w.add_graph(tg_bias, 'after_bias_correction')
@@ -112,6 +117,7 @@ def gptq_runner(tg: Graph,
112
117
  tb_w,
113
118
  tg,
114
119
  tg_bias,
120
+ fw_info,
115
121
  fw_impl,
116
122
  hessian_info_service=hessian_info_service)
117
123
 
@@ -17,7 +17,6 @@ from typing import Callable, Tuple, Union
17
17
 
18
18
  from model_compression_toolkit import get_target_platform_capabilities
19
19
  from model_compression_toolkit.constants import TENSORFLOW
20
- from model_compression_toolkit.quantization_preparation.load_fqc import load_fqc_configuration
21
20
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
22
21
  from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
23
22
  from model_compression_toolkit.verify_packages import FOUND_TF
@@ -25,8 +24,10 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
25
24
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
26
25
  from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
27
26
  from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
27
+ from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
28
28
  from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
29
29
  from model_compression_toolkit.logger import Logger
30
+ from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
30
31
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
31
32
 
32
33
  if FOUND_TF:
@@ -34,12 +35,11 @@ if FOUND_TF:
34
35
  AttachTpcToKeras
35
36
  from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
36
37
  from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
37
- from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
38
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
38
39
  from tensorflow.keras.models import Model
39
40
 
40
41
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
41
42
 
42
- @set_keras_info
43
43
  def keras_pruning_experimental(model: Model,
44
44
  target_resource_utilization: ResourceUtilization,
45
45
  representative_data_gen: Callable,
@@ -116,25 +116,30 @@ if FOUND_TF:
116
116
 
117
117
  target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
118
118
  # Attach tpc model to framework
119
- framework_platform_capabilities = AttachTpcToKeras().attach(target_platform_capabilities)
119
+ attach2keras = AttachTpcToKeras()
120
+ target_platform_capabilities = attach2keras.attach(target_platform_capabilities)
120
121
 
121
122
  # Convert the original Keras model to an internal graph representation.
122
123
  float_graph = read_model_to_graph(model,
123
124
  representative_data_gen,
124
- framework_platform_capabilities,
125
+ target_platform_capabilities,
126
+ DEFAULT_KERAS_INFO,
125
127
  fw_impl)
126
128
 
127
129
  # Apply quantization configuration to the graph. This step is necessary even when not quantizing,
128
130
  # as it prepares the graph for the pruning process.
129
- float_graph_with_compression_config = load_fqc_configuration(float_graph, framework_platform_capabilities)
131
+ float_graph_with_compression_config = set_quantization_configuration_to_graph(float_graph,
132
+ quant_config=DEFAULTCONFIG,
133
+ mixed_precision_enable=False)
130
134
 
131
135
  # Create a Pruner object with the graph and configuration.
132
136
  pruner = Pruner(float_graph_with_compression_config,
137
+ DEFAULT_KERAS_INFO,
133
138
  fw_impl,
134
139
  target_resource_utilization,
135
140
  representative_data_gen,
136
141
  pruning_config,
137
- framework_platform_capabilities)
142
+ target_platform_capabilities)
138
143
 
139
144
  # Apply the pruning process.
140
145
  pruned_graph = pruner.prune_graph()
@@ -23,9 +23,10 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
23
23
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
24
24
  from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
25
25
  from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
26
- from model_compression_toolkit.quantization_preparation.load_fqc import load_fqc_configuration
26
+ from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
27
27
  from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
28
28
  from model_compression_toolkit.logger import Logger
29
+ from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
29
30
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
30
31
 
31
32
 
@@ -35,7 +36,7 @@ if FOUND_TORCH:
35
36
  from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
36
37
  from model_compression_toolkit.core.pytorch.pruning.pruning_pytorch_implementation import \
37
38
  PruningPytorchImplementation
38
- from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
39
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
39
40
  from torch.nn import Module
40
41
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
41
42
  AttachTpcToPytorch
@@ -43,7 +44,6 @@ if FOUND_TORCH:
43
44
  # Set the default Target Platform Capabilities (TPC) for PyTorch.
44
45
  DEFAULT_PYOTRCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
45
46
 
46
- @set_pytorch_info
47
47
  def pytorch_pruning_experimental(model: Module,
48
48
  target_resource_utilization: ResourceUtilization,
49
49
  representative_data_gen: Callable,
@@ -129,14 +129,18 @@ if FOUND_TORCH:
129
129
  float_graph = read_model_to_graph(model,
130
130
  representative_data_gen,
131
131
  framework_platform_capabilities,
132
+ DEFAULT_PYTORCH_INFO,
132
133
  fw_impl)
133
134
 
134
135
  # Apply quantization configuration to the graph. This step is necessary even when not quantizing,
135
136
  # as it prepares the graph for the pruning process.
136
- float_graph_with_compression_config = load_fqc_configuration(float_graph, framework_platform_capabilities)
137
+ float_graph_with_compression_config = set_quantization_configuration_to_graph(float_graph,
138
+ quant_config=DEFAULTCONFIG,
139
+ mixed_precision_enable=False)
137
140
 
138
141
  # Create a Pruner object with the graph and configuration.
139
142
  pruner = Pruner(float_graph_with_compression_config,
143
+ DEFAULT_PYTORCH_INFO,
140
144
  fw_impl,
141
145
  target_resource_utilization,
142
146
  representative_data_gen,
@@ -36,8 +36,9 @@ from model_compression_toolkit.metadata import create_model_metadata
36
36
  if FOUND_TF:
37
37
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
38
38
  AttachTpcToKeras
39
- from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
39
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
40
40
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
41
+ from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
41
42
  from tensorflow.keras.models import Model
42
43
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
43
44
  from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
@@ -48,7 +49,6 @@ if FOUND_TF:
48
49
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
49
50
 
50
51
 
51
- @set_keras_info
52
52
  def keras_post_training_quantization(in_model: Model,
53
53
  representative_data_gen: Callable,
54
54
  target_resource_utilization: ResourceUtilization = None,
@@ -121,20 +121,25 @@ if FOUND_TF:
121
121
 
122
122
  >>> quantized_model, quantization_info = mct.ptq.keras_post_training_quantization(model, repr_datagen, ru, core_config=config)
123
123
 
124
- For more configuration options, please take a look at our `API documentation <https://sonysemiconductorsolutions.github.io/mct-model-optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
124
+ For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
125
125
 
126
126
  """
127
127
 
128
128
  if core_config.debug_config.bypass:
129
129
  return in_model, None
130
130
 
131
+ fw_info = DEFAULT_KERAS_INFO
132
+
133
+ KerasModelValidation(model=in_model,
134
+ fw_info=fw_info).validate()
135
+
131
136
  if core_config.is_mixed_precision_enabled:
132
137
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
133
138
  Logger.critical("Given quantization config to mixed-precision facade is not of type "
134
139
  "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
135
140
  "API, or pass a valid mixed precision configuration.") # pragma: no cover
136
141
 
137
- tb_w = init_tensorboard_writer()
142
+ tb_w = init_tensorboard_writer(fw_info)
138
143
 
139
144
  fw_impl = KerasImplementation()
140
145
 
@@ -148,6 +153,7 @@ if FOUND_TF:
148
153
  tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_model,
149
154
  representative_data_gen=representative_data_gen,
150
155
  core_config=core_config,
156
+ fw_info=fw_info,
151
157
  fw_impl=fw_impl,
152
158
  fqc=framework_platform_capabilities,
153
159
  target_resource_utilization=target_resource_utilization,
@@ -163,6 +169,7 @@ if FOUND_TF:
163
169
  graph_with_stats_correction = ptq_runner(tg,
164
170
  representative_data_gen,
165
171
  core_config,
172
+ fw_info,
166
173
  fw_impl,
167
174
  tb_w)
168
175
 
@@ -172,7 +179,8 @@ if FOUND_TF:
172
179
  tb_w,
173
180
  similarity_baseline_graph,
174
181
  quantized_graph,
175
- fw_impl)
182
+ fw_impl,
183
+ fw_info)
176
184
 
177
185
  exportable_model, user_info = get_exportable_keras_model(graph_with_stats_correction)
178
186
  if framework_platform_capabilities.tpc.add_metadata:
@@ -34,7 +34,7 @@ from model_compression_toolkit.core.common.quantization.quantize_graph_weights i
34
34
  from model_compression_toolkit.metadata import create_model_metadata
35
35
 
36
36
  if FOUND_TORCH:
37
- from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
37
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
38
38
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
39
39
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
40
40
  from torch.nn import Module
@@ -46,7 +46,6 @@ if FOUND_TORCH:
46
46
 
47
47
  DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
48
48
 
49
- @set_pytorch_info
50
49
  def pytorch_post_training_quantization(in_module: Module,
51
50
  representative_data_gen: Callable,
52
51
  target_resource_utilization: ResourceUtilization = None,
@@ -103,6 +102,8 @@ if FOUND_TORCH:
103
102
  if core_config.debug_config.bypass:
104
103
  return in_module, None
105
104
 
105
+ fw_info = DEFAULT_PYTORCH_INFO
106
+
106
107
  if core_config.is_mixed_precision_enabled:
107
108
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
108
109
  Logger.critical("Given quantization config to mixed-precision facade is not of type "
@@ -110,7 +111,7 @@ if FOUND_TORCH:
110
111
  "pytorch_post_training_quantization API, or pass a valid mixed precision "
111
112
  "configuration.") # pragma: no cover
112
113
 
113
- tb_w = init_tensorboard_writer()
114
+ tb_w = init_tensorboard_writer(fw_info)
114
115
 
115
116
  fw_impl = PytorchImplementation()
116
117
 
@@ -124,6 +125,7 @@ if FOUND_TORCH:
124
125
  tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_module,
125
126
  representative_data_gen=representative_data_gen,
126
127
  core_config=core_config,
128
+ fw_info=fw_info,
127
129
  fw_impl=fw_impl,
128
130
  fqc=framework_platform_capabilities,
129
131
  target_resource_utilization=target_resource_utilization,
@@ -139,6 +141,7 @@ if FOUND_TORCH:
139
141
  graph_with_stats_correction = ptq_runner(tg,
140
142
  representative_data_gen,
141
143
  core_config,
144
+ fw_info,
142
145
  fw_impl,
143
146
  tb_w)
144
147
 
@@ -148,7 +151,8 @@ if FOUND_TORCH:
148
151
  tb_w,
149
152
  similarity_baseline_graph,
150
153
  quantized_graph,
151
- fw_impl)
154
+ fw_impl,
155
+ fw_info)
152
156
 
153
157
  exportable_model, user_info = get_exportable_pytorch_model(graph_with_stats_correction)
154
158
  if framework_platform_capabilities.tpc.add_metadata:
@@ -16,6 +16,7 @@
16
16
 
17
17
  from typing import Callable
18
18
 
19
+ from model_compression_toolkit.core.common import FrameworkInfo
19
20
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
20
21
  from model_compression_toolkit.core.common.graph.base_graph import Graph
21
22
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
@@ -27,6 +28,7 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
27
28
  def ptq_runner(tg: Graph,
28
29
  representative_data_gen: Callable,
29
30
  core_config: CoreConfig,
31
+ fw_info: FrameworkInfo,
30
32
  fw_impl: FrameworkImplementation,
31
33
  tb_w: TensorboardWriter) -> Graph:
32
34
  """
@@ -36,6 +38,7 @@ def ptq_runner(tg: Graph,
36
38
  tg: Graph to apply PTQ and to quantize.
37
39
  representative_data_gen (Callable): Dataset used for calibration.
38
40
  core_config: CoreConfig containing parameters of how the model should be quantized.
41
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
39
42
  groups of layers by how they should be quantized, etc.)
40
43
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
41
44
  tb_w: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
@@ -47,5 +50,5 @@ def ptq_runner(tg: Graph,
47
50
  #############################################
48
51
  # Statistics Correction
49
52
  #############################################
50
- tg = apply_statistics_correction(tg, representative_data_gen, core_config, fw_impl, tb_w)
53
+ tg = apply_statistics_correction(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
51
54
  return tg
@@ -19,17 +19,21 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
19
19
  from model_compression_toolkit.trainable_infrastructure import TrainingMethod
20
20
 
21
21
 
22
- def is_qat_applicable(node: common.BaseNode) -> bool:
22
+ def is_qat_applicable(node: common.BaseNode,
23
+ fw_info: FrameworkInfo) -> bool:
23
24
  """
24
25
  A function for deciding if a layer should be fine-tuned during QAT
25
26
 
26
27
  Args:
27
28
  node (BaseNode): Node for quantization decision
29
+ fw_info (FrameworkInfo): Pytorch quantization information
28
30
 
29
31
  Returns:
30
32
  A boolean whether the layer is to be wrapped with a QuantizeWrapper
31
33
  """
32
- return (node.kernel_attr is not None and node.is_weights_quantization_enabled(node.kernel_attr)) \
34
+
35
+ kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
36
+ return (kernel_attr is not None and node.is_weights_quantization_enabled(kernel_attr)) \
33
37
  or node.is_activation_quantization_enabled()
34
38
 
35
39
 
@@ -37,9 +37,10 @@ if FOUND_TF:
37
37
  from tensorflow.keras.models import Model
38
38
 
39
39
  from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
40
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
40
41
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
42
+ from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
41
43
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
42
- from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
43
44
 
44
45
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
45
46
 
@@ -51,6 +52,7 @@ if FOUND_TF:
51
52
  from model_compression_toolkit.constants import TENSORFLOW
52
53
  from model_compression_toolkit.qat.common.qat_config import is_qat_applicable
53
54
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
55
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
54
56
  from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder, \
55
57
  get_activation_quantizer_holder
56
58
  from model_compression_toolkit.qat.common.qat_config import QATConfig
@@ -71,11 +73,11 @@ if FOUND_TF:
71
73
  Returns: Wrapped layer
72
74
 
73
75
  """
74
- if is_qat_applicable(n):
76
+ if is_qat_applicable(n, DEFAULT_KERAS_INFO):
75
77
  # If we are here, then the node has a kernel attribute to quantize and training during QAT
76
78
  weights_quantizers, _ = quantization_builder(n,
77
79
  qat_config,
78
- n.kernel_attr)
80
+ DEFAULT_KERAS_INFO.get_kernel_op_attributes(n.type)[0])
79
81
  if len(weights_quantizers) > 0:
80
82
  layer.trainable = True
81
83
  return KerasTrainableQuantizationWrapper(layer, weights_quantizers)
@@ -85,7 +87,6 @@ if FOUND_TF:
85
87
  return layer
86
88
 
87
89
 
88
- @set_keras_info
89
90
  def keras_quantization_aware_training_init_experimental(in_model: Model,
90
91
  representative_data_gen: Callable,
91
92
  target_resource_utilization: ResourceUtilization = None,
@@ -166,7 +167,7 @@ if FOUND_TF:
166
167
 
167
168
  >>> quantized_model = tf.keras.models.load_model(model_file, custom_objects=custom_objects)
168
169
 
169
- For more configuration options, please take a look at our `API documentation <https://sonysemiconductorsolutions.github.io/mct-model-optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
170
+ For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
170
171
 
171
172
  """
172
173
 
@@ -174,13 +175,16 @@ if FOUND_TF:
174
175
  f"If you encounter an issue, please open an issue in our GitHub "
175
176
  f"project https://github.com/sony/model_optimization")
176
177
 
178
+ KerasModelValidation(model=in_model,
179
+ fw_info=DEFAULT_KERAS_INFO).validate()
180
+
177
181
  if core_config.is_mixed_precision_enabled:
178
182
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
179
183
  Logger.critical("Given quantization config to mixed-precision facade is not of type "
180
184
  "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization API,"
181
185
  "or pass a valid mixed precision configuration.")
182
186
 
183
- tb_w = init_tensorboard_writer()
187
+ tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)
184
188
 
185
189
  fw_impl = KerasImplementation()
186
190
 
@@ -194,15 +198,17 @@ if FOUND_TF:
194
198
  tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
195
199
  representative_data_gen=representative_data_gen,
196
200
  core_config=core_config,
201
+ fw_info=DEFAULT_KERAS_INFO,
197
202
  fw_impl=fw_impl,
198
203
  fqc=target_platform_capabilities,
199
204
  target_resource_utilization=target_resource_utilization,
200
205
  tb_w=tb_w)
201
206
 
202
- tg = ptq_runner(tg, representative_data_gen, core_config, fw_impl, tb_w)
207
+ tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_KERAS_INFO, fw_impl, tb_w)
203
208
 
204
209
  _qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
205
210
  qat_model, user_info = KerasModelBuilder(graph=tg,
211
+ fw_info=DEFAULT_KERAS_INFO,
206
212
  wrapper=_qat_wrapper,
207
213
  get_activation_quantizer_holder_fn=partial(get_activation_quantizer_holder,
208
214
  qat_config=qat_config)).build_model()
@@ -36,7 +36,7 @@ if FOUND_TORCH:
36
36
  import torch.nn as nn
37
37
  from torch.nn import Module
38
38
  from mct_quantizers import PytorchActivationQuantizationHolder
39
- from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
39
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
40
40
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
41
41
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
42
42
  from model_compression_toolkit.qat.common.qat_config import is_qat_applicable
@@ -62,10 +62,10 @@ if FOUND_TORCH:
62
62
  Returns: Wrapped layer
63
63
 
64
64
  """
65
- if is_qat_applicable(n):
65
+ if is_qat_applicable(n, DEFAULT_PYTORCH_INFO):
66
66
  # If we are here, then the node has a kernel attribute to quantize and training during QAT
67
67
  weights_quantizers, _ = quantization_builder(n, qat_config,
68
- n.kernel_attr)
68
+ DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(n.type)[0])
69
69
  if len(weights_quantizers) > 0:
70
70
  return PytorchQuantizationWrapper(module, weights_quantizers)
71
71
 
@@ -74,7 +74,6 @@ if FOUND_TORCH:
74
74
  return module
75
75
 
76
76
 
77
- @set_pytorch_info
78
77
  def pytorch_quantization_aware_training_init_experimental(in_model: Module,
79
78
  representative_data_gen: Callable,
80
79
  target_resource_utilization: ResourceUtilization = None,
@@ -136,7 +135,7 @@ if FOUND_TORCH:
136
135
 
137
136
  >>> quantized_model, quantization_info = mct.qat.pytorch_quantization_aware_training_init_experimental(model, repr_datagen, core_config=config)
138
137
 
139
- For more configuration options, please take a look at our `API documentation <https://sonysemiconductorsolutions.github.io/mct-model-optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
138
+ For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
140
139
 
141
140
  """
142
141
  Logger.warning(
@@ -150,7 +149,7 @@ if FOUND_TORCH:
150
149
  "MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API,"
151
150
  "or pass a valid mixed precision configuration.")
152
151
 
153
- tb_w = init_tensorboard_writer()
152
+ tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
154
153
  fw_impl = PytorchImplementation()
155
154
 
156
155
  target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
@@ -163,16 +162,18 @@ if FOUND_TORCH:
163
162
  tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
164
163
  representative_data_gen=representative_data_gen,
165
164
  core_config=core_config,
165
+ fw_info=DEFAULT_PYTORCH_INFO,
166
166
  fw_impl=fw_impl,
167
167
  fqc=framework_platform_capabilities,
168
168
  target_resource_utilization=target_resource_utilization,
169
169
  tb_w=tb_w)
170
170
 
171
- tg = ptq_runner(tg, representative_data_gen, core_config, fw_impl, tb_w)
171
+ tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_PYTORCH_INFO, fw_impl, tb_w)
172
172
 
173
173
  _qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
174
174
 
175
175
  qat_model, user_info = PyTorchModelBuilder(graph=tg,
176
+ fw_info=DEFAULT_PYTORCH_INFO,
176
177
  wrapper=_qat_wrapper,
177
178
  get_activation_quantizer_holder_fn=partial(
178
179
  get_activation_quantizer_holder,
@@ -180,6 +181,9 @@ if FOUND_TORCH:
180
181
 
181
182
  user_info.mixed_precision_cfg = bit_widths_config
182
183
 
184
+ # Remove fw_info from graph to enable saving the pytorch model (fw_info can not be pickled)
185
+ delattr(qat_model.graph, 'fw_info')
186
+
183
187
  return qat_model, user_info
184
188
 
185
189
 
@@ -29,7 +29,7 @@ QNNPACK_TP_MODEL = 'qnnpack'
29
29
  # TP Attributes
30
30
  KERNEL_ATTR = "kernel_attr"
31
31
  BIAS_ATTR = "bias_attr"
32
- POSITIONAL_ATTR = "pos_attr"
32
+ POS_ATTR = "pos_attr"
33
33
 
34
34
  # TODO: this is duplicated from the core frameworks constants files, because the original consts can't be used here
35
35
  # duo to circular dependency. It might be best to extract the constants from the core file and put them here (in a
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
1
+ # Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -32,7 +32,7 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
32
32
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
33
33
  AttachTpcToFramework
34
34
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import Eq
35
- from edgemdt_cl.pytorch import MulticlassNMS, MulticlassNMSWithIndices
35
+ from edgemdt_cl.pytorch import MulticlassNMS, MulticlassNMSWithIndices, MulticlassNMSOBB
36
36
 
37
37
 
38
38
  class AttachTpcToPytorch(AttachTpcToFramework):
@@ -98,7 +98,7 @@ class AttachTpcToPytorch(AttachTpcToFramework):
98
98
  OperatorSetNames.L2NORM: [LayerFilterParams(torch.nn.functional.normalize,
99
99
  Eq('p', 2) | Eq('p', None))],
100
100
  OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
101
- OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [MulticlassNMS, MulticlassNMSWithIndices],
101
+ OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [MulticlassNMS, MulticlassNMSWithIndices, MulticlassNMSOBB],
102
102
  OperatorSetNames.EXP: [torch.exp],
103
103
  OperatorSetNames.SIN: [torch.sin],
104
104
  OperatorSetNames.COS: [torch.cos],
@@ -48,6 +48,7 @@ def get_trainable_quantizer_weights_config(
48
48
  final_attr_cfg.enable_weights_quantization,
49
49
  final_attr_cfg.weights_channels_axis[0], # Output channel axis
50
50
  final_attr_cfg.weights_per_channel_threshold,
51
+ final_node_cfg.min_threshold,
51
52
  weights_quantization_candidates)
52
53
 
53
54
 
@@ -75,6 +76,7 @@ def get_trainable_quantizer_activation_config(
75
76
  final_cfg.activation_n_bits,
76
77
  final_cfg.activation_quantization_params,
77
78
  final_cfg.enable_activation_quantization,
79
+ final_cfg.min_threshold,
78
80
  activation_quantization_candidates)
79
81
 
80
82
 
@@ -44,6 +44,7 @@ class TrainableQuantizerActivationConfig:
44
44
  activation_n_bits: int,
45
45
  activation_quantization_params: Dict,
46
46
  enable_activation_quantization: bool,
47
+ min_threshold: float,
47
48
  activation_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None,
48
49
  ):
49
50
  """
@@ -54,11 +55,13 @@ class TrainableQuantizerActivationConfig:
54
55
  activation_n_bits (int): Number of bits to quantize the activations.
55
56
  activation_quantization_params (Dict): Dictionary that contains activation quantization params.
56
57
  enable_activation_quantization (bool): Whether to quantize the layer's activations or not.
58
+ min_threshold (float): Minimum threshold to use during thresholds selection.
57
59
  """
58
60
  self.activation_quantization_method = activation_quantization_method
59
61
  self.activation_n_bits = activation_n_bits
60
62
  self.activation_quantization_params = activation_quantization_params
61
63
  self.enable_activation_quantization = enable_activation_quantization
64
+ self.min_threshold = min_threshold
62
65
  self.activation_bits_candidates = activation_quantization_candidates
63
66
 
64
67
 
@@ -70,6 +73,7 @@ class TrainableQuantizerWeightsConfig:
70
73
  enable_weights_quantization: bool,
71
74
  weights_channels_axis: int,
72
75
  weights_per_channel_threshold: bool,
76
+ min_threshold: float,
73
77
  weights_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None,
74
78
  ):
75
79
  """
@@ -82,6 +86,7 @@ class TrainableQuantizerWeightsConfig:
82
86
  enable_weights_quantization (bool): Whether to quantize the layer's weights or not.
83
87
  weights_channels_axis (int): Axis to quantize a node's kernel when quantizing per-channel.
84
88
  weights_per_channel_threshold (bool): Whether to quantize the weights per-channel or not (per-tensor).
89
+ min_threshold (float): Minimum threshold to use during thresholds selection.
85
90
  """
86
91
  self.weights_quantization_method = weights_quantization_method
87
92
  self.weights_n_bits = weights_n_bits
@@ -89,4 +94,5 @@ class TrainableQuantizerWeightsConfig:
89
94
  self.enable_weights_quantization = enable_weights_quantization
90
95
  self.weights_channels_axis = weights_channels_axis
91
96
  self.weights_per_channel_threshold = weights_per_channel_threshold
97
+ self.min_threshold = min_threshold
92
98
  self.weights_bits_candidates = weights_quantization_candidates
@@ -77,11 +77,13 @@ def config_deserialization(in_config: dict) -> Union[TrainableQuantizerWeightsCo
77
77
  weights_quantization_params=weights_quantization_params,
78
78
  enable_weights_quantization=in_config[C.ENABLE_WEIGHTS_QUANTIZATION],
79
79
  weights_channels_axis=in_config[C.WEIGHTS_CHANNELS_AXIS],
80
- weights_per_channel_threshold=in_config[C.WEIGHTS_PER_CHANNEL_THRESHOLD])
80
+ weights_per_channel_threshold=in_config[C.WEIGHTS_PER_CHANNEL_THRESHOLD],
81
+ min_threshold=in_config[C.MIN_THRESHOLD])
81
82
  elif in_config[C.IS_ACTIVATIONS]:
82
83
  return TrainableQuantizerActivationConfig(activation_quantization_method=QuantizationMethod(in_config[C.ACTIVATION_QUANTIZATION_METHOD]),
83
84
  activation_n_bits=in_config[C.ACTIVATION_N_BITS],
84
85
  activation_quantization_params=in_config[C.ACTIVATION_QUANTIZATION_PARAMS],
85
- enable_activation_quantization=in_config[C.ENABLE_ACTIVATION_QUANTIZATION])
86
+ enable_activation_quantization=in_config[C.ENABLE_ACTIVATION_QUANTIZATION],
87
+ min_threshold=in_config[C.MIN_THRESHOLD])
86
88
  else:
87
89
  raise NotImplemented # pragma: no cover
@@ -16,4 +16,5 @@
16
16
  from model_compression_toolkit.xquant.common.xquant_config import XQuantConfig
17
17
  from model_compression_toolkit.xquant.keras.facade_xquant_report import xquant_report_keras_experimental
18
18
  from model_compression_toolkit.xquant.pytorch.facade_xquant_report import xquant_report_pytorch_experimental
19
+ from model_compression_toolkit.xquant.pytorch.facade_xquant_report import xquant_report_troubleshoot_pytorch_experimental
19
20