mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250927.534__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -0,0 +1,177 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from tqdm import tqdm
16
+ from typing import Callable, Any, Dict
17
+
18
+ from model_compression_toolkit.core.common.model_collector import ModelCollector
19
+ from model_compression_toolkit.xquant import XQuantConfig
20
+ from model_compression_toolkit.xquant.common.constants import OUTPUT_SIMILARITY_METRICS_REPR, OUTPUT_SIMILARITY_METRICS_VAL, INTERMEDIATE_SIMILARITY_METRICS_REPR, \
21
+ INTERMEDIATE_SIMILARITY_METRICS_VAL
22
+ from model_compression_toolkit.xquant.pytorch.framework_report_utils import FrameworkReportUtils
23
+
24
+ from model_compression_toolkit.xquant.pytorch.core_detect_degrade_layer import core_detect_degrade_layer
25
+ from model_compression_toolkit.xquant.pytorch.core_judge_troubleshoot import core_judge_troubleshoot
26
+
27
+ def core_report_generator(float_model: Any,
28
+ quantized_model: Any,
29
+ repr_dataset: Callable,
30
+ validation_dataset: Callable,
31
+ fw_report_utils: FrameworkReportUtils,
32
+ xquant_config: XQuantConfig) -> Dict[str, Any]:
33
+ """
34
+ Generate report in tensorboard with a graph of the quantized model and similarity metrics that
35
+ have been measured when comparing to the float model (or any other two models).
36
+ The report also contains histograms that are collected on the baseline model (usually, the float
37
+ model).
38
+
39
+ Args:
40
+ float_model (Any): The original floating-point model.
41
+ quantized_model (Any): The model after quantization.
42
+ repr_dataset (Callable): Representative dataset used for similarity metrics computation.
43
+ validation_dataset (Callable): Validation dataset used for similarity metrics computation.
44
+ fw_report_utils (FrameworkReportUtils): Utilities for generating framework-specific reports.
45
+ xquant_config (XQuantConfig): Configuration settings for explainable quantization.
46
+
47
+ Returns:
48
+ Dict[str, Any]: A dictionary containing the collected similarity metrics and report data.
49
+ """
50
+ # Get metadata from the quantized model
51
+ quantized_model_metadata = fw_report_utils.get_metadata_fn(quantized_model)
52
+
53
+ # Collect histograms on the float model.
54
+ float_graph = fw_report_utils.model_folding_utils.create_float_folded_graph(float_model, repr_dataset)
55
+ mi = ModelCollector(float_graph, fw_report_utils.fw_impl, fw_report_utils.fw_info)
56
+ for _data in tqdm(repr_dataset(), desc="Collecting Histograms"):
57
+ mi.infer(_data)
58
+
59
+ # Collect histograms and add them to Tensorboard.
60
+ fw_report_utils.tb_utils.add_histograms_to_tensorboard(graph=float_graph)
61
+
62
+ # Compute similarity metrics on representative dataset and validation set.
63
+ repr_similarity = fw_report_utils.similarity_calculator.compute_similarity_metrics(float_model=float_model,
64
+ quantized_model=quantized_model,
65
+ dataset=repr_dataset,
66
+ custom_similarity_metrics=xquant_config.custom_similarity_metrics)
67
+ val_similarity = fw_report_utils.similarity_calculator.compute_similarity_metrics(float_model=float_model,
68
+ quantized_model=quantized_model,
69
+ dataset=validation_dataset,
70
+ custom_similarity_metrics=xquant_config.custom_similarity_metrics,
71
+ is_validation=True)
72
+ similarity_metrics = {
73
+ OUTPUT_SIMILARITY_METRICS_REPR: repr_similarity[0],
74
+ OUTPUT_SIMILARITY_METRICS_VAL: val_similarity[0],
75
+ INTERMEDIATE_SIMILARITY_METRICS_REPR: repr_similarity[1],
76
+ INTERMEDIATE_SIMILARITY_METRICS_VAL: val_similarity[1]
77
+ }
78
+
79
+ # Add a graph of the quantized model with the similarity metrics to TensorBoard for visualization.
80
+ fw_report_utils.tb_utils.add_graph_to_tensorboard(quantized_model,
81
+ similarity_metrics,
82
+ repr_dataset,
83
+ quantized_model_metadata)
84
+
85
+ # Adds text information (like max cut and output similarity metrics) to the tensorboard writer.
86
+ fw_report_utils.tb_utils.add_text_information(similarity_metrics,
87
+ quantized_model_metadata)
88
+
89
+ # Save data to a json file.
90
+ fw_report_utils.dump_report_to_json(report_dir=xquant_config.report_dir,
91
+ collected_data=similarity_metrics)
92
+
93
+ return similarity_metrics
94
+
95
+ def core_report_generator_troubleshoot(float_model: Any,
96
+ quantized_model: Any,
97
+ repr_dataset: Callable,
98
+ validation_dataset: Callable,
99
+ fw_report_utils: FrameworkReportUtils,
100
+ xquant_config: XQuantConfig) -> Dict[str, Any]:
101
+ """
102
+ Generate report in tensorboard with a graph of the quantized model and similarity metrics that
103
+ have been measured when comparing to the float model (or any other two models).
104
+ The report also contains histograms that are collected on the baseline model (usually, the float
105
+ model).
106
+ Furthermore, detect degraded layers by quantize similarities and analyze degrade causes by float_model and quantized_model.
107
+
108
+ Args:
109
+ float_model (Any): The original floating-point model.
110
+ quantized_model (Any): The model after quantization.
111
+ repr_dataset (Callable): Representative dataset used for similarity metrics computation.
112
+ validation_dataset (Callable): Validation dataset used for similarity metrics computation.
113
+ fw_report_utils (FrameworkReportUtils): Utilities for generating framework-specific reports.
114
+ xquant_config (XQuantConfig): Configuration settings for explainable quantization.
115
+
116
+ Returns:
117
+ Dict[str, Any]: A dictionary containing the collected similarity metrics and report data.
118
+ Dict[str, Any]: A dictionary containing the analyze degrade cause report for degraded layaers.
119
+ """
120
+ # Get metadata from the quantized model
121
+ quantized_model_metadata = fw_report_utils.get_metadata_fn(quantized_model)
122
+
123
+ # Collect histograms on the float model.
124
+ float_graph = fw_report_utils.model_folding_utils.create_float_folded_graph(float_model, repr_dataset)
125
+ mi = ModelCollector(float_graph, fw_report_utils.fw_impl, fw_report_utils.fw_info)
126
+ for _data in tqdm(repr_dataset(), desc="Collecting Histograms"):
127
+ mi.infer(_data)
128
+
129
+ # Collect histograms and add them to Tensorboard.
130
+ fw_report_utils.tb_utils.add_histograms_to_tensorboard(graph=float_graph)
131
+
132
+ # Compute similarity metrics on representative dataset and validation set.
133
+ repr_similarity = fw_report_utils.similarity_calculator.compute_similarity_metrics(float_model=float_model,
134
+ quantized_model=quantized_model,
135
+ dataset=repr_dataset,
136
+ custom_similarity_metrics=xquant_config.custom_similarity_metrics)
137
+ val_similarity = fw_report_utils.similarity_calculator.compute_similarity_metrics(float_model=float_model,
138
+ quantized_model=quantized_model,
139
+ dataset=validation_dataset,
140
+ custom_similarity_metrics=xquant_config.custom_similarity_metrics,
141
+ is_validation=True)
142
+ similarity_metrics = {
143
+ OUTPUT_SIMILARITY_METRICS_REPR: repr_similarity[0],
144
+ OUTPUT_SIMILARITY_METRICS_VAL: val_similarity[0],
145
+ INTERMEDIATE_SIMILARITY_METRICS_REPR: repr_similarity[1],
146
+ INTERMEDIATE_SIMILARITY_METRICS_VAL: val_similarity[1]
147
+ }
148
+
149
+ # Add a graph of the quantized model with the similarity metrics to TensorBoard for visualization.
150
+ fw_report_utils.tb_utils.add_graph_to_tensorboard(quantized_model,
151
+ similarity_metrics,
152
+ repr_dataset,
153
+ quantized_model_metadata)
154
+
155
+ # Adds text information (like max cut and output similarity metrics) to the tensorboard writer.
156
+ fw_report_utils.tb_utils.add_text_information(similarity_metrics,
157
+ quantized_model_metadata)
158
+
159
+ # Save data to a json file.
160
+ fw_report_utils.dump_report_to_json(report_dir=xquant_config.report_dir,
161
+ collected_data=similarity_metrics)
162
+
163
+ # Detect degraded layers by quantize similarities
164
+ degrade_layers = core_detect_degrade_layer(repr_similarity, val_similarity, xquant_config)
165
+
166
+ # Analyze degrade causes by float_model and quantized_model. (If degrade_layer_ratio < threshold)
167
+ degrade_layer_ratio = len(degrade_layers) / len(repr_similarity[1].keys())
168
+ if(degrade_layer_ratio < xquant_config.threshold_degrade_layer_ratio):
169
+ _troubleshoot_data = core_judge_troubleshoot(float_model, quantized_model, float_graph, degrade_layers, repr_dataset, xquant_config)
170
+ else:
171
+ _troubleshoot_data = {}
172
+
173
+ # Save data to a json file.
174
+ fw_report_utils.dump_troubleshoot_report_to_json(report_dir=xquant_config.report_dir,
175
+ collected_data=_troubleshoot_data)
176
+
177
+ return similarity_metrics, _troubleshoot_data
@@ -0,0 +1,78 @@
1
+ # Copyright 2025 Sony Semiconductor Solutions. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from model_compression_toolkit.xquant import XQuantConfig
17
+ import matplotlib.pyplot as plt
18
+
19
+ def make_similarity_graph(metrics_name: str,
20
+ dataset_name: str,
21
+ intermediate_similarity: dict,
22
+ degrade_layers: list[str],
23
+ xquant_config: XQuantConfig) -> None:
24
+ """
25
+ Detect degrade layers by caliculated similarities by XQuant.
26
+ And Draw and save similarity graphs.
27
+
28
+ Args:
29
+ metrics_name (str): Metrics name of caluclualting quantized error. (i.g. 'mse')
30
+ dataset_name (str): Dataset name ('repr'(Represantation dataset) or 'val'(Validation dataset))
31
+ intermediate_similarity (dict): Quant error reports per layers.
32
+ degrade_layers (list[str]): A list of detected degrade layers.
33
+ xquant_config (XQuantConfig): Configuration settings for explainable quantization.
34
+
35
+ Returns:
36
+ None
37
+ """
38
+
39
+ plot_title = "quant_loss_{}_{}".format(metrics_name, dataset_name)
40
+
41
+ # Get x,y from intermediate_similarity
42
+ data_x = []
43
+ data_y = []
44
+ for key_layername in intermediate_similarity.keys():
45
+ value_metrics = intermediate_similarity[key_layername][metrics_name]
46
+ data_x.append(key_layername)
47
+ data_y.append(value_metrics)
48
+ degrade_layer_indexes = [data_x.index(layer) for layer in degrade_layers]
49
+
50
+ # Make Graph(adjust size to num layers)
51
+ plt.figure(figsize=(max(int(len(data_x)/3), 1), 5))
52
+ # Draw plot with red circle markers of degrade layers
53
+ plt.plot(data_x, data_y, markevery=degrade_layer_indexes, marker='o', markersize=20, markeredgecolor='r', markerfacecolor=[0.0, 0.0, 0.0, 0.0], markeredgewidth=3)
54
+ plt.grid()
55
+ # Add labels
56
+ plt.title(plot_title)
57
+ plt.xlabel("layer name")
58
+ plt.ylabel(metrics_name)
59
+ plt.xticks(rotation=90)
60
+ # Label colors of degraded layers are red
61
+ _, plt_x_labels = plt.xticks()
62
+ for degrade_layer_index in degrade_layer_indexes:
63
+ plt_x_labels[degrade_layer_index].set_color("red")
64
+ # Add threshold line
65
+ threshold = xquant_config.threshold_quantize_error[metrics_name]
66
+ plt.hlines(threshold, 0, len(data_x)-1, "red", linestyles='dashed')
67
+ plt.text(0, threshold+(max(data_y)-min(data_y))*0.02, "threshold={}".format(threshold), color="red")
68
+
69
+ plt.tight_layout()
70
+
71
+ # Save Image
72
+ path_outfile = "{}/{}.png".format(xquant_config.report_dir, plot_title, dataset_name)
73
+ plt.savefig(path_outfile)
74
+
75
+ plt.clf()
76
+ plt.close()
77
+
78
+ return None
@@ -16,12 +16,12 @@
16
16
  from typing import Callable
17
17
 
18
18
  from model_compression_toolkit.verify_packages import FOUND_TORCH
19
- from model_compression_toolkit.xquant.common.core_report_generator import core_report_generator
20
19
  from model_compression_toolkit.xquant import XQuantConfig
21
20
  from model_compression_toolkit.logger import Logger
22
21
 
23
22
  if FOUND_TORCH:
24
23
  from model_compression_toolkit.xquant.pytorch.pytorch_report_utils import PytorchReportUtils
24
+ from model_compression_toolkit.xquant.pytorch.core_report_generator import core_report_generator, core_report_generator_troubleshoot
25
25
  import torch
26
26
 
27
27
  def xquant_report_pytorch_experimental(float_model: torch.nn.Module,
@@ -58,7 +58,47 @@ if FOUND_TORCH:
58
58
 
59
59
  return _collected_data
60
60
 
61
+ def xquant_report_troubleshoot_pytorch_experimental(float_model: torch.nn.Module,
62
+ quantized_model: torch.nn.Module,
63
+ repr_dataset: Callable,
64
+ validation_dataset: Callable,
65
+ xquant_config: XQuantConfig):
66
+ """
67
+ Generate an explainable quantization report, detect degraded layaers and judge degrade causes for a quantized Pytorch model.
68
+
69
+ Args:
70
+ float_model (torch.nn.Module): The original floating-point Pytorch model.
71
+ quantized_model (torch.nn.Module): The quantized Pytorch model.
72
+ repr_dataset (Callable): The representative dataset used during quantization.
73
+ validation_dataset (Callable): The validation dataset used for evaluation.
74
+ xquant_config (XQuantConfig): Configuration settings for explainable quantization.
75
+
76
+ Returns:
77
+ Dict[str, Any]: A dictionary containing the collected similarity metrics and report data.
78
+ Dict[str, Any]: A dictionary containing the analyze degrade cause report for degraded layaers.
79
+ """
80
+ # Initialize the logger with the report directory.
81
+ Logger.set_log_file(log_folder=xquant_config.report_dir)
82
+
83
+ pytorch_report_utils = PytorchReportUtils(xquant_config.report_dir)
84
+
85
+ _collected_data, _troubleshoot_data = core_report_generator_troubleshoot(float_model=float_model,
86
+ quantized_model=quantized_model,
87
+ repr_dataset=repr_dataset,
88
+ validation_dataset=validation_dataset,
89
+ fw_report_utils=pytorch_report_utils,
90
+ xquant_config=xquant_config)
91
+
92
+ Logger.shutdown()
93
+
94
+ return _collected_data, _troubleshoot_data
95
+
96
+
61
97
  else:
62
98
  def xquant_report_pytorch_experimental(*args, **kwargs):
63
99
  Logger.critical("PyTorch must be installed to use 'xquant_report_pytorch_experimental'. "
64
100
  "The 'torch' package is missing.") # pragma: no cover
101
+
102
+ def xquant_report_troubleshoot_pytorch_experimental(*args, **kwargs):
103
+ Logger.critical("PyTorch must be installed to use 'xquant_report_troubleshoot_pytorch_experimental'. "
104
+ "The 'torch' package is missing.") # pragma: no cover
@@ -0,0 +1,98 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import json
17
+ import os
18
+
19
+ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
20
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
21
+ from typing import Any, Dict, Callable
22
+
23
+ from model_compression_toolkit.xquant.common.constants import REPORT_FILENAME, TROUBLESHOOT_REPORT_FILENAME
24
+ from model_compression_toolkit.xquant.common.dataset_utils import DatasetUtils
25
+ from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
26
+ from model_compression_toolkit.xquant.pytorch.similarity_calculator import SimilarityCalculator
27
+ from model_compression_toolkit.xquant.common.tensorboard_utils import TensorboardUtils
28
+ from model_compression_toolkit.logger import Logger
29
+
30
+
31
+ class FrameworkReportUtils:
32
+ """
33
+ Class with various utility components required for generating the report in a specific framework.
34
+ """
35
+
36
+ def __init__(self,
37
+ fw_info: FrameworkInfo,
38
+ fw_impl: FrameworkImplementation,
39
+ similarity_calculator: SimilarityCalculator,
40
+ dataset_utils: DatasetUtils,
41
+ model_folding_utils: ModelFoldingUtils,
42
+ tb_utils: TensorboardUtils,
43
+ get_metadata_fn: Callable):
44
+ """
45
+ Initializes the FrameworkReportUtils class with various utility components required for generating the report.
46
+
47
+ Args:
48
+ fw_info (FrameworkInfo): Information about the framework being used.
49
+ fw_impl (FrameworkImplementation): The implemented functions of the framework.
50
+ similarity_calculator (SimilarityCalculator): A utility for calculating similarity metrics.
51
+ dataset_utils (DatasetUtils): Utilities for handling datasets.
52
+ model_folding_utils (ModelFoldingUtils): Utilities for model folding operations.
53
+ tb_utils (TensorboardUtils): Utilities for TensorBoard operations.
54
+ get_metadata_fn (Callable): Function to retrieve the metadata from the quantized model.
55
+ """
56
+ self.fw_info = fw_info
57
+ self.fw_impl = fw_impl
58
+ self.similarity_calculator = similarity_calculator
59
+ self.dataset_utils = dataset_utils
60
+ self.model_folding_utils = model_folding_utils
61
+ self.tb_utils = tb_utils
62
+ self.get_metadata_fn = get_metadata_fn
63
+
64
+ def dump_report_to_json(self,
65
+ report_dir: str,
66
+ collected_data: Dict[str, Any]):
67
+ """
68
+ Dump the collected data (similarity, etc.) into a JSON file.
69
+
70
+ Args:
71
+ report_dir (str): Directory where the report will be saved.
72
+ collected_data (Dict[str, Any]): Data collected during report generation.
73
+
74
+ """
75
+ report_file_name = os.path.join(report_dir, REPORT_FILENAME)
76
+ report_file_name = os.path.abspath(report_file_name)
77
+ Logger.info(f"Dumping report data to: {report_file_name}")
78
+
79
+ with open(report_file_name, 'w') as f:
80
+ json.dump(collected_data, f, indent=4)
81
+
82
+ def dump_troubleshoot_report_to_json(self,
83
+ report_dir: str,
84
+ collected_data: Dict[str, Any]):
85
+ """
86
+ Dump the collected data (similarity, etc.) into a JSON file.
87
+
88
+ Args:
89
+ report_dir (str): Directory where the report will be saved.
90
+ collected_data (Dict[str, Any]): Data collected during report generation.
91
+
92
+ """
93
+ report_file_name = os.path.join(report_dir, TROUBLESHOOT_REPORT_FILENAME)
94
+ report_file_name = os.path.abspath(report_file_name)
95
+ Logger.info(f"Dumping report data to: {report_file_name}")
96
+
97
+ with open(report_file_name, 'w') as f:
98
+ json.dump(collected_data, f, indent=4)