mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250927.534__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -0,0 +1,562 @@
1
+ # Copyright 2025 Sony Semiconductor Solutions. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from functools import partial
16
+ from typing import Tuple, List, Callable
17
+ import os
18
+ import glob
19
+ import numpy as np
20
+ import matplotlib.pyplot as plt
21
+ from model_compression_toolkit.xquant import XQuantConfig
22
+ from model_compression_toolkit.xquant.pytorch.dataset_utils import PytorchDatasetUtils
23
+ from model_compression_toolkit.core.common.graph.base_graph import Graph
24
+ from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
25
+ import torch
26
+ from torch.nn import Hardswish, SiLU, PReLU, ELU, GELU
27
+ from torch.nn.functional import hardswish, silu, prelu, elu, gelu
28
+ from tensorboard.backend.event_processing.plugin_event_accumulator import EventAccumulator
29
+ from model_compression_toolkit.logger import Logger
30
+
31
+
32
+ def _compute_zscore(statistics_collector: BaseStatsCollector) -> Tuple[np.array, np.array, np.array]:
33
+ """
34
+ Compute z-score from collected histogram.
35
+
36
+ Args:
37
+ statistics_collector (BaseStatsCollector): Statistics collector to compute z-score.
38
+
39
+ Returns:
40
+ Tuple[np.array, np.array, np.array]: A tuple containing computed z-score, histogram statistics (bins and counts).
41
+ """
42
+
43
+ if statistics_collector.require_collection():
44
+ if hasattr(statistics_collector, 'hc'):
45
+ if statistics_collector.hc.is_legal:
46
+ bins, counts = statistics_collector.hc.get_histogram()
47
+ if bins is not None and counts is not None:
48
+ bins = np.copy(bins)
49
+ counts = np.copy(counts)
50
+ bins = bins[:-1] # take out the last range
51
+
52
+ # Compute the z-score
53
+ mu = np.sum(bins * counts) / np.sum(counts)
54
+ sigma = np.sqrt(np.sum(np.power(bins - mu, 2.0) * counts) / np.sum(counts))
55
+ z_score = np.abs(bins - mu) / sigma
56
+
57
+ return (z_score, bins, counts, mu, sigma)
58
+
59
+ return None
60
+
61
+
62
+ def _save_outlier_histogram(layer_name: str, zscore_hist: Tuple[np.ndarray, np.ndarray, np.ndarray, float, float],
63
+ z_threshold: float, img_filename: str):
64
+ """
65
+ Save output activation distributions histogram.
66
+
67
+ Args:
68
+ layer_name (str): The name of the layer.
69
+ zscore_hist (Tuple[np.ndarray, np.ndarray, np.ndarray, float, float]): A tuple containing z-score, histogram statistics.
70
+ z_threshold (float): Threshold for detecting outliers.
71
+ img_filename (str): Filename to save histogram image to.
72
+ """
73
+ zscore, bins, counts, mu, sigma = zscore_hist
74
+ #zscore_over = zscore >= z_threshold
75
+
76
+ # Compute bin centers
77
+ bin_centers = bins
78
+
79
+ # Thresholds
80
+ lower_threshold = -z_threshold
81
+ upper_threshold = z_threshold
82
+ #z_score_vline = z_threshold * sigma + mu
83
+ """ if bins-mu >= 0:
84
+ z_score_vline = mu+z_threshold*(sigma+10e-16)
85
+ else:
86
+ z_score_vline = mu-z_threshold*(sigma+10e-16) """
87
+
88
+ lower_th_z_score_vline = mu-z_threshold*(sigma+10e-16)
89
+ upper_th_z_score_vline = mu+z_threshold*(sigma+10e-16)
90
+
91
+
92
+ def detect_outliers_after_peak(counts, bin_centers, lower_thresh, upper_thresh):
93
+ """
94
+ Filtering outlier data in the histogram that are outside the threshold and not monotonically decreasing.
95
+
96
+ Args:
97
+ counts (np.ndarray): z-score histogram counts.
98
+ bin_centers (np.ndarray): the centor data of histogram counts.
99
+ lower_thresh (float): Threshold for detecting outliers(lower side).
100
+ upper_thresh (float): Threshold for detecting outliers(upper side).
101
+ Returns:
102
+ color (np.ndarray): color data of filterd histogram.
103
+ """
104
+
105
+ # Detect outliers: once a bin exceeds the threshold and is greater than its neighbor,
106
+ # all subsequent bins are considered outliers.
107
+
108
+ colors = ['blue'] * len(counts)
109
+ upper_outlier_indices = []
110
+ lower_outlier_indices = []
111
+
112
+
113
+ # Check for positive side (right-facing peak)
114
+ for i in range(len(counts)):
115
+ if bin_centers[i] > upper_thresh:
116
+ if i > 0 and counts[i] > counts[i - 1]:
117
+ upper_outlier_indices = list(range(i, len(counts)))
118
+ break
119
+
120
+ # Check for negative side (left-facing peak)
121
+ for i in reversed(range(len(counts))):
122
+ if bin_centers[i] < lower_thresh:
123
+ if i < len(counts) - 1 and counts[i] > counts[i + 1]:
124
+ lower_outlier_indices = list(range(0, i + 1))
125
+ break
126
+
127
+ for idx1 in upper_outlier_indices:
128
+ colors[idx1] = 'red'
129
+
130
+ for idx2 in lower_outlier_indices:
131
+ colors[idx2] = 'red'
132
+
133
+ return colors
134
+
135
+
136
+ # Detect outliers
137
+ colors = detect_outliers_after_peak(counts, bin_centers, lower_th_z_score_vline, upper_th_z_score_vline)
138
+
139
+ # Plot
140
+ fig = plt.figure(figsize=(10, 5))
141
+ ax = fig.add_subplot(111)
142
+ ax_top = ax.twiny()
143
+
144
+ ax.bar(bin_centers, counts, width=0.9 * (bins[1] - bins[0]), color=colors)
145
+ ax.set_ylabel('counts')
146
+ ax.set_xlabel('bins')
147
+
148
+ bin_ticks = ax.get_xticks()
149
+ z_score_ticks = np.abs(bin_ticks - mu) / (sigma + 1e-16)
150
+
151
+ ax.set_xticks(bin_ticks)
152
+ ax_top.set_xticks(bin_ticks)
153
+ ax_top.set_xticklabels([f'{zs:.2f}' for zs in z_score_ticks])
154
+ ax_top.set_xlabel('z-score')
155
+ plt.ylim(0, np.max(counts) * 1.1)
156
+ ax.set_title(layer_name)
157
+
158
+ xlim_min, xlim_max = ax.get_xlim()
159
+ xlim_range = xlim_max - xlim_min
160
+
161
+ upper_zscore=abs(max(bins) - mu) / (sigma + 1e-16)
162
+ ax.axvline(x=max(bins), color='black', linestyle='--')
163
+ plt.text(max(bins) - 0.18*xlim_range, plt.ylim()[1] * 0.9, f'upper zscore={upper_zscore:.2f}', color='black')
164
+
165
+ lower_zscore=abs(min(bins) - mu) / (sigma + 1e-16)
166
+ ax.axvline(x=min(bins), color='black', linestyle='--')
167
+ plt.text(min(bins) + 0.01*xlim_range, plt.ylim()[1] * 0.9, f'lower zscore={lower_zscore:.2f}', color='black')
168
+
169
+ if(upper_th_z_score_vline <= xlim_max):
170
+ ax_top.axvline(x=upper_th_z_score_vline, color='r', linestyle='--')
171
+ plt.text(upper_th_z_score_vline - 0.12*xlim_range, plt.ylim()[1] * 0.95, f'threshold={z_threshold}', color='red')
172
+ else:
173
+ plt.text(xlim_max - 0.18*xlim_range, plt.ylim()[1] * 0.95, f'threshold={z_threshold} ->', color='red')
174
+
175
+ if(xlim_min <= lower_th_z_score_vline):
176
+ ax_top.axvline(x=lower_th_z_score_vline, color='r', linestyle='--')
177
+ plt.text(lower_th_z_score_vline + 0.01*xlim_range, plt.ylim()[1] * 0.95, f'threshold={z_threshold}', color='red')
178
+ else:
179
+ plt.text(xlim_min + 0.01*xlim_range, plt.ylim()[1] * 0.95, f'<- threshold={z_threshold}', color='red')
180
+
181
+
182
+ plt.tight_layout()
183
+
184
+ plt.savefig(img_filename)
185
+ plt.clf()
186
+ plt.close()
187
+
188
+
189
+ def judge_outlier_removal(degrade_layers: list[str], float_graph: Graph,
190
+ xquant_config: XQuantConfig) -> Tuple[str, str]:
191
+ """
192
+ Judge whether the degrade layers have outliers from statistics information
193
+
194
+ Args:
195
+ degrade_layers (list[str]): A list of detected degrade layers.
196
+ float_graph (Graph): Graph to get statistics for the layer.
197
+ xquant_config (XQuantConfig): Configuration settings for explainable quantization.
198
+
199
+ Returns:
200
+ Tuple[str, str]: A tuple containing the layer name with outliers and the image filename saved the histogram.
201
+ If the layer does not match the condition, it returns None.
202
+ """
203
+ z_threshold = xquant_config.threshold_zscore_outlier_removal
204
+
205
+ outlier_layers = []
206
+
207
+ for layer_name in degrade_layers:
208
+ nodes = float_graph.find_node_by_name(layer_name)
209
+ assert len(nodes) == 1 # check length of nodes
210
+
211
+ collector = float_graph.get_out_stats_collector(nodes[0])
212
+ if collector is not None:
213
+ statistics = float_graph.get_out_stats_collector(nodes[0])
214
+ zscore_hist = _compute_zscore(statistics)
215
+
216
+ if zscore_hist is not None:
217
+ zscore = zscore_hist[0]
218
+
219
+ if np.any(zscore >= z_threshold):
220
+ img_filename = os.path.join(xquant_config.report_dir, f'outlier_histgrams', f'{layer_name}.png')
221
+ if(os.path.exists(os.path.join(xquant_config.report_dir, f'outlier_histgrams'))):
222
+ _save_outlier_histogram(layer_name, zscore_hist, z_threshold, img_filename)
223
+ else:
224
+ Logger.warning("Output directory of outlier histgram images({}/outlier_histgrams) not found. Skipping output outlier histgram images.".format(xquant_config.report_dir))
225
+
226
+ # Print to Console
227
+ if(len(outlier_layers) == 0):
228
+ Logger.warning("There are output values ​​that deviate significantly from the average. Refer to the following images and the TroubleShooting Documentation (MCT XQuant Extension Tool) of \'Outlier Removal\'.")
229
+ Logger.warning(img_filename)
230
+
231
+ outlier_layers.append((layer_name, img_filename))
232
+
233
+ return outlier_layers
234
+
235
+
236
+ def judge_shift_negative_activation(float_graph: Graph,
237
+ xquant_config: XQuantConfig) -> list[str]:
238
+ """
239
+ Judge whether the layer has a negative activation function (PReLU / ELU / Hardswish / SiLU / GELU)
240
+
241
+ Args:
242
+ float_graph (Graph): Graph to get class name for the layer.
243
+ xquant_config (XQuantConfig): Configuration settings for explainable quantization.
244
+
245
+ Returns:
246
+ list[str]: A list of shift negative activation layers.
247
+
248
+ """
249
+ negative_activation_functions = [PReLU, prelu,
250
+ ELU, elu,
251
+ Hardswish, hardswish,
252
+ SiLU, silu,
253
+ GELU, gelu]
254
+
255
+ negative_activations = []
256
+
257
+ for n in float_graph.nodes:
258
+ if(n.layer_class in negative_activation_functions):
259
+ # Print to Console
260
+ if(len(negative_activations) == 0):
261
+ Logger.warning("There are activations that contain negative values. Refer to the troubleshooting manual of \"Shift Negative Activation\".")
262
+ Logger.warning("{}={}".format(n.name, n.layer_class.__name__))
263
+
264
+ negative_activations.append(n.name)
265
+
266
+ return negative_activations
267
+
268
+ def _compute_activations(name: str, activations: dict):
269
+ """
270
+ Creates a hook function to capture the activations of a layer.
271
+
272
+ Args:
273
+ name (str): The name of the layer.
274
+ activations (dict): The dictionary to store the activations.
275
+
276
+ Returns:
277
+ hook (function): The hook function to register with the layer.
278
+ """
279
+ def hook(model, input, output):
280
+ activation = input[0].detach()
281
+
282
+ if name not in activations.keys():
283
+ activations[name] = []
284
+ activations[name].append(activation)
285
+
286
+ return hook
287
+
288
+ def judge_unbalanced_concatnation(degrade_layers: list[str],
289
+ float_model: torch.nn.Module,
290
+ dataset: Callable,
291
+ xquant_config: XQuantConfig) -> List[List[Tuple[str, str, str]]]:
292
+ """
293
+ Judge whether the layer combines layers with significantly different value ranges
294
+
295
+ Args:
296
+ degrade_layers (list[str]): A list of detected degrade layers.
297
+ float_model (torch.nn.Module): The original floating-point Pytorch model.
298
+ dataset (Callable): Representative dataset used for similarity metrics computation.
299
+ xquant_config (XQuantConfig): Configuration settings for explainable quantization.
300
+
301
+ Returns:
302
+ List[List[Tuple[str, str, str]]]: A list containing layer name before concatnation, and scale adjustment.
303
+ If the layer does not match the condition, it returns None.
304
+ """
305
+
306
+ judge_results = []
307
+
308
+ if(xquant_config.quantize_reported_dir is None):
309
+ Logger.warning("XQuantConfig.quantize_reported_dir is not defined. Skip judging of \'Unbalanced \"concatenation\"\'.")
310
+ return judge_results
311
+
312
+ org_torch_add = torch.add
313
+ org_torch_tensor_add = torch.Tensor.__add__
314
+
315
+ concat_layers = {}
316
+ concat_layers_add = {}
317
+ activations_float = {}
318
+ float_model_modules = dict([*float_model.named_modules()])
319
+ for layer_name in degrade_layers:
320
+ is_search = layer_name[-3:] == '_bn' or layer_name[-10:] == '_collapsed' or layer_name[:3] == 'bn_'
321
+ if is_search:
322
+ logdir = xquant_config.quantize_reported_dir
323
+ tblog_names = ['initial_graph', 'after_graph_preparation', 'pre_statistics_collection_substitutions']
324
+ tblog_to_nodename = {}
325
+
326
+ for tblog_name in tblog_names:
327
+ tfevent_paths = sorted(glob.glob(os.path.join(logdir, 'tensorboard_logs', tblog_name, 'events.out.tfevents.*')))
328
+ if(len(tfevent_paths) == 0):
329
+ tfevent_paths = sorted(glob.glob(os.path.join(logdir, '**', 'tensorboard_logs', tblog_name, 'events.out.tfevents.*')))
330
+
331
+ if(len(tfevent_paths) == 0):
332
+ Logger.warning("TensorBoard logs not found in XQuantConfig.quantize_reported_dir. Skip judging of \'Unbalanced \"concatenation\"\'.")
333
+ return judge_results
334
+ tfevent = EventAccumulator(path=tfevent_paths[0])
335
+ tfevent.Reload()
336
+
337
+ node_names = []
338
+ for log in str(tfevent.Graph()).splitlines():
339
+ if(log[:7] == ' name:'):
340
+ node_name = log.split('\"')[-2]
341
+ node_names.append(node_name)
342
+
343
+ tblog_to_nodename[tblog_name] = node_names
344
+
345
+ # layer_name = 'features_1_conv_0_0_bn'
346
+
347
+ after_graph_preparation_log = tblog_to_nodename['after_graph_preparation']
348
+
349
+ first_node, second_node = None, None
350
+
351
+ if layer_name[-3:] == '_bn': # features_1_conv_0_0_bn
352
+
353
+ target_name = layer_name[:-3] # features_1_conv_0_0
354
+ for idx, node_name in enumerate(after_graph_preparation_log):
355
+ _name = node_name.split('/')[-1] # node_name: MobileNetV2/Conv2d_1/features_1_conv_0_0, _name: features_1_conv_0_0
356
+
357
+ if _name == target_name:
358
+ _idx, _first_node = idx, node_name
359
+ _second_node = after_graph_preparation_log[_idx + 1]
360
+ continue
361
+
362
+ elif layer_name[-10:] == '_collapsed':
363
+
364
+ target_name = layer_name[:-10]
365
+
366
+ sorted_node_names_by_len = sorted(after_graph_preparation_log, key=len, reverse=True)
367
+ _first_node = None
368
+ for node_name in sorted_node_names_by_len:
369
+ _name = node_name.split('/')[-1]
370
+ if target_name.startswith(_name):
371
+ _first_node = _name
372
+ break
373
+
374
+ if(_first_node is not None):
375
+ for node_name in sorted_node_names_by_len:
376
+ _name = node_name.split('/')[-1]
377
+ target_name_exclude_first_node = target_name[len(_first_node)+1:]
378
+ if _name == target_name_exclude_first_node or _name+"_bn" == target_name_exclude_first_node:
379
+ _second_node = _name
380
+ break
381
+
382
+ elif layer_name[:3] == 'bn_':
383
+
384
+ target_name = layer_name[3:]
385
+ for idx, node_name in enumerate(after_graph_preparation_log):
386
+ _name = node_name.split('/')[-1]
387
+
388
+ if _name == target_name:
389
+ _idx, _second_node = idx, node_name
390
+ _first_node = after_graph_preparation_log[_idx - 1]
391
+ continue
392
+
393
+ if _first_node is not None and _second_node is not None:
394
+ first_node = _first_node.split('/')[-1].replace("_", ".") # features.1.conv.0.0
395
+ second_node = _second_node.split('/')[-1].replace("_", ".") # features.1.conv.0.1
396
+ if first_node in float_model_modules.keys() and second_node in float_model_modules.keys():
397
+ float_model_modules[first_node].register_forward_hook(_compute_activations(first_node, activations_float))
398
+ float_model_modules[second_node].register_forward_hook(_compute_activations(second_node, activations_float))
399
+
400
+ concat_layers[layer_name] = (first_node, second_node)
401
+ elif first_node in float_model_modules.keys() and second_node == "add":
402
+ float_model_modules[first_node].register_forward_hook(_compute_activations(first_node, activations_float))
403
+ concat_layers_add[layer_name] = first_node
404
+
405
+ # Hooks cannot be applied to add operations. Define temporarily wrapper functions of "torch.add" and "+" to capture values after the first_node.
406
+ add_activations = {}
407
+ for first_node in concat_layers_add.values():
408
+ add_activations[first_node] = []
409
+
410
+ def hook_add(x, y, *args, **kwargs):
411
+ """
412
+ Hook function to detect calls to torch.add during model execution.
413
+
414
+ Args:
415
+ x (torch.Tensor): The first operand in the addition operation.
416
+ y (torch.Tensor): The second operand in the addition operation.
417
+ *args: Additional positional arguments passed to torch.add.
418
+ **kwargs: Additional keyword arguments passed to torch.add.
419
+
420
+ Returns:
421
+ torch.Tensor: The result of the addition operation.
422
+ """
423
+ add_result = org_torch_add(x, y, *args, **kwargs)
424
+ for first_node in add_activations.keys():
425
+ conv_output = activations_float.get(first_node)[-1]
426
+ if conv_output is not None and (torch.equal(x, conv_output) or torch.equal(y, conv_output)):
427
+ add_activations[first_node].append(add_result.detach())
428
+ return add_result
429
+ return add_result
430
+
431
+ def hook_tensor_add(self, other):
432
+ """
433
+ Hook function to detect calls to [torch.Tensor + torch.Tensor] during model execution.
434
+
435
+ Args:
436
+ self (torch.Tensor): The left operand of the addition.
437
+ other (torch.Tensor): The right operand of the addition.
438
+
439
+ Returns:
440
+ torch.Tensor: The result of the addition operation.
441
+ """
442
+ add_result = org_torch_tensor_add(self, other)
443
+ for first_node in add_activations.keys():
444
+ conv_output = activations_float.get(first_node)[-1]
445
+ if conv_output is not None and (torch.equal(self, conv_output) or torch.equal(other, conv_output)):
446
+ add_activations[first_node].append(add_result.detach())
447
+ return add_result
448
+ return add_result
449
+
450
+ # Replace temporarily wrapper add functions
451
+ torch.add = hook_add
452
+ torch.Tensor.__add__ = hook_tensor_add
453
+
454
+ # Perform a forward pass with the input data and capture activations
455
+ dataset = partial(PytorchDatasetUtils.prepare_dataset, dataset=dataset, is_validation=True)
456
+ if(( len(concat_layers) + len(concat_layers_add)) > 0):
457
+ for data in dataset():
458
+ with torch.no_grad():
459
+ _ = float_model(*data)
460
+
461
+ # Restore the original add functions
462
+ torch.add = org_torch_add
463
+ torch.Tensor.__add__ = org_torch_tensor_add
464
+
465
+ for layer_name in concat_layers.keys():
466
+ first_node, second_node = concat_layers[layer_name]
467
+
468
+ act_first_node = activations_float.get(first_node)
469
+ act_second_node = activations_float.get(second_node)
470
+
471
+ all_act_first_node, all_act_second_node = torch.cat(act_first_node), torch.cat(act_second_node)
472
+ min_act_first_node, min_act_second_node = torch.min(all_act_first_node).item(), torch.min(all_act_second_node).item()
473
+ max_act_first_node, max_act_second_node = torch.max(all_act_first_node).item(), torch.max(all_act_second_node).item()
474
+
475
+ # Calculate act range
476
+ range_first_node = max_act_first_node - min_act_first_node
477
+ range_second_node = max_act_second_node - min_act_second_node
478
+
479
+ # Calculate ratio
480
+ range_ratio = range_second_node / (range_first_node + 1e-10)
481
+ scaling_formula = "first layer * {}".format(range_ratio)
482
+
483
+ range_ratio_over1 = range_ratio if range_ratio >= 1.0 else 1/range_ratio
484
+ th_ratio = xquant_config.threshold_ratio_unbalanced_concatenation
485
+ if range_ratio_over1 >= th_ratio:
486
+ # Print to Console
487
+ if(len(judge_results) == 0):
488
+ Logger.warning("There are unbalanced range layers concatnated. Refer to the troubleshooting manual of \'Unbalanced \"concatenation\"\'.")
489
+ Logger.warning("first layer:{}, second layer:{}, if you add a scaling operation, recommended scaling:{}".format(first_node, second_node, scaling_formula))
490
+
491
+ judge_results.append((first_node, second_node, scaling_formula))
492
+
493
+ for layer_name in concat_layers_add.keys():
494
+ first_node = concat_layers_add[layer_name]
495
+
496
+ act_first_node = activations_float.get(first_node)
497
+ act_second_node = add_activations[first_node]
498
+
499
+ all_act_first_node, all_act_second_node = torch.cat(act_first_node), torch.cat(act_second_node)
500
+ min_act_first_node, min_act_second_node = torch.min(all_act_first_node).item(), torch.min(all_act_second_node).item()
501
+ max_act_first_node, max_act_second_node = torch.max(all_act_first_node).item(), torch.max(all_act_second_node).item()
502
+
503
+ # Calculate act range
504
+ range_first_node = max_act_first_node - min_act_first_node
505
+ range_second_node = max_act_second_node - min_act_second_node
506
+
507
+ # Calculate ratio
508
+ range_ratio = range_second_node / (range_first_node + 1e-10)
509
+ scaling_formula = "first layer * {}".format(range_ratio)
510
+
511
+ range_ratio_over1 = range_ratio if range_ratio >= 1.0 else 1/range_ratio
512
+ th_ratio = xquant_config.threshold_ratio_unbalanced_concatenation
513
+ if range_ratio_over1 >= th_ratio:
514
+ # Print to Console
515
+ if(len(judge_results) == 0):
516
+ Logger.warning("There are unbalanced range layers concatnated. Refer to the troubleshooting manual of \'Unbalanced \"concatenation\"\'.")
517
+ Logger.warning("first layer:{}, second layer:{}, if you add a scaling operation, recommended scaling:{}".format(first_node, second_node, scaling_formula))
518
+
519
+ judge_results.append((first_node, second_node, scaling_formula))
520
+
521
+ return judge_results
522
+
523
+ def judge_mixed_precision_with_model_output_loss_objective(quantized_model: torch.nn.Module,
524
+ xquant_config: XQuantConfig) -> str:
525
+ """
526
+ Judge whether the bitwidth of the final layer is less than threshold
527
+
528
+ Args:
529
+ quantized_model (torch.nn.Module): The quantized Pytorch model.
530
+ xquant_config (XQuantConfig): Configuration settings for explainable quantization.
531
+
532
+ Returns:
533
+ str: The name of the final layer. If the layer does not match the condition, it returns None.
534
+ """
535
+ threshold_bitwidth = xquant_config.threshold_bitwidth_mixed_precision_with_model_output_loss_objective
536
+
537
+ is_mixed_precision_with_model_output_loss_objective = False
538
+
539
+ last_layer_name, last_layer = list(quantized_model.named_children())[-2]
540
+ if(hasattr(last_layer, "weights_quantizers")):
541
+ bitwidth_weights = last_layer.weights_quantizers['weight'].num_bits
542
+ if(bitwidth_weights <= threshold_bitwidth):
543
+ is_mixed_precision_with_model_output_loss_objective = True
544
+ else:
545
+ bitwidth_weights = None
546
+
547
+ last_layer_activation_name, last_layer_activation = list(quantized_model.named_children())[-1]
548
+ if(hasattr(last_layer_activation, "activation_holder_quantizer")):
549
+ bitwidth_activation = last_layer_activation.activation_holder_quantizer.num_bits
550
+ if(bitwidth_activation <= threshold_bitwidth):
551
+ is_mixed_precision_with_model_output_loss_objective = True
552
+ else:
553
+ bitwidth_activation = None
554
+
555
+ if is_mixed_precision_with_model_output_loss_objective:
556
+ # Print to Console
557
+ Logger.warning("the quantization bitwidth of the last layer is an extremely small number. Refer to the troubleshooting manual of \'Mixed Precision with model output loss objective\'.")
558
+ Logger.warning("bidwidth of {}:{}(Weight), {}(Activation)".format(last_layer_name, bitwidth_weights, bitwidth_activation))
559
+
560
+ return [last_layer_name]
561
+
562
+ return []
@@ -19,18 +19,17 @@ from model_compression_toolkit.target_platform_capabilities.constants import DEF
19
19
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
20
20
  AttachTpcToPytorch
21
21
 
22
- from model_compression_toolkit.xquant.common.framework_report_utils import FrameworkReportUtils
23
- from model_compression_toolkit.core.pytorch.default_framework_info import PyTorchInfo
22
+ from model_compression_toolkit.xquant.pytorch.framework_report_utils import FrameworkReportUtils
23
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
24
24
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
25
25
  from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
26
- from model_compression_toolkit.xquant.common.similarity_calculator import SimilarityCalculator
26
+ from model_compression_toolkit.xquant.pytorch.similarity_calculator import SimilarityCalculator
27
27
  from model_compression_toolkit.xquant.pytorch.dataset_utils import PytorchDatasetUtils
28
28
  from model_compression_toolkit.xquant.pytorch.model_analyzer import PytorchModelAnalyzer
29
29
  from model_compression_toolkit.xquant.pytorch.similarity_functions import PytorchSimilarityFunctions
30
30
  from model_compression_toolkit.xquant.pytorch.tensorboard_utils import PytorchTensorboardUtils
31
31
  from mct_quantizers.pytorch.metadata import get_metadata
32
32
 
33
-
34
33
  class PytorchReportUtils(FrameworkReportUtils):
35
34
  """
36
35
  Class with various utility components required for generating the report for a Pytorch model.
@@ -40,6 +39,7 @@ class PytorchReportUtils(FrameworkReportUtils):
40
39
  Args:
41
40
  report_dir: Logging dir path.
42
41
  """
42
+ fw_info = DEFAULT_PYTORCH_INFO
43
43
  fw_impl = PytorchImplementation()
44
44
  # Set the default Target Platform Capabilities (TPC) for PyTorch.
45
45
  default_tpc = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
@@ -47,7 +47,8 @@ class PytorchReportUtils(FrameworkReportUtils):
47
47
  framework_quantization_capabilities = attach2pytorch.attach(default_tpc)
48
48
 
49
49
  dataset_utils = PytorchDatasetUtils()
50
- model_folding = ModelFoldingUtils(fw_impl=fw_impl,
50
+ model_folding = ModelFoldingUtils(fw_info=fw_info,
51
+ fw_impl=fw_impl,
51
52
  fw_default_fqc=framework_quantization_capabilities)
52
53
 
53
54
  similarity_calculator = SimilarityCalculator(dataset_utils=dataset_utils,
@@ -57,9 +58,11 @@ class PytorchReportUtils(FrameworkReportUtils):
57
58
  device=get_working_device())
58
59
 
59
60
  tb_utils = PytorchTensorboardUtils(report_dir=report_dir,
60
- fw_impl=fw_impl)
61
+ fw_impl=fw_impl,
62
+ fw_info=fw_info)
61
63
 
62
- super().__init__(fw_impl=fw_impl,
64
+ super().__init__(fw_info=fw_info,
65
+ fw_impl=fw_impl,
63
66
  tb_utils=tb_utils,
64
67
  dataset_utils=dataset_utils,
65
68
  similarity_calculator=similarity_calculator,
@@ -15,6 +15,7 @@
15
15
  from functools import partial
16
16
 
17
17
  from typing import Tuple, Any, Dict, Callable
18
+ import torch
18
19
 
19
20
  from model_compression_toolkit.xquant.common.constants import MODEL_OUTPUT_KEY
20
21
  from model_compression_toolkit.xquant.common.dataset_utils import DatasetUtils
@@ -57,6 +58,7 @@ class SimilarityCalculator:
57
58
  similarity_metrics: Dict[str, Callable]) -> Dict[str, float]:
58
59
  """
59
60
  Compute the similarity between two tensors using provided similarity metrics.
61
+ If the type of inputs are not tensor, the similarity is 0.0.
60
62
 
61
63
  Args:
62
64
  tensors_to_compare (Tuple[Any, Any]): Tensors to compare by computing their similarity.
@@ -66,7 +68,10 @@ class SimilarityCalculator:
66
68
  Dict[str, float]: A dictionary of similarity metric names and their computed values.
67
69
  """
68
70
  x, y = tensors_to_compare
69
- similarity_metrics = {k: v(x, y) for k, v in similarity_metrics.items()}
71
+ if(torch.is_tensor(x) and torch.is_tensor(y)):
72
+ similarity_metrics = {k: v(x, y) for k, v in similarity_metrics.items()}
73
+ else:
74
+ similarity_metrics = {k: 0.0 for k, v in similarity_metrics.items()}
70
75
  return similarity_metrics
71
76
 
72
77
  def _get_float_to_quantized_compare_points(self,
@@ -41,15 +41,18 @@ class PytorchTensorboardUtils(TensorboardUtils):
41
41
 
42
42
  def __init__(self,
43
43
  report_dir: str,
44
+ fw_info: FrameworkInfo,
44
45
  fw_impl: FrameworkImplementation):
45
46
  """
46
47
  Initialize the PytorchTensorboardUtils instance.
47
48
 
48
49
  Args:
49
50
  report_dir: Directory where the reports are stored.
51
+ fw_info: Information about the framework being used.
50
52
  fw_impl: Implementation methods for the framework.
51
53
  """
52
54
  super().__init__(report_dir,
55
+ fw_info,
53
56
  fw_impl)
54
57
 
55
58
  def get_graph_for_tensorboard_display(self,