mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -0,0 +1,177 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from tqdm import tqdm
|
16
|
+
from typing import Callable, Any, Dict
|
17
|
+
|
18
|
+
from model_compression_toolkit.core.common.model_collector import ModelCollector
|
19
|
+
from model_compression_toolkit.xquant import XQuantConfig
|
20
|
+
from model_compression_toolkit.xquant.common.constants import OUTPUT_SIMILARITY_METRICS_REPR, OUTPUT_SIMILARITY_METRICS_VAL, INTERMEDIATE_SIMILARITY_METRICS_REPR, \
|
21
|
+
INTERMEDIATE_SIMILARITY_METRICS_VAL
|
22
|
+
from model_compression_toolkit.xquant.pytorch.framework_report_utils import FrameworkReportUtils
|
23
|
+
|
24
|
+
from model_compression_toolkit.xquant.pytorch.core_detect_degrade_layer import core_detect_degrade_layer
|
25
|
+
from model_compression_toolkit.xquant.pytorch.core_judge_troubleshoot import core_judge_troubleshoot
|
26
|
+
|
27
|
+
def core_report_generator(float_model: Any,
|
28
|
+
quantized_model: Any,
|
29
|
+
repr_dataset: Callable,
|
30
|
+
validation_dataset: Callable,
|
31
|
+
fw_report_utils: FrameworkReportUtils,
|
32
|
+
xquant_config: XQuantConfig) -> Dict[str, Any]:
|
33
|
+
"""
|
34
|
+
Generate report in tensorboard with a graph of the quantized model and similarity metrics that
|
35
|
+
have been measured when comparing to the float model (or any other two models).
|
36
|
+
The report also contains histograms that are collected on the baseline model (usually, the float
|
37
|
+
model).
|
38
|
+
|
39
|
+
Args:
|
40
|
+
float_model (Any): The original floating-point model.
|
41
|
+
quantized_model (Any): The model after quantization.
|
42
|
+
repr_dataset (Callable): Representative dataset used for similarity metrics computation.
|
43
|
+
validation_dataset (Callable): Validation dataset used for similarity metrics computation.
|
44
|
+
fw_report_utils (FrameworkReportUtils): Utilities for generating framework-specific reports.
|
45
|
+
xquant_config (XQuantConfig): Configuration settings for explainable quantization.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
Dict[str, Any]: A dictionary containing the collected similarity metrics and report data.
|
49
|
+
"""
|
50
|
+
# Get metadata from the quantized model
|
51
|
+
quantized_model_metadata = fw_report_utils.get_metadata_fn(quantized_model)
|
52
|
+
|
53
|
+
# Collect histograms on the float model.
|
54
|
+
float_graph = fw_report_utils.model_folding_utils.create_float_folded_graph(float_model, repr_dataset)
|
55
|
+
mi = ModelCollector(float_graph, fw_report_utils.fw_impl, fw_report_utils.fw_info)
|
56
|
+
for _data in tqdm(repr_dataset(), desc="Collecting Histograms"):
|
57
|
+
mi.infer(_data)
|
58
|
+
|
59
|
+
# Collect histograms and add them to Tensorboard.
|
60
|
+
fw_report_utils.tb_utils.add_histograms_to_tensorboard(graph=float_graph)
|
61
|
+
|
62
|
+
# Compute similarity metrics on representative dataset and validation set.
|
63
|
+
repr_similarity = fw_report_utils.similarity_calculator.compute_similarity_metrics(float_model=float_model,
|
64
|
+
quantized_model=quantized_model,
|
65
|
+
dataset=repr_dataset,
|
66
|
+
custom_similarity_metrics=xquant_config.custom_similarity_metrics)
|
67
|
+
val_similarity = fw_report_utils.similarity_calculator.compute_similarity_metrics(float_model=float_model,
|
68
|
+
quantized_model=quantized_model,
|
69
|
+
dataset=validation_dataset,
|
70
|
+
custom_similarity_metrics=xquant_config.custom_similarity_metrics,
|
71
|
+
is_validation=True)
|
72
|
+
similarity_metrics = {
|
73
|
+
OUTPUT_SIMILARITY_METRICS_REPR: repr_similarity[0],
|
74
|
+
OUTPUT_SIMILARITY_METRICS_VAL: val_similarity[0],
|
75
|
+
INTERMEDIATE_SIMILARITY_METRICS_REPR: repr_similarity[1],
|
76
|
+
INTERMEDIATE_SIMILARITY_METRICS_VAL: val_similarity[1]
|
77
|
+
}
|
78
|
+
|
79
|
+
# Add a graph of the quantized model with the similarity metrics to TensorBoard for visualization.
|
80
|
+
fw_report_utils.tb_utils.add_graph_to_tensorboard(quantized_model,
|
81
|
+
similarity_metrics,
|
82
|
+
repr_dataset,
|
83
|
+
quantized_model_metadata)
|
84
|
+
|
85
|
+
# Adds text information (like max cut and output similarity metrics) to the tensorboard writer.
|
86
|
+
fw_report_utils.tb_utils.add_text_information(similarity_metrics,
|
87
|
+
quantized_model_metadata)
|
88
|
+
|
89
|
+
# Save data to a json file.
|
90
|
+
fw_report_utils.dump_report_to_json(report_dir=xquant_config.report_dir,
|
91
|
+
collected_data=similarity_metrics)
|
92
|
+
|
93
|
+
return similarity_metrics
|
94
|
+
|
95
|
+
def core_report_generator_troubleshoot(float_model: Any,
|
96
|
+
quantized_model: Any,
|
97
|
+
repr_dataset: Callable,
|
98
|
+
validation_dataset: Callable,
|
99
|
+
fw_report_utils: FrameworkReportUtils,
|
100
|
+
xquant_config: XQuantConfig) -> Dict[str, Any]:
|
101
|
+
"""
|
102
|
+
Generate report in tensorboard with a graph of the quantized model and similarity metrics that
|
103
|
+
have been measured when comparing to the float model (or any other two models).
|
104
|
+
The report also contains histograms that are collected on the baseline model (usually, the float
|
105
|
+
model).
|
106
|
+
Furthermore, detect degraded layers by quantize similarities and analyze degrade causes by float_model and quantized_model.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
float_model (Any): The original floating-point model.
|
110
|
+
quantized_model (Any): The model after quantization.
|
111
|
+
repr_dataset (Callable): Representative dataset used for similarity metrics computation.
|
112
|
+
validation_dataset (Callable): Validation dataset used for similarity metrics computation.
|
113
|
+
fw_report_utils (FrameworkReportUtils): Utilities for generating framework-specific reports.
|
114
|
+
xquant_config (XQuantConfig): Configuration settings for explainable quantization.
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
Dict[str, Any]: A dictionary containing the collected similarity metrics and report data.
|
118
|
+
Dict[str, Any]: A dictionary containing the analyze degrade cause report for degraded layaers.
|
119
|
+
"""
|
120
|
+
# Get metadata from the quantized model
|
121
|
+
quantized_model_metadata = fw_report_utils.get_metadata_fn(quantized_model)
|
122
|
+
|
123
|
+
# Collect histograms on the float model.
|
124
|
+
float_graph = fw_report_utils.model_folding_utils.create_float_folded_graph(float_model, repr_dataset)
|
125
|
+
mi = ModelCollector(float_graph, fw_report_utils.fw_impl, fw_report_utils.fw_info)
|
126
|
+
for _data in tqdm(repr_dataset(), desc="Collecting Histograms"):
|
127
|
+
mi.infer(_data)
|
128
|
+
|
129
|
+
# Collect histograms and add them to Tensorboard.
|
130
|
+
fw_report_utils.tb_utils.add_histograms_to_tensorboard(graph=float_graph)
|
131
|
+
|
132
|
+
# Compute similarity metrics on representative dataset and validation set.
|
133
|
+
repr_similarity = fw_report_utils.similarity_calculator.compute_similarity_metrics(float_model=float_model,
|
134
|
+
quantized_model=quantized_model,
|
135
|
+
dataset=repr_dataset,
|
136
|
+
custom_similarity_metrics=xquant_config.custom_similarity_metrics)
|
137
|
+
val_similarity = fw_report_utils.similarity_calculator.compute_similarity_metrics(float_model=float_model,
|
138
|
+
quantized_model=quantized_model,
|
139
|
+
dataset=validation_dataset,
|
140
|
+
custom_similarity_metrics=xquant_config.custom_similarity_metrics,
|
141
|
+
is_validation=True)
|
142
|
+
similarity_metrics = {
|
143
|
+
OUTPUT_SIMILARITY_METRICS_REPR: repr_similarity[0],
|
144
|
+
OUTPUT_SIMILARITY_METRICS_VAL: val_similarity[0],
|
145
|
+
INTERMEDIATE_SIMILARITY_METRICS_REPR: repr_similarity[1],
|
146
|
+
INTERMEDIATE_SIMILARITY_METRICS_VAL: val_similarity[1]
|
147
|
+
}
|
148
|
+
|
149
|
+
# Add a graph of the quantized model with the similarity metrics to TensorBoard for visualization.
|
150
|
+
fw_report_utils.tb_utils.add_graph_to_tensorboard(quantized_model,
|
151
|
+
similarity_metrics,
|
152
|
+
repr_dataset,
|
153
|
+
quantized_model_metadata)
|
154
|
+
|
155
|
+
# Adds text information (like max cut and output similarity metrics) to the tensorboard writer.
|
156
|
+
fw_report_utils.tb_utils.add_text_information(similarity_metrics,
|
157
|
+
quantized_model_metadata)
|
158
|
+
|
159
|
+
# Save data to a json file.
|
160
|
+
fw_report_utils.dump_report_to_json(report_dir=xquant_config.report_dir,
|
161
|
+
collected_data=similarity_metrics)
|
162
|
+
|
163
|
+
# Detect degraded layers by quantize similarities
|
164
|
+
degrade_layers = core_detect_degrade_layer(repr_similarity, val_similarity, xquant_config)
|
165
|
+
|
166
|
+
# Analyze degrade causes by float_model and quantized_model. (If degrade_layer_ratio < threshold)
|
167
|
+
degrade_layer_ratio = len(degrade_layers) / len(repr_similarity[1].keys())
|
168
|
+
if(degrade_layer_ratio < xquant_config.threshold_degrade_layer_ratio):
|
169
|
+
_troubleshoot_data = core_judge_troubleshoot(float_model, quantized_model, float_graph, degrade_layers, repr_dataset, xquant_config)
|
170
|
+
else:
|
171
|
+
_troubleshoot_data = {}
|
172
|
+
|
173
|
+
# Save data to a json file.
|
174
|
+
fw_report_utils.dump_troubleshoot_report_to_json(report_dir=xquant_config.report_dir,
|
175
|
+
collected_data=_troubleshoot_data)
|
176
|
+
|
177
|
+
return similarity_metrics, _troubleshoot_data
|
@@ -0,0 +1,78 @@
|
|
1
|
+
# Copyright 2025 Sony Semiconductor Solutions. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from model_compression_toolkit.xquant import XQuantConfig
|
17
|
+
import matplotlib.pyplot as plt
|
18
|
+
|
19
|
+
def make_similarity_graph(metrics_name: str,
|
20
|
+
dataset_name: str,
|
21
|
+
intermediate_similarity: dict,
|
22
|
+
degrade_layers: list[str],
|
23
|
+
xquant_config: XQuantConfig) -> None:
|
24
|
+
"""
|
25
|
+
Detect degrade layers by caliculated similarities by XQuant.
|
26
|
+
And Draw and save similarity graphs.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
metrics_name (str): Metrics name of caluclualting quantized error. (i.g. 'mse')
|
30
|
+
dataset_name (str): Dataset name ('repr'(Represantation dataset) or 'val'(Validation dataset))
|
31
|
+
intermediate_similarity (dict): Quant error reports per layers.
|
32
|
+
degrade_layers (list[str]): A list of detected degrade layers.
|
33
|
+
xquant_config (XQuantConfig): Configuration settings for explainable quantization.
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
None
|
37
|
+
"""
|
38
|
+
|
39
|
+
plot_title = "quant_loss_{}_{}".format(metrics_name, dataset_name)
|
40
|
+
|
41
|
+
# Get x,y from intermediate_similarity
|
42
|
+
data_x = []
|
43
|
+
data_y = []
|
44
|
+
for key_layername in intermediate_similarity.keys():
|
45
|
+
value_metrics = intermediate_similarity[key_layername][metrics_name]
|
46
|
+
data_x.append(key_layername)
|
47
|
+
data_y.append(value_metrics)
|
48
|
+
degrade_layer_indexes = [data_x.index(layer) for layer in degrade_layers]
|
49
|
+
|
50
|
+
# Make Graph(adjust size to num layers)
|
51
|
+
plt.figure(figsize=(max(int(len(data_x)/3), 1), 5))
|
52
|
+
# Draw plot with red circle markers of degrade layers
|
53
|
+
plt.plot(data_x, data_y, markevery=degrade_layer_indexes, marker='o', markersize=20, markeredgecolor='r', markerfacecolor=[0.0, 0.0, 0.0, 0.0], markeredgewidth=3)
|
54
|
+
plt.grid()
|
55
|
+
# Add labels
|
56
|
+
plt.title(plot_title)
|
57
|
+
plt.xlabel("layer name")
|
58
|
+
plt.ylabel(metrics_name)
|
59
|
+
plt.xticks(rotation=90)
|
60
|
+
# Label colors of degraded layers are red
|
61
|
+
_, plt_x_labels = plt.xticks()
|
62
|
+
for degrade_layer_index in degrade_layer_indexes:
|
63
|
+
plt_x_labels[degrade_layer_index].set_color("red")
|
64
|
+
# Add threshold line
|
65
|
+
threshold = xquant_config.threshold_quantize_error[metrics_name]
|
66
|
+
plt.hlines(threshold, 0, len(data_x)-1, "red", linestyles='dashed')
|
67
|
+
plt.text(0, threshold+(max(data_y)-min(data_y))*0.02, "threshold={}".format(threshold), color="red")
|
68
|
+
|
69
|
+
plt.tight_layout()
|
70
|
+
|
71
|
+
# Save Image
|
72
|
+
path_outfile = "{}/{}.png".format(xquant_config.report_dir, plot_title, dataset_name)
|
73
|
+
plt.savefig(path_outfile)
|
74
|
+
|
75
|
+
plt.clf()
|
76
|
+
plt.close()
|
77
|
+
|
78
|
+
return None
|
@@ -16,12 +16,12 @@
|
|
16
16
|
from typing import Callable
|
17
17
|
|
18
18
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
19
|
-
from model_compression_toolkit.xquant.common.core_report_generator import core_report_generator
|
20
19
|
from model_compression_toolkit.xquant import XQuantConfig
|
21
20
|
from model_compression_toolkit.logger import Logger
|
22
21
|
|
23
22
|
if FOUND_TORCH:
|
24
23
|
from model_compression_toolkit.xquant.pytorch.pytorch_report_utils import PytorchReportUtils
|
24
|
+
from model_compression_toolkit.xquant.pytorch.core_report_generator import core_report_generator, core_report_generator_troubleshoot
|
25
25
|
import torch
|
26
26
|
|
27
27
|
def xquant_report_pytorch_experimental(float_model: torch.nn.Module,
|
@@ -58,7 +58,47 @@ if FOUND_TORCH:
|
|
58
58
|
|
59
59
|
return _collected_data
|
60
60
|
|
61
|
+
def xquant_report_troubleshoot_pytorch_experimental(float_model: torch.nn.Module,
|
62
|
+
quantized_model: torch.nn.Module,
|
63
|
+
repr_dataset: Callable,
|
64
|
+
validation_dataset: Callable,
|
65
|
+
xquant_config: XQuantConfig):
|
66
|
+
"""
|
67
|
+
Generate an explainable quantization report, detect degraded layaers and judge degrade causes for a quantized Pytorch model.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
float_model (torch.nn.Module): The original floating-point Pytorch model.
|
71
|
+
quantized_model (torch.nn.Module): The quantized Pytorch model.
|
72
|
+
repr_dataset (Callable): The representative dataset used during quantization.
|
73
|
+
validation_dataset (Callable): The validation dataset used for evaluation.
|
74
|
+
xquant_config (XQuantConfig): Configuration settings for explainable quantization.
|
75
|
+
|
76
|
+
Returns:
|
77
|
+
Dict[str, Any]: A dictionary containing the collected similarity metrics and report data.
|
78
|
+
Dict[str, Any]: A dictionary containing the analyze degrade cause report for degraded layaers.
|
79
|
+
"""
|
80
|
+
# Initialize the logger with the report directory.
|
81
|
+
Logger.set_log_file(log_folder=xquant_config.report_dir)
|
82
|
+
|
83
|
+
pytorch_report_utils = PytorchReportUtils(xquant_config.report_dir)
|
84
|
+
|
85
|
+
_collected_data, _troubleshoot_data = core_report_generator_troubleshoot(float_model=float_model,
|
86
|
+
quantized_model=quantized_model,
|
87
|
+
repr_dataset=repr_dataset,
|
88
|
+
validation_dataset=validation_dataset,
|
89
|
+
fw_report_utils=pytorch_report_utils,
|
90
|
+
xquant_config=xquant_config)
|
91
|
+
|
92
|
+
Logger.shutdown()
|
93
|
+
|
94
|
+
return _collected_data, _troubleshoot_data
|
95
|
+
|
96
|
+
|
61
97
|
else:
|
62
98
|
def xquant_report_pytorch_experimental(*args, **kwargs):
|
63
99
|
Logger.critical("PyTorch must be installed to use 'xquant_report_pytorch_experimental'. "
|
64
100
|
"The 'torch' package is missing.") # pragma: no cover
|
101
|
+
|
102
|
+
def xquant_report_troubleshoot_pytorch_experimental(*args, **kwargs):
|
103
|
+
Logger.critical("PyTorch must be installed to use 'xquant_report_troubleshoot_pytorch_experimental'. "
|
104
|
+
"The 'torch' package is missing.") # pragma: no cover
|
@@ -0,0 +1,98 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import json
|
17
|
+
import os
|
18
|
+
|
19
|
+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
20
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
21
|
+
from typing import Any, Dict, Callable
|
22
|
+
|
23
|
+
from model_compression_toolkit.xquant.common.constants import REPORT_FILENAME, TROUBLESHOOT_REPORT_FILENAME
|
24
|
+
from model_compression_toolkit.xquant.common.dataset_utils import DatasetUtils
|
25
|
+
from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
|
26
|
+
from model_compression_toolkit.xquant.pytorch.similarity_calculator import SimilarityCalculator
|
27
|
+
from model_compression_toolkit.xquant.common.tensorboard_utils import TensorboardUtils
|
28
|
+
from model_compression_toolkit.logger import Logger
|
29
|
+
|
30
|
+
|
31
|
+
class FrameworkReportUtils:
|
32
|
+
"""
|
33
|
+
Class with various utility components required for generating the report in a specific framework.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(self,
|
37
|
+
fw_info: FrameworkInfo,
|
38
|
+
fw_impl: FrameworkImplementation,
|
39
|
+
similarity_calculator: SimilarityCalculator,
|
40
|
+
dataset_utils: DatasetUtils,
|
41
|
+
model_folding_utils: ModelFoldingUtils,
|
42
|
+
tb_utils: TensorboardUtils,
|
43
|
+
get_metadata_fn: Callable):
|
44
|
+
"""
|
45
|
+
Initializes the FrameworkReportUtils class with various utility components required for generating the report.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
fw_info (FrameworkInfo): Information about the framework being used.
|
49
|
+
fw_impl (FrameworkImplementation): The implemented functions of the framework.
|
50
|
+
similarity_calculator (SimilarityCalculator): A utility for calculating similarity metrics.
|
51
|
+
dataset_utils (DatasetUtils): Utilities for handling datasets.
|
52
|
+
model_folding_utils (ModelFoldingUtils): Utilities for model folding operations.
|
53
|
+
tb_utils (TensorboardUtils): Utilities for TensorBoard operations.
|
54
|
+
get_metadata_fn (Callable): Function to retrieve the metadata from the quantized model.
|
55
|
+
"""
|
56
|
+
self.fw_info = fw_info
|
57
|
+
self.fw_impl = fw_impl
|
58
|
+
self.similarity_calculator = similarity_calculator
|
59
|
+
self.dataset_utils = dataset_utils
|
60
|
+
self.model_folding_utils = model_folding_utils
|
61
|
+
self.tb_utils = tb_utils
|
62
|
+
self.get_metadata_fn = get_metadata_fn
|
63
|
+
|
64
|
+
def dump_report_to_json(self,
|
65
|
+
report_dir: str,
|
66
|
+
collected_data: Dict[str, Any]):
|
67
|
+
"""
|
68
|
+
Dump the collected data (similarity, etc.) into a JSON file.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
report_dir (str): Directory where the report will be saved.
|
72
|
+
collected_data (Dict[str, Any]): Data collected during report generation.
|
73
|
+
|
74
|
+
"""
|
75
|
+
report_file_name = os.path.join(report_dir, REPORT_FILENAME)
|
76
|
+
report_file_name = os.path.abspath(report_file_name)
|
77
|
+
Logger.info(f"Dumping report data to: {report_file_name}")
|
78
|
+
|
79
|
+
with open(report_file_name, 'w') as f:
|
80
|
+
json.dump(collected_data, f, indent=4)
|
81
|
+
|
82
|
+
def dump_troubleshoot_report_to_json(self,
|
83
|
+
report_dir: str,
|
84
|
+
collected_data: Dict[str, Any]):
|
85
|
+
"""
|
86
|
+
Dump the collected data (similarity, etc.) into a JSON file.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
report_dir (str): Directory where the report will be saved.
|
90
|
+
collected_data (Dict[str, Any]): Data collected during report generation.
|
91
|
+
|
92
|
+
"""
|
93
|
+
report_file_name = os.path.join(report_dir, TROUBLESHOOT_REPORT_FILENAME)
|
94
|
+
report_file_name = os.path.abspath(report_file_name)
|
95
|
+
Logger.info(f"Dumping report data to: {report_file_name}")
|
96
|
+
|
97
|
+
with open(report_file_name, 'w') as f:
|
98
|
+
json.dump(collected_data, f, indent=4)
|