mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -37,6 +37,7 @@ def _apply_gptq(gptq_config: GradientPTQConfig,
|
|
37
37
|
tb_w: TensorboardWriter,
|
38
38
|
tg: Graph,
|
39
39
|
tg_bias: Graph,
|
40
|
+
fw_info: FrameworkInfo,
|
40
41
|
fw_impl: FrameworkImplementation,
|
41
42
|
hessian_info_service: HessianInfoService = None) -> Graph:
|
42
43
|
"""
|
@@ -51,6 +52,7 @@ def _apply_gptq(gptq_config: GradientPTQConfig,
|
|
51
52
|
tb_w: TensorBoardWriter object to log events.
|
52
53
|
tg: Float Reference Graph.
|
53
54
|
tg_bias: Graph of quantized model.
|
55
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.).
|
54
56
|
fw_impl: Framework implementation per framework
|
55
57
|
hessian_info_service: HessianInfoService to fetch information based on the hessian approximation for the float model.
|
56
58
|
Returns:
|
@@ -62,6 +64,7 @@ def _apply_gptq(gptq_config: GradientPTQConfig,
|
|
62
64
|
gptq_config,
|
63
65
|
representative_data_gen,
|
64
66
|
fw_impl,
|
67
|
+
fw_info,
|
65
68
|
hessian_info_service=hessian_info_service)
|
66
69
|
|
67
70
|
if tb_w is not None:
|
@@ -74,6 +77,7 @@ def gptq_runner(tg: Graph,
|
|
74
77
|
gptq_config: GradientPTQConfig,
|
75
78
|
representative_data_gen: Callable,
|
76
79
|
gptq_representative_data_gen: Callable,
|
80
|
+
fw_info: FrameworkInfo,
|
77
81
|
fw_impl: FrameworkImplementation,
|
78
82
|
tb_w: TensorboardWriter,
|
79
83
|
hessian_info_service: HessianInfoService = None) -> Graph:
|
@@ -87,6 +91,7 @@ def gptq_runner(tg: Graph,
|
|
87
91
|
gptq_config: GradientPTQConfig with parameters about the tuning process.
|
88
92
|
representative_data_gen: Dataset used for calibration.
|
89
93
|
gptq_representative_data_gen: Dataset used for GPTQ training
|
94
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.)
|
90
95
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
91
96
|
tb_w: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
|
92
97
|
hessian_info_service: HessianScoresService to fetch approximations of the hessian scores for the float model.
|
@@ -99,7 +104,7 @@ def gptq_runner(tg: Graph,
|
|
99
104
|
#############################################
|
100
105
|
# Apply Statistics Correction
|
101
106
|
#############################################
|
102
|
-
tg_bias = apply_statistics_correction(tg, representative_data_gen, core_config, fw_impl, tb_w)
|
107
|
+
tg_bias = apply_statistics_correction(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
|
103
108
|
|
104
109
|
if tb_w is not None:
|
105
110
|
tb_w.add_graph(tg_bias, 'after_bias_correction')
|
@@ -112,6 +117,7 @@ def gptq_runner(tg: Graph,
|
|
112
117
|
tb_w,
|
113
118
|
tg,
|
114
119
|
tg_bias,
|
120
|
+
fw_info,
|
115
121
|
fw_impl,
|
116
122
|
hessian_info_service=hessian_info_service)
|
117
123
|
|
@@ -17,7 +17,6 @@ from typing import Callable, Tuple, Union
|
|
17
17
|
|
18
18
|
from model_compression_toolkit import get_target_platform_capabilities
|
19
19
|
from model_compression_toolkit.constants import TENSORFLOW
|
20
|
-
from model_compression_toolkit.quantization_preparation.load_fqc import load_fqc_configuration
|
21
20
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
22
21
|
from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
|
23
22
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
@@ -25,8 +24,10 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
|
|
25
24
|
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
26
25
|
from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
|
27
26
|
from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
|
27
|
+
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
28
28
|
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
29
29
|
from model_compression_toolkit.logger import Logger
|
30
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
30
31
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
31
32
|
|
32
33
|
if FOUND_TF:
|
@@ -34,12 +35,11 @@ if FOUND_TF:
|
|
34
35
|
AttachTpcToKeras
|
35
36
|
from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
|
36
37
|
from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
|
37
|
-
from model_compression_toolkit.core.keras.default_framework_info import
|
38
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
38
39
|
from tensorflow.keras.models import Model
|
39
40
|
|
40
41
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
41
42
|
|
42
|
-
@set_keras_info
|
43
43
|
def keras_pruning_experimental(model: Model,
|
44
44
|
target_resource_utilization: ResourceUtilization,
|
45
45
|
representative_data_gen: Callable,
|
@@ -116,25 +116,30 @@ if FOUND_TF:
|
|
116
116
|
|
117
117
|
target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
|
118
118
|
# Attach tpc model to framework
|
119
|
-
|
119
|
+
attach2keras = AttachTpcToKeras()
|
120
|
+
target_platform_capabilities = attach2keras.attach(target_platform_capabilities)
|
120
121
|
|
121
122
|
# Convert the original Keras model to an internal graph representation.
|
122
123
|
float_graph = read_model_to_graph(model,
|
123
124
|
representative_data_gen,
|
124
|
-
|
125
|
+
target_platform_capabilities,
|
126
|
+
DEFAULT_KERAS_INFO,
|
125
127
|
fw_impl)
|
126
128
|
|
127
129
|
# Apply quantization configuration to the graph. This step is necessary even when not quantizing,
|
128
130
|
# as it prepares the graph for the pruning process.
|
129
|
-
float_graph_with_compression_config =
|
131
|
+
float_graph_with_compression_config = set_quantization_configuration_to_graph(float_graph,
|
132
|
+
quant_config=DEFAULTCONFIG,
|
133
|
+
mixed_precision_enable=False)
|
130
134
|
|
131
135
|
# Create a Pruner object with the graph and configuration.
|
132
136
|
pruner = Pruner(float_graph_with_compression_config,
|
137
|
+
DEFAULT_KERAS_INFO,
|
133
138
|
fw_impl,
|
134
139
|
target_resource_utilization,
|
135
140
|
representative_data_gen,
|
136
141
|
pruning_config,
|
137
|
-
|
142
|
+
target_platform_capabilities)
|
138
143
|
|
139
144
|
# Apply the pruning process.
|
140
145
|
pruned_graph = pruner.prune_graph()
|
@@ -23,9 +23,10 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
|
|
23
23
|
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
24
24
|
from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
|
25
25
|
from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
|
26
|
-
from model_compression_toolkit.
|
26
|
+
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
27
27
|
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
28
28
|
from model_compression_toolkit.logger import Logger
|
29
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
29
30
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
30
31
|
|
31
32
|
|
@@ -35,7 +36,7 @@ if FOUND_TORCH:
|
|
35
36
|
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
|
36
37
|
from model_compression_toolkit.core.pytorch.pruning.pruning_pytorch_implementation import \
|
37
38
|
PruningPytorchImplementation
|
38
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import
|
39
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
39
40
|
from torch.nn import Module
|
40
41
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
41
42
|
AttachTpcToPytorch
|
@@ -43,7 +44,6 @@ if FOUND_TORCH:
|
|
43
44
|
# Set the default Target Platform Capabilities (TPC) for PyTorch.
|
44
45
|
DEFAULT_PYOTRCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
45
46
|
|
46
|
-
@set_pytorch_info
|
47
47
|
def pytorch_pruning_experimental(model: Module,
|
48
48
|
target_resource_utilization: ResourceUtilization,
|
49
49
|
representative_data_gen: Callable,
|
@@ -129,14 +129,18 @@ if FOUND_TORCH:
|
|
129
129
|
float_graph = read_model_to_graph(model,
|
130
130
|
representative_data_gen,
|
131
131
|
framework_platform_capabilities,
|
132
|
+
DEFAULT_PYTORCH_INFO,
|
132
133
|
fw_impl)
|
133
134
|
|
134
135
|
# Apply quantization configuration to the graph. This step is necessary even when not quantizing,
|
135
136
|
# as it prepares the graph for the pruning process.
|
136
|
-
float_graph_with_compression_config =
|
137
|
+
float_graph_with_compression_config = set_quantization_configuration_to_graph(float_graph,
|
138
|
+
quant_config=DEFAULTCONFIG,
|
139
|
+
mixed_precision_enable=False)
|
137
140
|
|
138
141
|
# Create a Pruner object with the graph and configuration.
|
139
142
|
pruner = Pruner(float_graph_with_compression_config,
|
143
|
+
DEFAULT_PYTORCH_INFO,
|
140
144
|
fw_impl,
|
141
145
|
target_resource_utilization,
|
142
146
|
representative_data_gen,
|
@@ -36,8 +36,9 @@ from model_compression_toolkit.metadata import create_model_metadata
|
|
36
36
|
if FOUND_TF:
|
37
37
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
38
38
|
AttachTpcToKeras
|
39
|
-
from model_compression_toolkit.core.keras.default_framework_info import
|
39
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
40
40
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
41
|
+
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
41
42
|
from tensorflow.keras.models import Model
|
42
43
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
43
44
|
from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
|
@@ -48,7 +49,6 @@ if FOUND_TF:
|
|
48
49
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
49
50
|
|
50
51
|
|
51
|
-
@set_keras_info
|
52
52
|
def keras_post_training_quantization(in_model: Model,
|
53
53
|
representative_data_gen: Callable,
|
54
54
|
target_resource_utilization: ResourceUtilization = None,
|
@@ -121,20 +121,25 @@ if FOUND_TF:
|
|
121
121
|
|
122
122
|
>>> quantized_model, quantization_info = mct.ptq.keras_post_training_quantization(model, repr_datagen, ru, core_config=config)
|
123
123
|
|
124
|
-
For more configuration options, please take a look at our `API documentation <https://
|
124
|
+
For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
|
125
125
|
|
126
126
|
"""
|
127
127
|
|
128
128
|
if core_config.debug_config.bypass:
|
129
129
|
return in_model, None
|
130
130
|
|
131
|
+
fw_info = DEFAULT_KERAS_INFO
|
132
|
+
|
133
|
+
KerasModelValidation(model=in_model,
|
134
|
+
fw_info=fw_info).validate()
|
135
|
+
|
131
136
|
if core_config.is_mixed_precision_enabled:
|
132
137
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
133
138
|
Logger.critical("Given quantization config to mixed-precision facade is not of type "
|
134
139
|
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
|
135
140
|
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
136
141
|
|
137
|
-
tb_w = init_tensorboard_writer()
|
142
|
+
tb_w = init_tensorboard_writer(fw_info)
|
138
143
|
|
139
144
|
fw_impl = KerasImplementation()
|
140
145
|
|
@@ -148,6 +153,7 @@ if FOUND_TF:
|
|
148
153
|
tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_model,
|
149
154
|
representative_data_gen=representative_data_gen,
|
150
155
|
core_config=core_config,
|
156
|
+
fw_info=fw_info,
|
151
157
|
fw_impl=fw_impl,
|
152
158
|
fqc=framework_platform_capabilities,
|
153
159
|
target_resource_utilization=target_resource_utilization,
|
@@ -163,6 +169,7 @@ if FOUND_TF:
|
|
163
169
|
graph_with_stats_correction = ptq_runner(tg,
|
164
170
|
representative_data_gen,
|
165
171
|
core_config,
|
172
|
+
fw_info,
|
166
173
|
fw_impl,
|
167
174
|
tb_w)
|
168
175
|
|
@@ -172,7 +179,8 @@ if FOUND_TF:
|
|
172
179
|
tb_w,
|
173
180
|
similarity_baseline_graph,
|
174
181
|
quantized_graph,
|
175
|
-
fw_impl
|
182
|
+
fw_impl,
|
183
|
+
fw_info)
|
176
184
|
|
177
185
|
exportable_model, user_info = get_exportable_keras_model(graph_with_stats_correction)
|
178
186
|
if framework_platform_capabilities.tpc.add_metadata:
|
@@ -34,7 +34,7 @@ from model_compression_toolkit.core.common.quantization.quantize_graph_weights i
|
|
34
34
|
from model_compression_toolkit.metadata import create_model_metadata
|
35
35
|
|
36
36
|
if FOUND_TORCH:
|
37
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import
|
37
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
38
38
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
39
39
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
40
40
|
from torch.nn import Module
|
@@ -46,7 +46,6 @@ if FOUND_TORCH:
|
|
46
46
|
|
47
47
|
DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
48
48
|
|
49
|
-
@set_pytorch_info
|
50
49
|
def pytorch_post_training_quantization(in_module: Module,
|
51
50
|
representative_data_gen: Callable,
|
52
51
|
target_resource_utilization: ResourceUtilization = None,
|
@@ -103,6 +102,8 @@ if FOUND_TORCH:
|
|
103
102
|
if core_config.debug_config.bypass:
|
104
103
|
return in_module, None
|
105
104
|
|
105
|
+
fw_info = DEFAULT_PYTORCH_INFO
|
106
|
+
|
106
107
|
if core_config.is_mixed_precision_enabled:
|
107
108
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
108
109
|
Logger.critical("Given quantization config to mixed-precision facade is not of type "
|
@@ -110,7 +111,7 @@ if FOUND_TORCH:
|
|
110
111
|
"pytorch_post_training_quantization API, or pass a valid mixed precision "
|
111
112
|
"configuration.") # pragma: no cover
|
112
113
|
|
113
|
-
tb_w = init_tensorboard_writer()
|
114
|
+
tb_w = init_tensorboard_writer(fw_info)
|
114
115
|
|
115
116
|
fw_impl = PytorchImplementation()
|
116
117
|
|
@@ -124,6 +125,7 @@ if FOUND_TORCH:
|
|
124
125
|
tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_module,
|
125
126
|
representative_data_gen=representative_data_gen,
|
126
127
|
core_config=core_config,
|
128
|
+
fw_info=fw_info,
|
127
129
|
fw_impl=fw_impl,
|
128
130
|
fqc=framework_platform_capabilities,
|
129
131
|
target_resource_utilization=target_resource_utilization,
|
@@ -139,6 +141,7 @@ if FOUND_TORCH:
|
|
139
141
|
graph_with_stats_correction = ptq_runner(tg,
|
140
142
|
representative_data_gen,
|
141
143
|
core_config,
|
144
|
+
fw_info,
|
142
145
|
fw_impl,
|
143
146
|
tb_w)
|
144
147
|
|
@@ -148,7 +151,8 @@ if FOUND_TORCH:
|
|
148
151
|
tb_w,
|
149
152
|
similarity_baseline_graph,
|
150
153
|
quantized_graph,
|
151
|
-
fw_impl
|
154
|
+
fw_impl,
|
155
|
+
fw_info)
|
152
156
|
|
153
157
|
exportable_model, user_info = get_exportable_pytorch_model(graph_with_stats_correction)
|
154
158
|
if framework_platform_capabilities.tpc.add_metadata:
|
@@ -16,6 +16,7 @@
|
|
16
16
|
|
17
17
|
from typing import Callable
|
18
18
|
|
19
|
+
from model_compression_toolkit.core.common import FrameworkInfo
|
19
20
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
20
21
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
21
22
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
@@ -27,6 +28,7 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
|
|
27
28
|
def ptq_runner(tg: Graph,
|
28
29
|
representative_data_gen: Callable,
|
29
30
|
core_config: CoreConfig,
|
31
|
+
fw_info: FrameworkInfo,
|
30
32
|
fw_impl: FrameworkImplementation,
|
31
33
|
tb_w: TensorboardWriter) -> Graph:
|
32
34
|
"""
|
@@ -36,6 +38,7 @@ def ptq_runner(tg: Graph,
|
|
36
38
|
tg: Graph to apply PTQ and to quantize.
|
37
39
|
representative_data_gen (Callable): Dataset used for calibration.
|
38
40
|
core_config: CoreConfig containing parameters of how the model should be quantized.
|
41
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
39
42
|
groups of layers by how they should be quantized, etc.)
|
40
43
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
41
44
|
tb_w: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
|
@@ -47,5 +50,5 @@ def ptq_runner(tg: Graph,
|
|
47
50
|
#############################################
|
48
51
|
# Statistics Correction
|
49
52
|
#############################################
|
50
|
-
tg = apply_statistics_correction(tg, representative_data_gen, core_config, fw_impl, tb_w)
|
53
|
+
tg = apply_statistics_correction(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
|
51
54
|
return tg
|
@@ -19,17 +19,21 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
19
19
|
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
|
20
20
|
|
21
21
|
|
22
|
-
def is_qat_applicable(node: common.BaseNode
|
22
|
+
def is_qat_applicable(node: common.BaseNode,
|
23
|
+
fw_info: FrameworkInfo) -> bool:
|
23
24
|
"""
|
24
25
|
A function for deciding if a layer should be fine-tuned during QAT
|
25
26
|
|
26
27
|
Args:
|
27
28
|
node (BaseNode): Node for quantization decision
|
29
|
+
fw_info (FrameworkInfo): Pytorch quantization information
|
28
30
|
|
29
31
|
Returns:
|
30
32
|
A boolean whether the layer is to be wrapped with a QuantizeWrapper
|
31
33
|
"""
|
32
|
-
|
34
|
+
|
35
|
+
kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
|
36
|
+
return (kernel_attr is not None and node.is_weights_quantization_enabled(kernel_attr)) \
|
33
37
|
or node.is_activation_quantization_enabled()
|
34
38
|
|
35
39
|
|
@@ -37,9 +37,10 @@ if FOUND_TF:
|
|
37
37
|
from tensorflow.keras.models import Model
|
38
38
|
|
39
39
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
40
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
40
41
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
42
|
+
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
41
43
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
42
|
-
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
|
43
44
|
|
44
45
|
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
45
46
|
|
@@ -51,6 +52,7 @@ if FOUND_TF:
|
|
51
52
|
from model_compression_toolkit.constants import TENSORFLOW
|
52
53
|
from model_compression_toolkit.qat.common.qat_config import is_qat_applicable
|
53
54
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
55
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
54
56
|
from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder, \
|
55
57
|
get_activation_quantizer_holder
|
56
58
|
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
@@ -71,11 +73,11 @@ if FOUND_TF:
|
|
71
73
|
Returns: Wrapped layer
|
72
74
|
|
73
75
|
"""
|
74
|
-
if is_qat_applicable(n):
|
76
|
+
if is_qat_applicable(n, DEFAULT_KERAS_INFO):
|
75
77
|
# If we are here, then the node has a kernel attribute to quantize and training during QAT
|
76
78
|
weights_quantizers, _ = quantization_builder(n,
|
77
79
|
qat_config,
|
78
|
-
n.
|
80
|
+
DEFAULT_KERAS_INFO.get_kernel_op_attributes(n.type)[0])
|
79
81
|
if len(weights_quantizers) > 0:
|
80
82
|
layer.trainable = True
|
81
83
|
return KerasTrainableQuantizationWrapper(layer, weights_quantizers)
|
@@ -85,7 +87,6 @@ if FOUND_TF:
|
|
85
87
|
return layer
|
86
88
|
|
87
89
|
|
88
|
-
@set_keras_info
|
89
90
|
def keras_quantization_aware_training_init_experimental(in_model: Model,
|
90
91
|
representative_data_gen: Callable,
|
91
92
|
target_resource_utilization: ResourceUtilization = None,
|
@@ -166,7 +167,7 @@ if FOUND_TF:
|
|
166
167
|
|
167
168
|
>>> quantized_model = tf.keras.models.load_model(model_file, custom_objects=custom_objects)
|
168
169
|
|
169
|
-
For more configuration options, please take a look at our `API documentation <https://
|
170
|
+
For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
|
170
171
|
|
171
172
|
"""
|
172
173
|
|
@@ -174,13 +175,16 @@ if FOUND_TF:
|
|
174
175
|
f"If you encounter an issue, please open an issue in our GitHub "
|
175
176
|
f"project https://github.com/sony/model_optimization")
|
176
177
|
|
178
|
+
KerasModelValidation(model=in_model,
|
179
|
+
fw_info=DEFAULT_KERAS_INFO).validate()
|
180
|
+
|
177
181
|
if core_config.is_mixed_precision_enabled:
|
178
182
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
179
183
|
Logger.critical("Given quantization config to mixed-precision facade is not of type "
|
180
184
|
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization API,"
|
181
185
|
"or pass a valid mixed precision configuration.")
|
182
186
|
|
183
|
-
tb_w = init_tensorboard_writer()
|
187
|
+
tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)
|
184
188
|
|
185
189
|
fw_impl = KerasImplementation()
|
186
190
|
|
@@ -194,15 +198,17 @@ if FOUND_TF:
|
|
194
198
|
tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
|
195
199
|
representative_data_gen=representative_data_gen,
|
196
200
|
core_config=core_config,
|
201
|
+
fw_info=DEFAULT_KERAS_INFO,
|
197
202
|
fw_impl=fw_impl,
|
198
203
|
fqc=target_platform_capabilities,
|
199
204
|
target_resource_utilization=target_resource_utilization,
|
200
205
|
tb_w=tb_w)
|
201
206
|
|
202
|
-
tg = ptq_runner(tg, representative_data_gen, core_config, fw_impl, tb_w)
|
207
|
+
tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_KERAS_INFO, fw_impl, tb_w)
|
203
208
|
|
204
209
|
_qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
|
205
210
|
qat_model, user_info = KerasModelBuilder(graph=tg,
|
211
|
+
fw_info=DEFAULT_KERAS_INFO,
|
206
212
|
wrapper=_qat_wrapper,
|
207
213
|
get_activation_quantizer_holder_fn=partial(get_activation_quantizer_holder,
|
208
214
|
qat_config=qat_config)).build_model()
|
@@ -36,7 +36,7 @@ if FOUND_TORCH:
|
|
36
36
|
import torch.nn as nn
|
37
37
|
from torch.nn import Module
|
38
38
|
from mct_quantizers import PytorchActivationQuantizationHolder
|
39
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import
|
39
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
40
40
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
41
41
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
42
42
|
from model_compression_toolkit.qat.common.qat_config import is_qat_applicable
|
@@ -62,10 +62,10 @@ if FOUND_TORCH:
|
|
62
62
|
Returns: Wrapped layer
|
63
63
|
|
64
64
|
"""
|
65
|
-
if is_qat_applicable(n):
|
65
|
+
if is_qat_applicable(n, DEFAULT_PYTORCH_INFO):
|
66
66
|
# If we are here, then the node has a kernel attribute to quantize and training during QAT
|
67
67
|
weights_quantizers, _ = quantization_builder(n, qat_config,
|
68
|
-
n.
|
68
|
+
DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(n.type)[0])
|
69
69
|
if len(weights_quantizers) > 0:
|
70
70
|
return PytorchQuantizationWrapper(module, weights_quantizers)
|
71
71
|
|
@@ -74,7 +74,6 @@ if FOUND_TORCH:
|
|
74
74
|
return module
|
75
75
|
|
76
76
|
|
77
|
-
@set_pytorch_info
|
78
77
|
def pytorch_quantization_aware_training_init_experimental(in_model: Module,
|
79
78
|
representative_data_gen: Callable,
|
80
79
|
target_resource_utilization: ResourceUtilization = None,
|
@@ -136,7 +135,7 @@ if FOUND_TORCH:
|
|
136
135
|
|
137
136
|
>>> quantized_model, quantization_info = mct.qat.pytorch_quantization_aware_training_init_experimental(model, repr_datagen, core_config=config)
|
138
137
|
|
139
|
-
For more configuration options, please take a look at our `API documentation <https://
|
138
|
+
For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
|
140
139
|
|
141
140
|
"""
|
142
141
|
Logger.warning(
|
@@ -150,7 +149,7 @@ if FOUND_TORCH:
|
|
150
149
|
"MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API,"
|
151
150
|
"or pass a valid mixed precision configuration.")
|
152
151
|
|
153
|
-
tb_w = init_tensorboard_writer()
|
152
|
+
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
154
153
|
fw_impl = PytorchImplementation()
|
155
154
|
|
156
155
|
target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
|
@@ -163,16 +162,18 @@ if FOUND_TORCH:
|
|
163
162
|
tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
|
164
163
|
representative_data_gen=representative_data_gen,
|
165
164
|
core_config=core_config,
|
165
|
+
fw_info=DEFAULT_PYTORCH_INFO,
|
166
166
|
fw_impl=fw_impl,
|
167
167
|
fqc=framework_platform_capabilities,
|
168
168
|
target_resource_utilization=target_resource_utilization,
|
169
169
|
tb_w=tb_w)
|
170
170
|
|
171
|
-
tg = ptq_runner(tg, representative_data_gen, core_config, fw_impl, tb_w)
|
171
|
+
tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_PYTORCH_INFO, fw_impl, tb_w)
|
172
172
|
|
173
173
|
_qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
|
174
174
|
|
175
175
|
qat_model, user_info = PyTorchModelBuilder(graph=tg,
|
176
|
+
fw_info=DEFAULT_PYTORCH_INFO,
|
176
177
|
wrapper=_qat_wrapper,
|
177
178
|
get_activation_quantizer_holder_fn=partial(
|
178
179
|
get_activation_quantizer_holder,
|
@@ -180,6 +181,9 @@ if FOUND_TORCH:
|
|
180
181
|
|
181
182
|
user_info.mixed_precision_cfg = bit_widths_config
|
182
183
|
|
184
|
+
# Remove fw_info from graph to enable saving the pytorch model (fw_info can not be pickled)
|
185
|
+
delattr(qat_model.graph, 'fw_info')
|
186
|
+
|
183
187
|
return qat_model, user_info
|
184
188
|
|
185
189
|
|
@@ -29,7 +29,7 @@ QNNPACK_TP_MODEL = 'qnnpack'
|
|
29
29
|
# TP Attributes
|
30
30
|
KERNEL_ATTR = "kernel_attr"
|
31
31
|
BIAS_ATTR = "bias_attr"
|
32
|
-
|
32
|
+
POS_ATTR = "pos_attr"
|
33
33
|
|
34
34
|
# TODO: this is duplicated from the core frameworks constants files, because the original consts can't be used here
|
35
35
|
# duo to circular dependency. It might be best to extract the constants from the core file and put them here (in a
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -32,7 +32,7 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
|
|
32
32
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
|
33
33
|
AttachTpcToFramework
|
34
34
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import Eq
|
35
|
-
from edgemdt_cl.pytorch import MulticlassNMS, MulticlassNMSWithIndices
|
35
|
+
from edgemdt_cl.pytorch import MulticlassNMS, MulticlassNMSWithIndices, MulticlassNMSOBB
|
36
36
|
|
37
37
|
|
38
38
|
class AttachTpcToPytorch(AttachTpcToFramework):
|
@@ -98,7 +98,7 @@ class AttachTpcToPytorch(AttachTpcToFramework):
|
|
98
98
|
OperatorSetNames.L2NORM: [LayerFilterParams(torch.nn.functional.normalize,
|
99
99
|
Eq('p', 2) | Eq('p', None))],
|
100
100
|
OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
|
101
|
-
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [MulticlassNMS, MulticlassNMSWithIndices],
|
101
|
+
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [MulticlassNMS, MulticlassNMSWithIndices, MulticlassNMSOBB],
|
102
102
|
OperatorSetNames.EXP: [torch.exp],
|
103
103
|
OperatorSetNames.SIN: [torch.sin],
|
104
104
|
OperatorSetNames.COS: [torch.cos],
|
@@ -48,6 +48,7 @@ def get_trainable_quantizer_weights_config(
|
|
48
48
|
final_attr_cfg.enable_weights_quantization,
|
49
49
|
final_attr_cfg.weights_channels_axis[0], # Output channel axis
|
50
50
|
final_attr_cfg.weights_per_channel_threshold,
|
51
|
+
final_node_cfg.min_threshold,
|
51
52
|
weights_quantization_candidates)
|
52
53
|
|
53
54
|
|
@@ -75,6 +76,7 @@ def get_trainable_quantizer_activation_config(
|
|
75
76
|
final_cfg.activation_n_bits,
|
76
77
|
final_cfg.activation_quantization_params,
|
77
78
|
final_cfg.enable_activation_quantization,
|
79
|
+
final_cfg.min_threshold,
|
78
80
|
activation_quantization_candidates)
|
79
81
|
|
80
82
|
|
@@ -44,6 +44,7 @@ class TrainableQuantizerActivationConfig:
|
|
44
44
|
activation_n_bits: int,
|
45
45
|
activation_quantization_params: Dict,
|
46
46
|
enable_activation_quantization: bool,
|
47
|
+
min_threshold: float,
|
47
48
|
activation_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None,
|
48
49
|
):
|
49
50
|
"""
|
@@ -54,11 +55,13 @@ class TrainableQuantizerActivationConfig:
|
|
54
55
|
activation_n_bits (int): Number of bits to quantize the activations.
|
55
56
|
activation_quantization_params (Dict): Dictionary that contains activation quantization params.
|
56
57
|
enable_activation_quantization (bool): Whether to quantize the layer's activations or not.
|
58
|
+
min_threshold (float): Minimum threshold to use during thresholds selection.
|
57
59
|
"""
|
58
60
|
self.activation_quantization_method = activation_quantization_method
|
59
61
|
self.activation_n_bits = activation_n_bits
|
60
62
|
self.activation_quantization_params = activation_quantization_params
|
61
63
|
self.enable_activation_quantization = enable_activation_quantization
|
64
|
+
self.min_threshold = min_threshold
|
62
65
|
self.activation_bits_candidates = activation_quantization_candidates
|
63
66
|
|
64
67
|
|
@@ -70,6 +73,7 @@ class TrainableQuantizerWeightsConfig:
|
|
70
73
|
enable_weights_quantization: bool,
|
71
74
|
weights_channels_axis: int,
|
72
75
|
weights_per_channel_threshold: bool,
|
76
|
+
min_threshold: float,
|
73
77
|
weights_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None,
|
74
78
|
):
|
75
79
|
"""
|
@@ -82,6 +86,7 @@ class TrainableQuantizerWeightsConfig:
|
|
82
86
|
enable_weights_quantization (bool): Whether to quantize the layer's weights or not.
|
83
87
|
weights_channels_axis (int): Axis to quantize a node's kernel when quantizing per-channel.
|
84
88
|
weights_per_channel_threshold (bool): Whether to quantize the weights per-channel or not (per-tensor).
|
89
|
+
min_threshold (float): Minimum threshold to use during thresholds selection.
|
85
90
|
"""
|
86
91
|
self.weights_quantization_method = weights_quantization_method
|
87
92
|
self.weights_n_bits = weights_n_bits
|
@@ -89,4 +94,5 @@ class TrainableQuantizerWeightsConfig:
|
|
89
94
|
self.enable_weights_quantization = enable_weights_quantization
|
90
95
|
self.weights_channels_axis = weights_channels_axis
|
91
96
|
self.weights_per_channel_threshold = weights_per_channel_threshold
|
97
|
+
self.min_threshold = min_threshold
|
92
98
|
self.weights_bits_candidates = weights_quantization_candidates
|
@@ -77,11 +77,13 @@ def config_deserialization(in_config: dict) -> Union[TrainableQuantizerWeightsCo
|
|
77
77
|
weights_quantization_params=weights_quantization_params,
|
78
78
|
enable_weights_quantization=in_config[C.ENABLE_WEIGHTS_QUANTIZATION],
|
79
79
|
weights_channels_axis=in_config[C.WEIGHTS_CHANNELS_AXIS],
|
80
|
-
weights_per_channel_threshold=in_config[C.WEIGHTS_PER_CHANNEL_THRESHOLD]
|
80
|
+
weights_per_channel_threshold=in_config[C.WEIGHTS_PER_CHANNEL_THRESHOLD],
|
81
|
+
min_threshold=in_config[C.MIN_THRESHOLD])
|
81
82
|
elif in_config[C.IS_ACTIVATIONS]:
|
82
83
|
return TrainableQuantizerActivationConfig(activation_quantization_method=QuantizationMethod(in_config[C.ACTIVATION_QUANTIZATION_METHOD]),
|
83
84
|
activation_n_bits=in_config[C.ACTIVATION_N_BITS],
|
84
85
|
activation_quantization_params=in_config[C.ACTIVATION_QUANTIZATION_PARAMS],
|
85
|
-
enable_activation_quantization=in_config[C.ENABLE_ACTIVATION_QUANTIZATION]
|
86
|
+
enable_activation_quantization=in_config[C.ENABLE_ACTIVATION_QUANTIZATION],
|
87
|
+
min_threshold=in_config[C.MIN_THRESHOLD])
|
86
88
|
else:
|
87
89
|
raise NotImplemented # pragma: no cover
|
@@ -16,4 +16,5 @@
|
|
16
16
|
from model_compression_toolkit.xquant.common.xquant_config import XQuantConfig
|
17
17
|
from model_compression_toolkit.xquant.keras.facade_xquant_report import xquant_report_keras_experimental
|
18
18
|
from model_compression_toolkit.xquant.pytorch.facade_xquant_report import xquant_report_pytorch_experimental
|
19
|
+
from model_compression_toolkit.xquant.pytorch.facade_xquant_report import xquant_report_troubleshoot_pytorch_experimental
|
19
20
|
|