mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -0,0 +1,562 @@
|
|
1
|
+
# Copyright 2025 Sony Semiconductor Solutions. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from functools import partial
|
16
|
+
from typing import Tuple, List, Callable
|
17
|
+
import os
|
18
|
+
import glob
|
19
|
+
import numpy as np
|
20
|
+
import matplotlib.pyplot as plt
|
21
|
+
from model_compression_toolkit.xquant import XQuantConfig
|
22
|
+
from model_compression_toolkit.xquant.pytorch.dataset_utils import PytorchDatasetUtils
|
23
|
+
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
24
|
+
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
|
25
|
+
import torch
|
26
|
+
from torch.nn import Hardswish, SiLU, PReLU, ELU, GELU
|
27
|
+
from torch.nn.functional import hardswish, silu, prelu, elu, gelu
|
28
|
+
from tensorboard.backend.event_processing.plugin_event_accumulator import EventAccumulator
|
29
|
+
from model_compression_toolkit.logger import Logger
|
30
|
+
|
31
|
+
|
32
|
+
def _compute_zscore(statistics_collector: BaseStatsCollector) -> Tuple[np.array, np.array, np.array]:
|
33
|
+
"""
|
34
|
+
Compute z-score from collected histogram.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
statistics_collector (BaseStatsCollector): Statistics collector to compute z-score.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
Tuple[np.array, np.array, np.array]: A tuple containing computed z-score, histogram statistics (bins and counts).
|
41
|
+
"""
|
42
|
+
|
43
|
+
if statistics_collector.require_collection():
|
44
|
+
if hasattr(statistics_collector, 'hc'):
|
45
|
+
if statistics_collector.hc.is_legal:
|
46
|
+
bins, counts = statistics_collector.hc.get_histogram()
|
47
|
+
if bins is not None and counts is not None:
|
48
|
+
bins = np.copy(bins)
|
49
|
+
counts = np.copy(counts)
|
50
|
+
bins = bins[:-1] # take out the last range
|
51
|
+
|
52
|
+
# Compute the z-score
|
53
|
+
mu = np.sum(bins * counts) / np.sum(counts)
|
54
|
+
sigma = np.sqrt(np.sum(np.power(bins - mu, 2.0) * counts) / np.sum(counts))
|
55
|
+
z_score = np.abs(bins - mu) / sigma
|
56
|
+
|
57
|
+
return (z_score, bins, counts, mu, sigma)
|
58
|
+
|
59
|
+
return None
|
60
|
+
|
61
|
+
|
62
|
+
def _save_outlier_histogram(layer_name: str, zscore_hist: Tuple[np.ndarray, np.ndarray, np.ndarray, float, float],
|
63
|
+
z_threshold: float, img_filename: str):
|
64
|
+
"""
|
65
|
+
Save output activation distributions histogram.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
layer_name (str): The name of the layer.
|
69
|
+
zscore_hist (Tuple[np.ndarray, np.ndarray, np.ndarray, float, float]): A tuple containing z-score, histogram statistics.
|
70
|
+
z_threshold (float): Threshold for detecting outliers.
|
71
|
+
img_filename (str): Filename to save histogram image to.
|
72
|
+
"""
|
73
|
+
zscore, bins, counts, mu, sigma = zscore_hist
|
74
|
+
#zscore_over = zscore >= z_threshold
|
75
|
+
|
76
|
+
# Compute bin centers
|
77
|
+
bin_centers = bins
|
78
|
+
|
79
|
+
# Thresholds
|
80
|
+
lower_threshold = -z_threshold
|
81
|
+
upper_threshold = z_threshold
|
82
|
+
#z_score_vline = z_threshold * sigma + mu
|
83
|
+
""" if bins-mu >= 0:
|
84
|
+
z_score_vline = mu+z_threshold*(sigma+10e-16)
|
85
|
+
else:
|
86
|
+
z_score_vline = mu-z_threshold*(sigma+10e-16) """
|
87
|
+
|
88
|
+
lower_th_z_score_vline = mu-z_threshold*(sigma+10e-16)
|
89
|
+
upper_th_z_score_vline = mu+z_threshold*(sigma+10e-16)
|
90
|
+
|
91
|
+
|
92
|
+
def detect_outliers_after_peak(counts, bin_centers, lower_thresh, upper_thresh):
|
93
|
+
"""
|
94
|
+
Filtering outlier data in the histogram that are outside the threshold and not monotonically decreasing.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
counts (np.ndarray): z-score histogram counts.
|
98
|
+
bin_centers (np.ndarray): the centor data of histogram counts.
|
99
|
+
lower_thresh (float): Threshold for detecting outliers(lower side).
|
100
|
+
upper_thresh (float): Threshold for detecting outliers(upper side).
|
101
|
+
Returns:
|
102
|
+
color (np.ndarray): color data of filterd histogram.
|
103
|
+
"""
|
104
|
+
|
105
|
+
# Detect outliers: once a bin exceeds the threshold and is greater than its neighbor,
|
106
|
+
# all subsequent bins are considered outliers.
|
107
|
+
|
108
|
+
colors = ['blue'] * len(counts)
|
109
|
+
upper_outlier_indices = []
|
110
|
+
lower_outlier_indices = []
|
111
|
+
|
112
|
+
|
113
|
+
# Check for positive side (right-facing peak)
|
114
|
+
for i in range(len(counts)):
|
115
|
+
if bin_centers[i] > upper_thresh:
|
116
|
+
if i > 0 and counts[i] > counts[i - 1]:
|
117
|
+
upper_outlier_indices = list(range(i, len(counts)))
|
118
|
+
break
|
119
|
+
|
120
|
+
# Check for negative side (left-facing peak)
|
121
|
+
for i in reversed(range(len(counts))):
|
122
|
+
if bin_centers[i] < lower_thresh:
|
123
|
+
if i < len(counts) - 1 and counts[i] > counts[i + 1]:
|
124
|
+
lower_outlier_indices = list(range(0, i + 1))
|
125
|
+
break
|
126
|
+
|
127
|
+
for idx1 in upper_outlier_indices:
|
128
|
+
colors[idx1] = 'red'
|
129
|
+
|
130
|
+
for idx2 in lower_outlier_indices:
|
131
|
+
colors[idx2] = 'red'
|
132
|
+
|
133
|
+
return colors
|
134
|
+
|
135
|
+
|
136
|
+
# Detect outliers
|
137
|
+
colors = detect_outliers_after_peak(counts, bin_centers, lower_th_z_score_vline, upper_th_z_score_vline)
|
138
|
+
|
139
|
+
# Plot
|
140
|
+
fig = plt.figure(figsize=(10, 5))
|
141
|
+
ax = fig.add_subplot(111)
|
142
|
+
ax_top = ax.twiny()
|
143
|
+
|
144
|
+
ax.bar(bin_centers, counts, width=0.9 * (bins[1] - bins[0]), color=colors)
|
145
|
+
ax.set_ylabel('counts')
|
146
|
+
ax.set_xlabel('bins')
|
147
|
+
|
148
|
+
bin_ticks = ax.get_xticks()
|
149
|
+
z_score_ticks = np.abs(bin_ticks - mu) / (sigma + 1e-16)
|
150
|
+
|
151
|
+
ax.set_xticks(bin_ticks)
|
152
|
+
ax_top.set_xticks(bin_ticks)
|
153
|
+
ax_top.set_xticklabels([f'{zs:.2f}' for zs in z_score_ticks])
|
154
|
+
ax_top.set_xlabel('z-score')
|
155
|
+
plt.ylim(0, np.max(counts) * 1.1)
|
156
|
+
ax.set_title(layer_name)
|
157
|
+
|
158
|
+
xlim_min, xlim_max = ax.get_xlim()
|
159
|
+
xlim_range = xlim_max - xlim_min
|
160
|
+
|
161
|
+
upper_zscore=abs(max(bins) - mu) / (sigma + 1e-16)
|
162
|
+
ax.axvline(x=max(bins), color='black', linestyle='--')
|
163
|
+
plt.text(max(bins) - 0.18*xlim_range, plt.ylim()[1] * 0.9, f'upper zscore={upper_zscore:.2f}', color='black')
|
164
|
+
|
165
|
+
lower_zscore=abs(min(bins) - mu) / (sigma + 1e-16)
|
166
|
+
ax.axvline(x=min(bins), color='black', linestyle='--')
|
167
|
+
plt.text(min(bins) + 0.01*xlim_range, plt.ylim()[1] * 0.9, f'lower zscore={lower_zscore:.2f}', color='black')
|
168
|
+
|
169
|
+
if(upper_th_z_score_vline <= xlim_max):
|
170
|
+
ax_top.axvline(x=upper_th_z_score_vline, color='r', linestyle='--')
|
171
|
+
plt.text(upper_th_z_score_vline - 0.12*xlim_range, plt.ylim()[1] * 0.95, f'threshold={z_threshold}', color='red')
|
172
|
+
else:
|
173
|
+
plt.text(xlim_max - 0.18*xlim_range, plt.ylim()[1] * 0.95, f'threshold={z_threshold} ->', color='red')
|
174
|
+
|
175
|
+
if(xlim_min <= lower_th_z_score_vline):
|
176
|
+
ax_top.axvline(x=lower_th_z_score_vline, color='r', linestyle='--')
|
177
|
+
plt.text(lower_th_z_score_vline + 0.01*xlim_range, plt.ylim()[1] * 0.95, f'threshold={z_threshold}', color='red')
|
178
|
+
else:
|
179
|
+
plt.text(xlim_min + 0.01*xlim_range, plt.ylim()[1] * 0.95, f'<- threshold={z_threshold}', color='red')
|
180
|
+
|
181
|
+
|
182
|
+
plt.tight_layout()
|
183
|
+
|
184
|
+
plt.savefig(img_filename)
|
185
|
+
plt.clf()
|
186
|
+
plt.close()
|
187
|
+
|
188
|
+
|
189
|
+
def judge_outlier_removal(degrade_layers: list[str], float_graph: Graph,
|
190
|
+
xquant_config: XQuantConfig) -> Tuple[str, str]:
|
191
|
+
"""
|
192
|
+
Judge whether the degrade layers have outliers from statistics information
|
193
|
+
|
194
|
+
Args:
|
195
|
+
degrade_layers (list[str]): A list of detected degrade layers.
|
196
|
+
float_graph (Graph): Graph to get statistics for the layer.
|
197
|
+
xquant_config (XQuantConfig): Configuration settings for explainable quantization.
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
Tuple[str, str]: A tuple containing the layer name with outliers and the image filename saved the histogram.
|
201
|
+
If the layer does not match the condition, it returns None.
|
202
|
+
"""
|
203
|
+
z_threshold = xquant_config.threshold_zscore_outlier_removal
|
204
|
+
|
205
|
+
outlier_layers = []
|
206
|
+
|
207
|
+
for layer_name in degrade_layers:
|
208
|
+
nodes = float_graph.find_node_by_name(layer_name)
|
209
|
+
assert len(nodes) == 1 # check length of nodes
|
210
|
+
|
211
|
+
collector = float_graph.get_out_stats_collector(nodes[0])
|
212
|
+
if collector is not None:
|
213
|
+
statistics = float_graph.get_out_stats_collector(nodes[0])
|
214
|
+
zscore_hist = _compute_zscore(statistics)
|
215
|
+
|
216
|
+
if zscore_hist is not None:
|
217
|
+
zscore = zscore_hist[0]
|
218
|
+
|
219
|
+
if np.any(zscore >= z_threshold):
|
220
|
+
img_filename = os.path.join(xquant_config.report_dir, f'outlier_histgrams', f'{layer_name}.png')
|
221
|
+
if(os.path.exists(os.path.join(xquant_config.report_dir, f'outlier_histgrams'))):
|
222
|
+
_save_outlier_histogram(layer_name, zscore_hist, z_threshold, img_filename)
|
223
|
+
else:
|
224
|
+
Logger.warning("Output directory of outlier histgram images({}/outlier_histgrams) not found. Skipping output outlier histgram images.".format(xquant_config.report_dir))
|
225
|
+
|
226
|
+
# Print to Console
|
227
|
+
if(len(outlier_layers) == 0):
|
228
|
+
Logger.warning("There are output values that deviate significantly from the average. Refer to the following images and the TroubleShooting Documentation (MCT XQuant Extension Tool) of \'Outlier Removal\'.")
|
229
|
+
Logger.warning(img_filename)
|
230
|
+
|
231
|
+
outlier_layers.append((layer_name, img_filename))
|
232
|
+
|
233
|
+
return outlier_layers
|
234
|
+
|
235
|
+
|
236
|
+
def judge_shift_negative_activation(float_graph: Graph,
|
237
|
+
xquant_config: XQuantConfig) -> list[str]:
|
238
|
+
"""
|
239
|
+
Judge whether the layer has a negative activation function (PReLU / ELU / Hardswish / SiLU / GELU)
|
240
|
+
|
241
|
+
Args:
|
242
|
+
float_graph (Graph): Graph to get class name for the layer.
|
243
|
+
xquant_config (XQuantConfig): Configuration settings for explainable quantization.
|
244
|
+
|
245
|
+
Returns:
|
246
|
+
list[str]: A list of shift negative activation layers.
|
247
|
+
|
248
|
+
"""
|
249
|
+
negative_activation_functions = [PReLU, prelu,
|
250
|
+
ELU, elu,
|
251
|
+
Hardswish, hardswish,
|
252
|
+
SiLU, silu,
|
253
|
+
GELU, gelu]
|
254
|
+
|
255
|
+
negative_activations = []
|
256
|
+
|
257
|
+
for n in float_graph.nodes:
|
258
|
+
if(n.layer_class in negative_activation_functions):
|
259
|
+
# Print to Console
|
260
|
+
if(len(negative_activations) == 0):
|
261
|
+
Logger.warning("There are activations that contain negative values. Refer to the troubleshooting manual of \"Shift Negative Activation\".")
|
262
|
+
Logger.warning("{}={}".format(n.name, n.layer_class.__name__))
|
263
|
+
|
264
|
+
negative_activations.append(n.name)
|
265
|
+
|
266
|
+
return negative_activations
|
267
|
+
|
268
|
+
def _compute_activations(name: str, activations: dict):
|
269
|
+
"""
|
270
|
+
Creates a hook function to capture the activations of a layer.
|
271
|
+
|
272
|
+
Args:
|
273
|
+
name (str): The name of the layer.
|
274
|
+
activations (dict): The dictionary to store the activations.
|
275
|
+
|
276
|
+
Returns:
|
277
|
+
hook (function): The hook function to register with the layer.
|
278
|
+
"""
|
279
|
+
def hook(model, input, output):
|
280
|
+
activation = input[0].detach()
|
281
|
+
|
282
|
+
if name not in activations.keys():
|
283
|
+
activations[name] = []
|
284
|
+
activations[name].append(activation)
|
285
|
+
|
286
|
+
return hook
|
287
|
+
|
288
|
+
def judge_unbalanced_concatnation(degrade_layers: list[str],
|
289
|
+
float_model: torch.nn.Module,
|
290
|
+
dataset: Callable,
|
291
|
+
xquant_config: XQuantConfig) -> List[List[Tuple[str, str, str]]]:
|
292
|
+
"""
|
293
|
+
Judge whether the layer combines layers with significantly different value ranges
|
294
|
+
|
295
|
+
Args:
|
296
|
+
degrade_layers (list[str]): A list of detected degrade layers.
|
297
|
+
float_model (torch.nn.Module): The original floating-point Pytorch model.
|
298
|
+
dataset (Callable): Representative dataset used for similarity metrics computation.
|
299
|
+
xquant_config (XQuantConfig): Configuration settings for explainable quantization.
|
300
|
+
|
301
|
+
Returns:
|
302
|
+
List[List[Tuple[str, str, str]]]: A list containing layer name before concatnation, and scale adjustment.
|
303
|
+
If the layer does not match the condition, it returns None.
|
304
|
+
"""
|
305
|
+
|
306
|
+
judge_results = []
|
307
|
+
|
308
|
+
if(xquant_config.quantize_reported_dir is None):
|
309
|
+
Logger.warning("XQuantConfig.quantize_reported_dir is not defined. Skip judging of \'Unbalanced \"concatenation\"\'.")
|
310
|
+
return judge_results
|
311
|
+
|
312
|
+
org_torch_add = torch.add
|
313
|
+
org_torch_tensor_add = torch.Tensor.__add__
|
314
|
+
|
315
|
+
concat_layers = {}
|
316
|
+
concat_layers_add = {}
|
317
|
+
activations_float = {}
|
318
|
+
float_model_modules = dict([*float_model.named_modules()])
|
319
|
+
for layer_name in degrade_layers:
|
320
|
+
is_search = layer_name[-3:] == '_bn' or layer_name[-10:] == '_collapsed' or layer_name[:3] == 'bn_'
|
321
|
+
if is_search:
|
322
|
+
logdir = xquant_config.quantize_reported_dir
|
323
|
+
tblog_names = ['initial_graph', 'after_graph_preparation', 'pre_statistics_collection_substitutions']
|
324
|
+
tblog_to_nodename = {}
|
325
|
+
|
326
|
+
for tblog_name in tblog_names:
|
327
|
+
tfevent_paths = sorted(glob.glob(os.path.join(logdir, 'tensorboard_logs', tblog_name, 'events.out.tfevents.*')))
|
328
|
+
if(len(tfevent_paths) == 0):
|
329
|
+
tfevent_paths = sorted(glob.glob(os.path.join(logdir, '**', 'tensorboard_logs', tblog_name, 'events.out.tfevents.*')))
|
330
|
+
|
331
|
+
if(len(tfevent_paths) == 0):
|
332
|
+
Logger.warning("TensorBoard logs not found in XQuantConfig.quantize_reported_dir. Skip judging of \'Unbalanced \"concatenation\"\'.")
|
333
|
+
return judge_results
|
334
|
+
tfevent = EventAccumulator(path=tfevent_paths[0])
|
335
|
+
tfevent.Reload()
|
336
|
+
|
337
|
+
node_names = []
|
338
|
+
for log in str(tfevent.Graph()).splitlines():
|
339
|
+
if(log[:7] == ' name:'):
|
340
|
+
node_name = log.split('\"')[-2]
|
341
|
+
node_names.append(node_name)
|
342
|
+
|
343
|
+
tblog_to_nodename[tblog_name] = node_names
|
344
|
+
|
345
|
+
# layer_name = 'features_1_conv_0_0_bn'
|
346
|
+
|
347
|
+
after_graph_preparation_log = tblog_to_nodename['after_graph_preparation']
|
348
|
+
|
349
|
+
first_node, second_node = None, None
|
350
|
+
|
351
|
+
if layer_name[-3:] == '_bn': # features_1_conv_0_0_bn
|
352
|
+
|
353
|
+
target_name = layer_name[:-3] # features_1_conv_0_0
|
354
|
+
for idx, node_name in enumerate(after_graph_preparation_log):
|
355
|
+
_name = node_name.split('/')[-1] # node_name: MobileNetV2/Conv2d_1/features_1_conv_0_0, _name: features_1_conv_0_0
|
356
|
+
|
357
|
+
if _name == target_name:
|
358
|
+
_idx, _first_node = idx, node_name
|
359
|
+
_second_node = after_graph_preparation_log[_idx + 1]
|
360
|
+
continue
|
361
|
+
|
362
|
+
elif layer_name[-10:] == '_collapsed':
|
363
|
+
|
364
|
+
target_name = layer_name[:-10]
|
365
|
+
|
366
|
+
sorted_node_names_by_len = sorted(after_graph_preparation_log, key=len, reverse=True)
|
367
|
+
_first_node = None
|
368
|
+
for node_name in sorted_node_names_by_len:
|
369
|
+
_name = node_name.split('/')[-1]
|
370
|
+
if target_name.startswith(_name):
|
371
|
+
_first_node = _name
|
372
|
+
break
|
373
|
+
|
374
|
+
if(_first_node is not None):
|
375
|
+
for node_name in sorted_node_names_by_len:
|
376
|
+
_name = node_name.split('/')[-1]
|
377
|
+
target_name_exclude_first_node = target_name[len(_first_node)+1:]
|
378
|
+
if _name == target_name_exclude_first_node or _name+"_bn" == target_name_exclude_first_node:
|
379
|
+
_second_node = _name
|
380
|
+
break
|
381
|
+
|
382
|
+
elif layer_name[:3] == 'bn_':
|
383
|
+
|
384
|
+
target_name = layer_name[3:]
|
385
|
+
for idx, node_name in enumerate(after_graph_preparation_log):
|
386
|
+
_name = node_name.split('/')[-1]
|
387
|
+
|
388
|
+
if _name == target_name:
|
389
|
+
_idx, _second_node = idx, node_name
|
390
|
+
_first_node = after_graph_preparation_log[_idx - 1]
|
391
|
+
continue
|
392
|
+
|
393
|
+
if _first_node is not None and _second_node is not None:
|
394
|
+
first_node = _first_node.split('/')[-1].replace("_", ".") # features.1.conv.0.0
|
395
|
+
second_node = _second_node.split('/')[-1].replace("_", ".") # features.1.conv.0.1
|
396
|
+
if first_node in float_model_modules.keys() and second_node in float_model_modules.keys():
|
397
|
+
float_model_modules[first_node].register_forward_hook(_compute_activations(first_node, activations_float))
|
398
|
+
float_model_modules[second_node].register_forward_hook(_compute_activations(second_node, activations_float))
|
399
|
+
|
400
|
+
concat_layers[layer_name] = (first_node, second_node)
|
401
|
+
elif first_node in float_model_modules.keys() and second_node == "add":
|
402
|
+
float_model_modules[first_node].register_forward_hook(_compute_activations(first_node, activations_float))
|
403
|
+
concat_layers_add[layer_name] = first_node
|
404
|
+
|
405
|
+
# Hooks cannot be applied to add operations. Define temporarily wrapper functions of "torch.add" and "+" to capture values after the first_node.
|
406
|
+
add_activations = {}
|
407
|
+
for first_node in concat_layers_add.values():
|
408
|
+
add_activations[first_node] = []
|
409
|
+
|
410
|
+
def hook_add(x, y, *args, **kwargs):
|
411
|
+
"""
|
412
|
+
Hook function to detect calls to torch.add during model execution.
|
413
|
+
|
414
|
+
Args:
|
415
|
+
x (torch.Tensor): The first operand in the addition operation.
|
416
|
+
y (torch.Tensor): The second operand in the addition operation.
|
417
|
+
*args: Additional positional arguments passed to torch.add.
|
418
|
+
**kwargs: Additional keyword arguments passed to torch.add.
|
419
|
+
|
420
|
+
Returns:
|
421
|
+
torch.Tensor: The result of the addition operation.
|
422
|
+
"""
|
423
|
+
add_result = org_torch_add(x, y, *args, **kwargs)
|
424
|
+
for first_node in add_activations.keys():
|
425
|
+
conv_output = activations_float.get(first_node)[-1]
|
426
|
+
if conv_output is not None and (torch.equal(x, conv_output) or torch.equal(y, conv_output)):
|
427
|
+
add_activations[first_node].append(add_result.detach())
|
428
|
+
return add_result
|
429
|
+
return add_result
|
430
|
+
|
431
|
+
def hook_tensor_add(self, other):
|
432
|
+
"""
|
433
|
+
Hook function to detect calls to [torch.Tensor + torch.Tensor] during model execution.
|
434
|
+
|
435
|
+
Args:
|
436
|
+
self (torch.Tensor): The left operand of the addition.
|
437
|
+
other (torch.Tensor): The right operand of the addition.
|
438
|
+
|
439
|
+
Returns:
|
440
|
+
torch.Tensor: The result of the addition operation.
|
441
|
+
"""
|
442
|
+
add_result = org_torch_tensor_add(self, other)
|
443
|
+
for first_node in add_activations.keys():
|
444
|
+
conv_output = activations_float.get(first_node)[-1]
|
445
|
+
if conv_output is not None and (torch.equal(self, conv_output) or torch.equal(other, conv_output)):
|
446
|
+
add_activations[first_node].append(add_result.detach())
|
447
|
+
return add_result
|
448
|
+
return add_result
|
449
|
+
|
450
|
+
# Replace temporarily wrapper add functions
|
451
|
+
torch.add = hook_add
|
452
|
+
torch.Tensor.__add__ = hook_tensor_add
|
453
|
+
|
454
|
+
# Perform a forward pass with the input data and capture activations
|
455
|
+
dataset = partial(PytorchDatasetUtils.prepare_dataset, dataset=dataset, is_validation=True)
|
456
|
+
if(( len(concat_layers) + len(concat_layers_add)) > 0):
|
457
|
+
for data in dataset():
|
458
|
+
with torch.no_grad():
|
459
|
+
_ = float_model(*data)
|
460
|
+
|
461
|
+
# Restore the original add functions
|
462
|
+
torch.add = org_torch_add
|
463
|
+
torch.Tensor.__add__ = org_torch_tensor_add
|
464
|
+
|
465
|
+
for layer_name in concat_layers.keys():
|
466
|
+
first_node, second_node = concat_layers[layer_name]
|
467
|
+
|
468
|
+
act_first_node = activations_float.get(first_node)
|
469
|
+
act_second_node = activations_float.get(second_node)
|
470
|
+
|
471
|
+
all_act_first_node, all_act_second_node = torch.cat(act_first_node), torch.cat(act_second_node)
|
472
|
+
min_act_first_node, min_act_second_node = torch.min(all_act_first_node).item(), torch.min(all_act_second_node).item()
|
473
|
+
max_act_first_node, max_act_second_node = torch.max(all_act_first_node).item(), torch.max(all_act_second_node).item()
|
474
|
+
|
475
|
+
# Calculate act range
|
476
|
+
range_first_node = max_act_first_node - min_act_first_node
|
477
|
+
range_second_node = max_act_second_node - min_act_second_node
|
478
|
+
|
479
|
+
# Calculate ratio
|
480
|
+
range_ratio = range_second_node / (range_first_node + 1e-10)
|
481
|
+
scaling_formula = "first layer * {}".format(range_ratio)
|
482
|
+
|
483
|
+
range_ratio_over1 = range_ratio if range_ratio >= 1.0 else 1/range_ratio
|
484
|
+
th_ratio = xquant_config.threshold_ratio_unbalanced_concatenation
|
485
|
+
if range_ratio_over1 >= th_ratio:
|
486
|
+
# Print to Console
|
487
|
+
if(len(judge_results) == 0):
|
488
|
+
Logger.warning("There are unbalanced range layers concatnated. Refer to the troubleshooting manual of \'Unbalanced \"concatenation\"\'.")
|
489
|
+
Logger.warning("first layer:{}, second layer:{}, if you add a scaling operation, recommended scaling:{}".format(first_node, second_node, scaling_formula))
|
490
|
+
|
491
|
+
judge_results.append((first_node, second_node, scaling_formula))
|
492
|
+
|
493
|
+
for layer_name in concat_layers_add.keys():
|
494
|
+
first_node = concat_layers_add[layer_name]
|
495
|
+
|
496
|
+
act_first_node = activations_float.get(first_node)
|
497
|
+
act_second_node = add_activations[first_node]
|
498
|
+
|
499
|
+
all_act_first_node, all_act_second_node = torch.cat(act_first_node), torch.cat(act_second_node)
|
500
|
+
min_act_first_node, min_act_second_node = torch.min(all_act_first_node).item(), torch.min(all_act_second_node).item()
|
501
|
+
max_act_first_node, max_act_second_node = torch.max(all_act_first_node).item(), torch.max(all_act_second_node).item()
|
502
|
+
|
503
|
+
# Calculate act range
|
504
|
+
range_first_node = max_act_first_node - min_act_first_node
|
505
|
+
range_second_node = max_act_second_node - min_act_second_node
|
506
|
+
|
507
|
+
# Calculate ratio
|
508
|
+
range_ratio = range_second_node / (range_first_node + 1e-10)
|
509
|
+
scaling_formula = "first layer * {}".format(range_ratio)
|
510
|
+
|
511
|
+
range_ratio_over1 = range_ratio if range_ratio >= 1.0 else 1/range_ratio
|
512
|
+
th_ratio = xquant_config.threshold_ratio_unbalanced_concatenation
|
513
|
+
if range_ratio_over1 >= th_ratio:
|
514
|
+
# Print to Console
|
515
|
+
if(len(judge_results) == 0):
|
516
|
+
Logger.warning("There are unbalanced range layers concatnated. Refer to the troubleshooting manual of \'Unbalanced \"concatenation\"\'.")
|
517
|
+
Logger.warning("first layer:{}, second layer:{}, if you add a scaling operation, recommended scaling:{}".format(first_node, second_node, scaling_formula))
|
518
|
+
|
519
|
+
judge_results.append((first_node, second_node, scaling_formula))
|
520
|
+
|
521
|
+
return judge_results
|
522
|
+
|
523
|
+
def judge_mixed_precision_with_model_output_loss_objective(quantized_model: torch.nn.Module,
|
524
|
+
xquant_config: XQuantConfig) -> str:
|
525
|
+
"""
|
526
|
+
Judge whether the bitwidth of the final layer is less than threshold
|
527
|
+
|
528
|
+
Args:
|
529
|
+
quantized_model (torch.nn.Module): The quantized Pytorch model.
|
530
|
+
xquant_config (XQuantConfig): Configuration settings for explainable quantization.
|
531
|
+
|
532
|
+
Returns:
|
533
|
+
str: The name of the final layer. If the layer does not match the condition, it returns None.
|
534
|
+
"""
|
535
|
+
threshold_bitwidth = xquant_config.threshold_bitwidth_mixed_precision_with_model_output_loss_objective
|
536
|
+
|
537
|
+
is_mixed_precision_with_model_output_loss_objective = False
|
538
|
+
|
539
|
+
last_layer_name, last_layer = list(quantized_model.named_children())[-2]
|
540
|
+
if(hasattr(last_layer, "weights_quantizers")):
|
541
|
+
bitwidth_weights = last_layer.weights_quantizers['weight'].num_bits
|
542
|
+
if(bitwidth_weights <= threshold_bitwidth):
|
543
|
+
is_mixed_precision_with_model_output_loss_objective = True
|
544
|
+
else:
|
545
|
+
bitwidth_weights = None
|
546
|
+
|
547
|
+
last_layer_activation_name, last_layer_activation = list(quantized_model.named_children())[-1]
|
548
|
+
if(hasattr(last_layer_activation, "activation_holder_quantizer")):
|
549
|
+
bitwidth_activation = last_layer_activation.activation_holder_quantizer.num_bits
|
550
|
+
if(bitwidth_activation <= threshold_bitwidth):
|
551
|
+
is_mixed_precision_with_model_output_loss_objective = True
|
552
|
+
else:
|
553
|
+
bitwidth_activation = None
|
554
|
+
|
555
|
+
if is_mixed_precision_with_model_output_loss_objective:
|
556
|
+
# Print to Console
|
557
|
+
Logger.warning("the quantization bitwidth of the last layer is an extremely small number. Refer to the troubleshooting manual of \'Mixed Precision with model output loss objective\'.")
|
558
|
+
Logger.warning("bidwidth of {}:{}(Weight), {}(Activation)".format(last_layer_name, bitwidth_weights, bitwidth_activation))
|
559
|
+
|
560
|
+
return [last_layer_name]
|
561
|
+
|
562
|
+
return []
|
@@ -19,18 +19,17 @@ from model_compression_toolkit.target_platform_capabilities.constants import DEF
|
|
19
19
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
20
20
|
AttachTpcToPytorch
|
21
21
|
|
22
|
-
from model_compression_toolkit.xquant.
|
23
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import
|
22
|
+
from model_compression_toolkit.xquant.pytorch.framework_report_utils import FrameworkReportUtils
|
23
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
24
24
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
25
25
|
from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
|
26
|
-
from model_compression_toolkit.xquant.
|
26
|
+
from model_compression_toolkit.xquant.pytorch.similarity_calculator import SimilarityCalculator
|
27
27
|
from model_compression_toolkit.xquant.pytorch.dataset_utils import PytorchDatasetUtils
|
28
28
|
from model_compression_toolkit.xquant.pytorch.model_analyzer import PytorchModelAnalyzer
|
29
29
|
from model_compression_toolkit.xquant.pytorch.similarity_functions import PytorchSimilarityFunctions
|
30
30
|
from model_compression_toolkit.xquant.pytorch.tensorboard_utils import PytorchTensorboardUtils
|
31
31
|
from mct_quantizers.pytorch.metadata import get_metadata
|
32
32
|
|
33
|
-
|
34
33
|
class PytorchReportUtils(FrameworkReportUtils):
|
35
34
|
"""
|
36
35
|
Class with various utility components required for generating the report for a Pytorch model.
|
@@ -40,6 +39,7 @@ class PytorchReportUtils(FrameworkReportUtils):
|
|
40
39
|
Args:
|
41
40
|
report_dir: Logging dir path.
|
42
41
|
"""
|
42
|
+
fw_info = DEFAULT_PYTORCH_INFO
|
43
43
|
fw_impl = PytorchImplementation()
|
44
44
|
# Set the default Target Platform Capabilities (TPC) for PyTorch.
|
45
45
|
default_tpc = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
@@ -47,7 +47,8 @@ class PytorchReportUtils(FrameworkReportUtils):
|
|
47
47
|
framework_quantization_capabilities = attach2pytorch.attach(default_tpc)
|
48
48
|
|
49
49
|
dataset_utils = PytorchDatasetUtils()
|
50
|
-
model_folding = ModelFoldingUtils(
|
50
|
+
model_folding = ModelFoldingUtils(fw_info=fw_info,
|
51
|
+
fw_impl=fw_impl,
|
51
52
|
fw_default_fqc=framework_quantization_capabilities)
|
52
53
|
|
53
54
|
similarity_calculator = SimilarityCalculator(dataset_utils=dataset_utils,
|
@@ -57,9 +58,11 @@ class PytorchReportUtils(FrameworkReportUtils):
|
|
57
58
|
device=get_working_device())
|
58
59
|
|
59
60
|
tb_utils = PytorchTensorboardUtils(report_dir=report_dir,
|
60
|
-
fw_impl=fw_impl
|
61
|
+
fw_impl=fw_impl,
|
62
|
+
fw_info=fw_info)
|
61
63
|
|
62
|
-
super().__init__(
|
64
|
+
super().__init__(fw_info=fw_info,
|
65
|
+
fw_impl=fw_impl,
|
63
66
|
tb_utils=tb_utils,
|
64
67
|
dataset_utils=dataset_utils,
|
65
68
|
similarity_calculator=similarity_calculator,
|
@@ -15,6 +15,7 @@
|
|
15
15
|
from functools import partial
|
16
16
|
|
17
17
|
from typing import Tuple, Any, Dict, Callable
|
18
|
+
import torch
|
18
19
|
|
19
20
|
from model_compression_toolkit.xquant.common.constants import MODEL_OUTPUT_KEY
|
20
21
|
from model_compression_toolkit.xquant.common.dataset_utils import DatasetUtils
|
@@ -57,6 +58,7 @@ class SimilarityCalculator:
|
|
57
58
|
similarity_metrics: Dict[str, Callable]) -> Dict[str, float]:
|
58
59
|
"""
|
59
60
|
Compute the similarity between two tensors using provided similarity metrics.
|
61
|
+
If the type of inputs are not tensor, the similarity is 0.0.
|
60
62
|
|
61
63
|
Args:
|
62
64
|
tensors_to_compare (Tuple[Any, Any]): Tensors to compare by computing their similarity.
|
@@ -66,7 +68,10 @@ class SimilarityCalculator:
|
|
66
68
|
Dict[str, float]: A dictionary of similarity metric names and their computed values.
|
67
69
|
"""
|
68
70
|
x, y = tensors_to_compare
|
69
|
-
|
71
|
+
if(torch.is_tensor(x) and torch.is_tensor(y)):
|
72
|
+
similarity_metrics = {k: v(x, y) for k, v in similarity_metrics.items()}
|
73
|
+
else:
|
74
|
+
similarity_metrics = {k: 0.0 for k, v in similarity_metrics.items()}
|
70
75
|
return similarity_metrics
|
71
76
|
|
72
77
|
def _get_float_to_quantized_compare_points(self,
|
@@ -41,15 +41,18 @@ class PytorchTensorboardUtils(TensorboardUtils):
|
|
41
41
|
|
42
42
|
def __init__(self,
|
43
43
|
report_dir: str,
|
44
|
+
fw_info: FrameworkInfo,
|
44
45
|
fw_impl: FrameworkImplementation):
|
45
46
|
"""
|
46
47
|
Initialize the PytorchTensorboardUtils instance.
|
47
48
|
|
48
49
|
Args:
|
49
50
|
report_dir: Directory where the reports are stored.
|
51
|
+
fw_info: Information about the framework being used.
|
50
52
|
fw_impl: Implementation methods for the framework.
|
51
53
|
"""
|
52
54
|
super().__init__(report_dir,
|
55
|
+
fw_info,
|
53
56
|
fw_impl)
|
54
57
|
|
55
58
|
def get_graph_for_tensorboard_display(self,
|