mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250927.534__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -34,6 +34,7 @@ class ModelFoldingUtils:
|
|
34
34
|
"""
|
35
35
|
|
36
36
|
def __init__(self,
|
37
|
+
fw_info: FrameworkInfo,
|
37
38
|
fw_impl: FrameworkImplementation,
|
38
39
|
fw_default_fqc: FrameworkQuantizationCapabilities):
|
39
40
|
"""
|
@@ -41,9 +42,11 @@ class ModelFoldingUtils:
|
|
41
42
|
and default FQC.
|
42
43
|
|
43
44
|
Args:
|
45
|
+
fw_info: Framework-specific information.
|
44
46
|
fw_impl: Implementation functions for the framework.
|
45
47
|
fw_default_fqc: Default target platform capabilities for the handled framework.
|
46
48
|
"""
|
49
|
+
self.fw_info = fw_info
|
47
50
|
self.fw_impl = fw_impl
|
48
51
|
self.fw_default_fqc = fw_default_fqc
|
49
52
|
|
@@ -66,7 +69,8 @@ class ModelFoldingUtils:
|
|
66
69
|
float_folded_model, _ = self.fw_impl.model_builder(
|
67
70
|
float_graph,
|
68
71
|
mode=ModelBuilderMode.FLOAT,
|
69
|
-
append2output=None
|
72
|
+
append2output=None,
|
73
|
+
fw_info=self.fw_info
|
70
74
|
)
|
71
75
|
return float_folded_model
|
72
76
|
|
@@ -96,6 +100,7 @@ class ModelFoldingUtils:
|
|
96
100
|
graph = graph_preparation_runner(in_model=model,
|
97
101
|
representative_data_gen=repr_dataset,
|
98
102
|
fw_impl=self.fw_impl,
|
103
|
+
fw_info=self.fw_info,
|
99
104
|
quantization_config=DEFAULTCONFIG,
|
100
105
|
fqc=self.fw_default_fqc)
|
101
106
|
return graph
|
@@ -36,16 +36,19 @@ class TensorboardUtils:
|
|
36
36
|
|
37
37
|
def __init__(self,
|
38
38
|
report_dir: str,
|
39
|
+
fw_info: FrameworkInfo,
|
39
40
|
fw_impl: FrameworkImplementation):
|
40
41
|
"""
|
41
42
|
Initialize the TensorboardUtils.
|
42
43
|
|
43
44
|
Args:
|
44
45
|
report_dir (str): Directory where Tensorboard logs will be stored.
|
46
|
+
fw_info (FrameworkInfo): Framework-specific information.
|
45
47
|
fw_impl (FrameworkImplementation): Framework-specific implementation.
|
46
48
|
"""
|
47
49
|
self.fw_impl = fw_impl
|
48
|
-
self.
|
50
|
+
self.fw_info = fw_info
|
51
|
+
self.tb_writer = TensorboardWriter(report_dir, fw_info)
|
49
52
|
Logger.info(f"Please run: tensorboard --logdir {self.tb_writer.dir_path}")
|
50
53
|
|
51
54
|
def get_graph_for_tensorboard_display(self,
|
@@ -14,6 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
from typing import Dict, Callable
|
17
|
+
from model_compression_toolkit.logger import Logger
|
17
18
|
|
18
19
|
|
19
20
|
class XQuantConfig:
|
@@ -25,13 +26,38 @@ class XQuantConfig:
|
|
25
26
|
|
26
27
|
def __init__(self,
|
27
28
|
report_dir: str,
|
28
|
-
custom_similarity_metrics: Dict[str, Callable] = None
|
29
|
+
custom_similarity_metrics: Dict[str, Callable] = None,
|
30
|
+
quantize_reported_dir: str = None,
|
31
|
+
threshold_quantize_error: Dict[str, float] = {"mse": 0.1, "cs": 0.1, "sqnr": 0.1},
|
32
|
+
is_detect_under_threshold_quantize_error: Dict[str, bool] = {"mse": False, "cs": True, "sqnr": True},
|
33
|
+
threshold_degrade_layer_ratio: float = 0.5,
|
34
|
+
threshold_zscore_outlier_removal: float = 5.0,
|
35
|
+
threshold_ratio_unbalanced_concatenation: float = 16.0,
|
36
|
+
threshold_bitwidth_mixed_precision_with_model_output_loss_objective: int = 2
|
37
|
+
):
|
29
38
|
"""
|
30
39
|
Initializes the configuration for explainable quantization.
|
31
40
|
|
32
41
|
Args:
|
33
42
|
report_dir (str): Directory where the reports will be saved.
|
34
43
|
custom_similarity_metrics (Dict[str, Callable]): Custom similarity metrics to be computed between tensors of the two models. The dictionary keys are similarity metric names and the values are callables that implement the similarity metric computation.
|
44
|
+
quantize_reported_dir (str): Directory where the the quantization log will be saved.
|
45
|
+
threshold_quantize_error (Dict[str, float]): Threshold values for detecting degradation in accuracy.
|
46
|
+
is_detect_under_threshold_quantize_error (Dict[str, bool]): For each threshold specified in threshold_quantize_error, True: detect the layer as degraded when the error is below the threshold.; False: detect the layer as degraded when the error is above the threshold.
|
47
|
+
threshold_degrade_layer_ratio (float): If the number of layers detected as degraded is large, skips the judge degradation causes Specify the ratio here.
|
48
|
+
threshold_zscore_outlier_removal (float): Used in judge degradation causes (Outlier Removal). Threshold for z_score to detect outliers.
|
49
|
+
threshold_ratio_unbalanced_concatenation (float): Used in judge degradation causes (unbalanced “concatnation”). Threshold for the multiplier of range width between concatenated layers.
|
50
|
+
threshold_bitwidth_mixed_precision_with_model_output_loss_objective (int): Used in judge degradation causes (Mixed precision with model output loss objective). Bitwidth of the final layer to judge insufficient bitwidth.
|
35
51
|
"""
|
52
|
+
|
36
53
|
self.report_dir = report_dir
|
37
54
|
self.custom_similarity_metrics = custom_similarity_metrics
|
55
|
+
self.quantize_reported_dir = quantize_reported_dir
|
56
|
+
if(self.quantize_reported_dir is None):
|
57
|
+
self.quantize_reported_dir = Logger.LOG_PATH
|
58
|
+
self.threshold_quantize_error = threshold_quantize_error
|
59
|
+
self.is_detect_under_threshold_quantize_error = is_detect_under_threshold_quantize_error
|
60
|
+
self.threshold_degrade_layer_ratio = threshold_degrade_layer_ratio
|
61
|
+
self.threshold_zscore_outlier_removal = threshold_zscore_outlier_removal
|
62
|
+
self.threshold_ratio_unbalanced_concatenation = threshold_ratio_unbalanced_concatenation
|
63
|
+
self.threshold_bitwidth_mixed_precision_with_model_output_loss_objective = threshold_bitwidth_mixed_precision_with_model_output_loss_objective
|
@@ -19,7 +19,7 @@ from model_compression_toolkit.core.common.model_collector import ModelCollector
|
|
19
19
|
from model_compression_toolkit.xquant import XQuantConfig
|
20
20
|
from model_compression_toolkit.xquant.common.constants import OUTPUT_SIMILARITY_METRICS_REPR, OUTPUT_SIMILARITY_METRICS_VAL, INTERMEDIATE_SIMILARITY_METRICS_REPR, \
|
21
21
|
INTERMEDIATE_SIMILARITY_METRICS_VAL
|
22
|
-
from model_compression_toolkit.xquant.
|
22
|
+
from model_compression_toolkit.xquant.keras.framework_report_utils import FrameworkReportUtils
|
23
23
|
|
24
24
|
|
25
25
|
def core_report_generator(float_model: Any,
|
@@ -50,7 +50,7 @@ def core_report_generator(float_model: Any,
|
|
50
50
|
|
51
51
|
# Collect histograms on the float model.
|
52
52
|
float_graph = fw_report_utils.model_folding_utils.create_float_folded_graph(float_model, repr_dataset)
|
53
|
-
mi = ModelCollector(float_graph, fw_report_utils.fw_impl)
|
53
|
+
mi = ModelCollector(float_graph, fw_report_utils.fw_impl, fw_report_utils.fw_info)
|
54
54
|
for _data in tqdm(repr_dataset(), desc="Collecting Histograms"):
|
55
55
|
mi.infer(_data)
|
56
56
|
|
@@ -16,13 +16,13 @@
|
|
16
16
|
from typing import Callable, Dict, Any
|
17
17
|
|
18
18
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
19
|
-
from model_compression_toolkit.xquant.common.core_report_generator import core_report_generator
|
20
19
|
from model_compression_toolkit.xquant import XQuantConfig
|
21
20
|
from model_compression_toolkit.logger import Logger
|
22
21
|
|
23
22
|
if FOUND_TF:
|
24
23
|
import keras
|
25
24
|
from model_compression_toolkit.xquant.keras.keras_report_utils import KerasReportUtils
|
25
|
+
from model_compression_toolkit.xquant.keras.core_report_generator import core_report_generator
|
26
26
|
|
27
27
|
def xquant_report_keras_experimental(float_model: keras.Model,
|
28
28
|
quantized_model: keras.Model,
|
@@ -20,10 +20,10 @@ from model_compression_toolkit.core.common.framework_implementation import Frame
|
|
20
20
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
21
21
|
from typing import Any, Dict, Callable
|
22
22
|
|
23
|
-
from model_compression_toolkit.xquant.common.constants import REPORT_FILENAME
|
23
|
+
from model_compression_toolkit.xquant.common.constants import REPORT_FILENAME, TROUBLESHOOT_REPORT_FILENAME
|
24
24
|
from model_compression_toolkit.xquant.common.dataset_utils import DatasetUtils
|
25
25
|
from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
|
26
|
-
from model_compression_toolkit.xquant.
|
26
|
+
from model_compression_toolkit.xquant.keras.similarity_calculator import SimilarityCalculator
|
27
27
|
from model_compression_toolkit.xquant.common.tensorboard_utils import TensorboardUtils
|
28
28
|
from model_compression_toolkit.logger import Logger
|
29
29
|
|
@@ -34,6 +34,7 @@ class FrameworkReportUtils:
|
|
34
34
|
"""
|
35
35
|
|
36
36
|
def __init__(self,
|
37
|
+
fw_info: FrameworkInfo,
|
37
38
|
fw_impl: FrameworkImplementation,
|
38
39
|
similarity_calculator: SimilarityCalculator,
|
39
40
|
dataset_utils: DatasetUtils,
|
@@ -44,6 +45,7 @@ class FrameworkReportUtils:
|
|
44
45
|
Initializes the FrameworkReportUtils class with various utility components required for generating the report.
|
45
46
|
|
46
47
|
Args:
|
48
|
+
fw_info (FrameworkInfo): Information about the framework being used.
|
47
49
|
fw_impl (FrameworkImplementation): The implemented functions of the framework.
|
48
50
|
similarity_calculator (SimilarityCalculator): A utility for calculating similarity metrics.
|
49
51
|
dataset_utils (DatasetUtils): Utilities for handling datasets.
|
@@ -51,6 +53,7 @@ class FrameworkReportUtils:
|
|
51
53
|
tb_utils (TensorboardUtils): Utilities for TensorBoard operations.
|
52
54
|
get_metadata_fn (Callable): Function to retrieve the metadata from the quantized model.
|
53
55
|
"""
|
56
|
+
self.fw_info = fw_info
|
54
57
|
self.fw_impl = fw_impl
|
55
58
|
self.similarity_calculator = similarity_calculator
|
56
59
|
self.dataset_utils = dataset_utils
|
@@ -75,3 +78,21 @@ class FrameworkReportUtils:
|
|
75
78
|
|
76
79
|
with open(report_file_name, 'w') as f:
|
77
80
|
json.dump(collected_data, f, indent=4)
|
81
|
+
|
82
|
+
def dump_troubleshoot_report_to_json(self,
|
83
|
+
report_dir: str,
|
84
|
+
collected_data: Dict[str, Any]):
|
85
|
+
"""
|
86
|
+
Dump the collected data (similarity, etc.) into a JSON file.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
report_dir (str): Directory where the report will be saved.
|
90
|
+
collected_data (Dict[str, Any]): Data collected during report generation.
|
91
|
+
|
92
|
+
"""
|
93
|
+
report_file_name = os.path.join(report_dir, TROUBLESHOOT_REPORT_FILENAME)
|
94
|
+
report_file_name = os.path.abspath(report_file_name)
|
95
|
+
Logger.info(f"Dumping report data to: {report_file_name}")
|
96
|
+
|
97
|
+
with open(report_file_name, 'w') as f:
|
98
|
+
json.dump(collected_data, f, indent=4)
|
@@ -15,12 +15,13 @@
|
|
15
15
|
|
16
16
|
from model_compression_toolkit import get_target_platform_capabilities
|
17
17
|
from model_compression_toolkit.constants import TENSORFLOW
|
18
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
18
19
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
19
20
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
20
21
|
AttachTpcToKeras
|
21
|
-
from model_compression_toolkit.xquant.
|
22
|
+
from model_compression_toolkit.xquant.keras.framework_report_utils import FrameworkReportUtils
|
22
23
|
from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
|
23
|
-
from model_compression_toolkit.xquant.
|
24
|
+
from model_compression_toolkit.xquant.keras.similarity_calculator import SimilarityCalculator
|
24
25
|
from model_compression_toolkit.xquant.keras.dataset_utils import KerasDatasetUtils
|
25
26
|
from model_compression_toolkit.xquant.keras.model_analyzer import KerasModelAnalyzer
|
26
27
|
|
@@ -39,6 +40,7 @@ class KerasReportUtils(FrameworkReportUtils):
|
|
39
40
|
Args:
|
40
41
|
report_dir: Logging dir path.
|
41
42
|
"""
|
43
|
+
fw_info = DEFAULT_KERAS_INFO
|
42
44
|
fw_impl = KerasImplementation()
|
43
45
|
|
44
46
|
# Set the default Target Platform Capabilities (TPC) for Keras.
|
@@ -47,7 +49,8 @@ class KerasReportUtils(FrameworkReportUtils):
|
|
47
49
|
framework_platform_capabilities = attach2pytorch.attach(default_tpc)
|
48
50
|
|
49
51
|
dataset_utils = KerasDatasetUtils()
|
50
|
-
model_folding = ModelFoldingUtils(
|
52
|
+
model_folding = ModelFoldingUtils(fw_info=fw_info,
|
53
|
+
fw_impl=fw_impl,
|
51
54
|
fw_default_fqc=framework_platform_capabilities)
|
52
55
|
|
53
56
|
similarity_calculator = SimilarityCalculator(dataset_utils=dataset_utils,
|
@@ -56,8 +59,10 @@ class KerasReportUtils(FrameworkReportUtils):
|
|
56
59
|
model_analyzer_utils=KerasModelAnalyzer())
|
57
60
|
|
58
61
|
tb_utils = KerasTensorboardUtils(report_dir=report_dir,
|
59
|
-
fw_impl=fw_impl
|
60
|
-
|
62
|
+
fw_impl=fw_impl,
|
63
|
+
fw_info=fw_info)
|
64
|
+
super().__init__(fw_info,
|
65
|
+
fw_impl,
|
61
66
|
similarity_calculator,
|
62
67
|
dataset_utils,
|
63
68
|
model_folding,
|
@@ -0,0 +1,199 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from functools import partial
|
16
|
+
|
17
|
+
from typing import Tuple, Any, Dict, Callable
|
18
|
+
import numpy as np
|
19
|
+
|
20
|
+
from model_compression_toolkit.xquant.common.constants import MODEL_OUTPUT_KEY
|
21
|
+
from model_compression_toolkit.xquant.common.dataset_utils import DatasetUtils
|
22
|
+
from model_compression_toolkit.xquant.common.model_analyzer import ModelAnalyzer
|
23
|
+
from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
|
24
|
+
from model_compression_toolkit.xquant.common.similarity_functions import SimilarityFunctions
|
25
|
+
from model_compression_toolkit.logger import Logger
|
26
|
+
|
27
|
+
class SimilarityCalculator:
|
28
|
+
"""
|
29
|
+
A class to calculate the similarity between two models (that are often referred as float
|
30
|
+
and quantized models). It utilizes various utility classes for dataset preparation, model folding,
|
31
|
+
similarity computation, and model analysis.
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(self,
|
35
|
+
dataset_utils: DatasetUtils,
|
36
|
+
model_folding: ModelFoldingUtils,
|
37
|
+
similarity_functions: SimilarityFunctions,
|
38
|
+
model_analyzer_utils: ModelAnalyzer,
|
39
|
+
device: str = None):
|
40
|
+
"""
|
41
|
+
Initialize the SimilarityCalculator with required utilities.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
dataset_utils (DatasetUtils): Utility class for dataset preparation.
|
45
|
+
model_folding (ModelFoldingUtils): Utility class for model folding operations.
|
46
|
+
similarity_functions (SimilarityFunctions): Class containing similarity functions.
|
47
|
+
model_analyzer_utils (ModelAnalyzer): Utility class for model analysis.
|
48
|
+
device (str, optional): Device to perform computations on (e.g., 'cpu', 'cuda'). Defaults to None.
|
49
|
+
"""
|
50
|
+
self.dataset_utils = dataset_utils
|
51
|
+
self.model_folding = model_folding
|
52
|
+
self.similarity_functions = similarity_functions
|
53
|
+
self.model_analyzer_utils = model_analyzer_utils
|
54
|
+
self.device = device
|
55
|
+
|
56
|
+
@staticmethod
|
57
|
+
def compute_tensors_similarity(tensors_to_compare: Tuple[Any, Any],
|
58
|
+
similarity_metrics: Dict[str, Callable]) -> Dict[str, float]:
|
59
|
+
"""
|
60
|
+
Compute the similarity between two tensors using provided similarity metrics.
|
61
|
+
If the type of inputs are not tensor, the similarity is 0.0.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
tensors_to_compare (Tuple[Any, Any]): Tensors to compare by computing their similarity.
|
65
|
+
similarity_metrics (Dict[str, Callable]): A dictionary with similarity metric names and functions.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
Dict[str, float]: A dictionary of similarity metric names and their computed values.
|
69
|
+
"""
|
70
|
+
x, y = tensors_to_compare
|
71
|
+
if(isinstance(x, np.ndarray) and isinstance(y, np.ndarray)):
|
72
|
+
similarity_metrics = {k: v(x, y) for k, v in similarity_metrics.items()}
|
73
|
+
else:
|
74
|
+
similarity_metrics = {k: 0.0 for k, v in similarity_metrics.items()}
|
75
|
+
return similarity_metrics
|
76
|
+
|
77
|
+
def _get_float_to_quantized_compare_points(self,
|
78
|
+
quantized_model: Any,
|
79
|
+
float_model: Any) -> Dict[str, str]:
|
80
|
+
"""
|
81
|
+
Map corresponding layers between the float and quantized models for comparison.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
quantized_model (Any): The quantized model.
|
85
|
+
float_model (Any): The float model.
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
Dict[str, str]: A dictionary mapping float model layer names to quantized model layer names.
|
89
|
+
"""
|
90
|
+
# Identify the points in the quantized model to compare.
|
91
|
+
quant_points_names = self.model_analyzer_utils.identify_quantized_compare_points(quantized_model)
|
92
|
+
|
93
|
+
float_name2quant_name = {}
|
94
|
+
|
95
|
+
# Extract the names of the layers in the float model.
|
96
|
+
float_layers_names = self.model_analyzer_utils.extract_float_layer_names(float_model)
|
97
|
+
|
98
|
+
# Map each quantized layer to the corresponding float layer.
|
99
|
+
for quant_point in quant_points_names:
|
100
|
+
candidate_float_layer_name = self.model_analyzer_utils.find_corresponding_float_layer(
|
101
|
+
quant_compare_point=quant_point, quantized_model=quantized_model)
|
102
|
+
|
103
|
+
if candidate_float_layer_name in float_layers_names:
|
104
|
+
if candidate_float_layer_name not in float_name2quant_name:
|
105
|
+
float_name2quant_name[candidate_float_layer_name] = quant_point
|
106
|
+
else:
|
107
|
+
Logger.critical(f"Duplicate mapping found for layer: {candidate_float_layer_name}.")
|
108
|
+
else:
|
109
|
+
Logger.warning(
|
110
|
+
f"Could not find a matching layer in the float model for layer with name {quant_point}, "
|
111
|
+
f"skipping it in similarity metrics comparison points computation.")
|
112
|
+
|
113
|
+
return float_name2quant_name
|
114
|
+
|
115
|
+
def compute_similarity_metrics(self,
|
116
|
+
float_model: Any,
|
117
|
+
quantized_model: Any,
|
118
|
+
dataset: Callable,
|
119
|
+
custom_similarity_metrics: Dict[str, Callable] = None,
|
120
|
+
is_validation: bool = False) -> Tuple[Dict[str, float], Dict[str, Dict[str, float]]]:
|
121
|
+
"""
|
122
|
+
Compute the similarity metrics between the two models (usually, float and quantized models).
|
123
|
+
|
124
|
+
Args:
|
125
|
+
float_model (Any): The float model.
|
126
|
+
quantized_model (Any): The quantized model.
|
127
|
+
dataset (Callable): A callable to provide the dataset.
|
128
|
+
custom_similarity_metrics (Dict[str, Callable], optional): Custom similarity metrics. Defaults to None.
|
129
|
+
is_validation (bool, optional): Flag to indicate if the dataset is for validation. Defaults to False.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
Tuple[Dict[str, float], Dict[str, Dict[str, float]]]: Aggregated output similarity metrics and
|
133
|
+
intermediate similarity metrics for each layer.
|
134
|
+
"""
|
135
|
+
# Prepare the dataset such that the rest of operations are indistinguishable between the representative
|
136
|
+
# dataset and the validation dataset.
|
137
|
+
dataset = partial(self.dataset_utils.prepare_dataset,
|
138
|
+
dataset=dataset,
|
139
|
+
is_validation=is_validation,
|
140
|
+
device=self.device)
|
141
|
+
|
142
|
+
# Create a folded version of the float model.
|
143
|
+
float_model = self.model_folding.create_float_folded_model(float_model=float_model,
|
144
|
+
representative_dataset=dataset)
|
145
|
+
|
146
|
+
# Gather similarity metrics to compute (default and custom).
|
147
|
+
similarity_metrics_to_compute = self.similarity_functions.get_default_similarity_metrics()
|
148
|
+
if custom_similarity_metrics:
|
149
|
+
if not isinstance(custom_similarity_metrics, dict):
|
150
|
+
Logger.critical(
|
151
|
+
f"custom_similarity_metrics should be a dictionary but is of type "
|
152
|
+
f"{type(custom_similarity_metrics)}.")
|
153
|
+
similarity_metrics_to_compute.update(custom_similarity_metrics)
|
154
|
+
|
155
|
+
# Map float model layers to quantized model layers for comparison.
|
156
|
+
float_name2quant_name = self._get_float_to_quantized_compare_points(float_model=float_model,
|
157
|
+
quantized_model=quantized_model)
|
158
|
+
|
159
|
+
# Initialize dictionaries to store similarity metrics.
|
160
|
+
output_similarity_metrics = {key: [] for key in similarity_metrics_to_compute.keys()}
|
161
|
+
intermediate_similarity_metrics = {layer: {key: [] for key in similarity_metrics_to_compute.keys()} for layer in
|
162
|
+
float_name2quant_name.values()}
|
163
|
+
|
164
|
+
# Iterate over the dataset and compute similarity metrics.
|
165
|
+
for x in dataset():
|
166
|
+
# Extract activations and predictions from both models.
|
167
|
+
float_activations, quant_activations = (
|
168
|
+
self.model_analyzer_utils.extract_model_activations(
|
169
|
+
float_model, quantized_model, float_name2quant_name, x))
|
170
|
+
|
171
|
+
float_predictions = float_activations[MODEL_OUTPUT_KEY]
|
172
|
+
quant_predictions = quant_activations[MODEL_OUTPUT_KEY]
|
173
|
+
|
174
|
+
# Compute similarity metrics for the output predictions.
|
175
|
+
output_results = self.compute_tensors_similarity((float_predictions, quant_predictions),
|
176
|
+
similarity_metrics_to_compute)
|
177
|
+
for key in output_similarity_metrics:
|
178
|
+
output_similarity_metrics[key].append(output_results[key])
|
179
|
+
|
180
|
+
# Compute similarity metrics for each intermediate layer.
|
181
|
+
for float_layer, quant_layer in float_name2quant_name.items():
|
182
|
+
intermediate_results = self.compute_tensors_similarity(
|
183
|
+
(float_activations[float_layer], quant_activations[quant_layer]),
|
184
|
+
similarity_metrics_to_compute)
|
185
|
+
for key in intermediate_similarity_metrics[quant_layer]:
|
186
|
+
intermediate_similarity_metrics[quant_layer][key].append(intermediate_results[key])
|
187
|
+
|
188
|
+
# Aggregate the output similarity metrics.
|
189
|
+
aggregated_output_similarity_metrics = {key: sum(value) / len(value) for key, value in
|
190
|
+
output_similarity_metrics.items()}
|
191
|
+
|
192
|
+
# Aggregate the intermediate similarity metrics for each layer.
|
193
|
+
for layer_name, layer_similarity_metrics in intermediate_similarity_metrics.items():
|
194
|
+
for similarity_name, similarity_values_list in layer_similarity_metrics.items():
|
195
|
+
if len(similarity_values_list) == 0:
|
196
|
+
Logger.critical(f"Can not average similarities of an empty list.")
|
197
|
+
intermediate_similarity_metrics[layer_name][similarity_name] = sum(similarity_values_list) / len(similarity_values_list)
|
198
|
+
|
199
|
+
return aggregated_output_similarity_metrics, intermediate_similarity_metrics
|
@@ -40,15 +40,18 @@ class KerasTensorboardUtils(TensorboardUtils):
|
|
40
40
|
"""
|
41
41
|
|
42
42
|
def __init__(self, report_dir: str,
|
43
|
+
fw_info: FrameworkInfo,
|
43
44
|
fw_impl: FrameworkImplementation):
|
44
45
|
"""
|
45
46
|
Initialize the KerasTensorboardUtils class with the given parameters.
|
46
47
|
|
47
48
|
Args:
|
48
49
|
report_dir (str): Directory where the TensorBoard files will be stored.
|
50
|
+
fw_info (FrameworkInfo): Information about the framework being used.
|
49
51
|
fw_impl (FrameworkImplementation): Implementation functions for the framework.
|
50
52
|
"""
|
51
53
|
super().__init__(report_dir,
|
54
|
+
fw_info,
|
52
55
|
fw_impl)
|
53
56
|
|
54
57
|
def get_graph_for_tensorboard_display(self,
|
@@ -0,0 +1,77 @@
|
|
1
|
+
# Copyright 2025 Sony Semiconductor Solutions. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Callable, Any, Dict
|
17
|
+
from model_compression_toolkit.xquant import XQuantConfig
|
18
|
+
from model_compression_toolkit.xquant.pytorch.detect_degrade_utils import make_similarity_graph
|
19
|
+
from model_compression_toolkit.logger import Logger
|
20
|
+
|
21
|
+
def core_detect_degrade_layer(repl_similarity: tuple[Dict[str, Any], Dict[str, Any]],
|
22
|
+
val_similarity: tuple[Dict[str, Any], Dict[str, Any]],
|
23
|
+
xquant_config: XQuantConfig) -> list[str]:
|
24
|
+
"""
|
25
|
+
Detect degrade layers by caliculated similarities by XQuant.
|
26
|
+
And Draw and save similarity graphs.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
repr_similarity (tuple[Dict[str, Any], Dict[str, Any]]): Quant error reports of Representative dataset.
|
30
|
+
val_similarity (tuple[Dict[str, Any], Dict[str, Any]]): Quant error reports of Validation dataset.
|
31
|
+
xquant_config (XQuantConfig): Configuration settings for explainable quantization.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
List[str]: A list of detected degrade layers.
|
35
|
+
"""
|
36
|
+
|
37
|
+
degrade_layers = []
|
38
|
+
metrics_names = repl_similarity[0].keys()
|
39
|
+
for metrics_name in metrics_names:
|
40
|
+
# Check xquant_config parameter of custom_similarity_metrics (If not threshold then skip. If not flag of under/above threshold detection then assume above threshold detection.)
|
41
|
+
if(metrics_name not in xquant_config.threshold_quantize_error.keys()):
|
42
|
+
Logger.warning("XQuantConfig.threshold_quantize_error[\'{}\'] is not defined. Skipping detection degrade layers by \'{}\'".format(metrics_name, metrics_name))
|
43
|
+
continue
|
44
|
+
if(metrics_name not in xquant_config.is_detect_under_threshold_quantize_error.keys()):
|
45
|
+
Logger.warning("XQuantConfig.is_detect_under_threshold_quantize_error[{}] is not defined. Assume =False".format(metrics_name))
|
46
|
+
xquant_config.is_detect_under_threshold_quantize_error[metrics_name] = False
|
47
|
+
|
48
|
+
for dataset_similarity, dataset_name in [(repl_similarity, "repr"), (val_similarity, "val")]:
|
49
|
+
degrade_layers_tmp = []
|
50
|
+
intermediate_similarity = dataset_similarity[1]
|
51
|
+
for layer_name in intermediate_similarity.keys():
|
52
|
+
quantize_error = intermediate_similarity[layer_name][metrics_name]
|
53
|
+
threshold_quantize_error = xquant_config.threshold_quantize_error[metrics_name]
|
54
|
+
|
55
|
+
# Switch by under/above threshold flag of xquant_config
|
56
|
+
is_degrade = False
|
57
|
+
if(xquant_config.is_detect_under_threshold_quantize_error[metrics_name]):
|
58
|
+
if(quantize_error <= threshold_quantize_error):
|
59
|
+
is_degrade = True
|
60
|
+
else:
|
61
|
+
if(quantize_error >= threshold_quantize_error):
|
62
|
+
is_degrade = True
|
63
|
+
|
64
|
+
# Add degrade_layers
|
65
|
+
if(is_degrade):
|
66
|
+
# Print to Console
|
67
|
+
if(len(degrade_layers)==0):
|
68
|
+
Logger.info("This may be problematic because the quantization error is larger than other layers. Refer to the TroubleShooting Documentation (MCT XQuant Extension Tool).")
|
69
|
+
Logger.info("{}[{}]={}".format(metrics_name, layer_name, quantize_error))
|
70
|
+
if(layer_name not in degrade_layers_tmp):
|
71
|
+
degrade_layers_tmp.append(layer_name)
|
72
|
+
if(layer_name not in degrade_layers):
|
73
|
+
degrade_layers.append(layer_name)
|
74
|
+
# Draw and save similarity graph by matplotlib
|
75
|
+
make_similarity_graph(metrics_name, dataset_name, intermediate_similarity, degrade_layers_tmp, xquant_config)
|
76
|
+
|
77
|
+
return degrade_layers
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright 2025 Sony Semiconductor Solutions. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import os
|
17
|
+
from typing import Callable, Any, Dict
|
18
|
+
from model_compression_toolkit.xquant import XQuantConfig
|
19
|
+
from model_compression_toolkit.core.common import Graph
|
20
|
+
from model_compression_toolkit.xquant.pytorch.judge_troubleshoot_utils import judge_outlier_removal, judge_shift_negative_activation, judge_unbalanced_concatnation, judge_mixed_precision_with_model_output_loss_objective
|
21
|
+
|
22
|
+
def core_judge_troubleshoot(float_model: Any,
|
23
|
+
quantized_model: Any,
|
24
|
+
float_graph: Graph,
|
25
|
+
degrade_layers: list[str],
|
26
|
+
dataset: Callable,
|
27
|
+
xquant_config: XQuantConfig) -> Dict[str, Any]:
|
28
|
+
"""
|
29
|
+
Judge whether judgeable troubleshoots and make troubleshoot report.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
float_model (Any): The original floating-point model.
|
33
|
+
quantized_model (Any): The model after quantization.
|
34
|
+
float_graph (Graph): Graph of float_model with histgrams.
|
35
|
+
degrade_layers (list[str]): A list of detected degrade layers.
|
36
|
+
dataset (Callable): Representative dataset used for similarity metrics computation.
|
37
|
+
xquant_config (XQuantConfig): Configuration settings for explainable quantization.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
Dict[str, Any]: A dictionary containing the analyze degrade cause report for degraded layers.
|
41
|
+
"""
|
42
|
+
|
43
|
+
_troubleshoot_data = {"outlier_removal":[], "shift_negative_activation":[], "unbalanced_concatenation":[], "mixed_precision_with_model_output_loss_objective":[]}
|
44
|
+
|
45
|
+
# Judge whether the layer has outliers from statistics information
|
46
|
+
# make outlier image folder
|
47
|
+
outlier_histgram_dir = os.path.join(xquant_config.report_dir, "outlier_histgrams")
|
48
|
+
if(not os.path.exists(outlier_histgram_dir)):
|
49
|
+
os.mkdir(outlier_histgram_dir)
|
50
|
+
_troubleshoot_data["outlier_removal"] = judge_outlier_removal(degrade_layers, float_graph, xquant_config)
|
51
|
+
|
52
|
+
# Judge whether the layer combines layers with significantly different value ranges
|
53
|
+
_troubleshoot_data["unbalanced_concatenation"] = judge_unbalanced_concatnation(degrade_layers, float_model, dataset, xquant_config)
|
54
|
+
|
55
|
+
# Judge whether the layer has a negative activation function (PReLU / ELU / Hardswish / SiLU / GELU)
|
56
|
+
_troubleshoot_data["shift_negative_activation"] = judge_shift_negative_activation(float_graph, xquant_config)
|
57
|
+
|
58
|
+
# Judge whether the bitwidth of the final layer is less than threshold
|
59
|
+
_troubleshoot_data["mixed_precision_with_model_output_loss_objective"] = judge_mixed_precision_with_model_output_loss_objective(quantized_model, xquant_config)
|
60
|
+
|
61
|
+
# Delete no data key
|
62
|
+
for troubleshoot_name in list(_troubleshoot_data.keys()):
|
63
|
+
if(len(_troubleshoot_data[troubleshoot_name]) == 0):
|
64
|
+
del _troubleshoot_data[troubleshoot_name]
|
65
|
+
|
66
|
+
return _troubleshoot_data
|