mct-nightly 2.4.0.20250617.613__py3-none-any.whl → 2.4.0.20250619.621__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/METADATA +1 -1
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/RECORD +123 -123
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +2 -5
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
- model_compression_toolkit/core/common/framework_implementation.py +10 -22
- model_compression_toolkit/core/common/framework_info.py +105 -68
- model_compression_toolkit/core/common/graph/base_graph.py +15 -42
- model_compression_toolkit/core/common/graph/base_node.py +103 -42
- model_compression_toolkit/core/common/graph/functional_node.py +18 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
- model_compression_toolkit/core/common/model_collector.py +10 -20
- model_compression_toolkit/core/common/model_validation.py +1 -4
- model_compression_toolkit/core/common/network_editors/actions.py +14 -38
- model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
- model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
- model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
- model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
- model_compression_toolkit/core/common/pruning/pruner.py +1 -6
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
- model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
- model_compression_toolkit/core/graph_prep_runner.py +2 -16
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/keras/default_framework_info.py +138 -87
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
- model_compression_toolkit/core/keras/keras_implementation.py +15 -35
- model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
- model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/quantization_prep_runner.py +4 -9
- model_compression_toolkit/core/runner.py +5 -15
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +19 -17
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -0
- model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
- model_compression_toolkit/gptq/common/gptq_training.py +1 -8
- model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
- model_compression_toolkit/gptq/keras/graph_info.py +4 -6
- model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
- model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/runner.py +1 -7
- model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
- model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
- model_compression_toolkit/ptq/runner.py +1 -4
- model_compression_toolkit/qat/common/qat_config.py +2 -6
- model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
- model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
- model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
- model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
- model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
- model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
- model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/top_level.txt +0 -0
@@ -35,11 +35,12 @@ if FOUND_TF:
|
|
35
35
|
AttachTpcToKeras
|
36
36
|
from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
|
37
37
|
from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
|
38
|
-
from model_compression_toolkit.core.keras.default_framework_info import
|
38
|
+
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
|
39
39
|
from tensorflow.keras.models import Model
|
40
40
|
|
41
41
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
42
42
|
|
43
|
+
@set_keras_info
|
43
44
|
def keras_pruning_experimental(model: Model,
|
44
45
|
target_resource_utilization: ResourceUtilization,
|
45
46
|
representative_data_gen: Callable,
|
@@ -123,7 +124,6 @@ if FOUND_TF:
|
|
123
124
|
float_graph = read_model_to_graph(model,
|
124
125
|
representative_data_gen,
|
125
126
|
target_platform_capabilities,
|
126
|
-
DEFAULT_KERAS_INFO,
|
127
127
|
fw_impl)
|
128
128
|
|
129
129
|
# Apply quantization configuration to the graph. This step is necessary even when not quantizing,
|
@@ -134,7 +134,6 @@ if FOUND_TF:
|
|
134
134
|
|
135
135
|
# Create a Pruner object with the graph and configuration.
|
136
136
|
pruner = Pruner(float_graph_with_compression_config,
|
137
|
-
DEFAULT_KERAS_INFO,
|
138
137
|
fw_impl,
|
139
138
|
target_resource_utilization,
|
140
139
|
representative_data_gen,
|
@@ -36,7 +36,7 @@ if FOUND_TORCH:
|
|
36
36
|
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
|
37
37
|
from model_compression_toolkit.core.pytorch.pruning.pruning_pytorch_implementation import \
|
38
38
|
PruningPytorchImplementation
|
39
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import
|
39
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
|
40
40
|
from torch.nn import Module
|
41
41
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
42
42
|
AttachTpcToPytorch
|
@@ -44,6 +44,7 @@ if FOUND_TORCH:
|
|
44
44
|
# Set the default Target Platform Capabilities (TPC) for PyTorch.
|
45
45
|
DEFAULT_PYOTRCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
46
46
|
|
47
|
+
@set_pytorch_info
|
47
48
|
def pytorch_pruning_experimental(model: Module,
|
48
49
|
target_resource_utilization: ResourceUtilization,
|
49
50
|
representative_data_gen: Callable,
|
@@ -129,7 +130,6 @@ if FOUND_TORCH:
|
|
129
130
|
float_graph = read_model_to_graph(model,
|
130
131
|
representative_data_gen,
|
131
132
|
framework_platform_capabilities,
|
132
|
-
DEFAULT_PYTORCH_INFO,
|
133
133
|
fw_impl)
|
134
134
|
|
135
135
|
# Apply quantization configuration to the graph. This step is necessary even when not quantizing,
|
@@ -140,7 +140,6 @@ if FOUND_TORCH:
|
|
140
140
|
|
141
141
|
# Create a Pruner object with the graph and configuration.
|
142
142
|
pruner = Pruner(float_graph_with_compression_config,
|
143
|
-
DEFAULT_PYTORCH_INFO,
|
144
143
|
fw_impl,
|
145
144
|
target_resource_utilization,
|
146
145
|
representative_data_gen,
|
@@ -36,7 +36,7 @@ from model_compression_toolkit.metadata import create_model_metadata
|
|
36
36
|
if FOUND_TF:
|
37
37
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
38
38
|
AttachTpcToKeras
|
39
|
-
from model_compression_toolkit.core.keras.default_framework_info import
|
39
|
+
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
|
40
40
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
41
41
|
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
42
42
|
from tensorflow.keras.models import Model
|
@@ -49,6 +49,7 @@ if FOUND_TF:
|
|
49
49
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
50
50
|
|
51
51
|
|
52
|
+
@set_keras_info
|
52
53
|
def keras_post_training_quantization(in_model: Model,
|
53
54
|
representative_data_gen: Callable,
|
54
55
|
target_resource_utilization: ResourceUtilization = None,
|
@@ -128,10 +129,7 @@ if FOUND_TF:
|
|
128
129
|
if core_config.debug_config.bypass:
|
129
130
|
return in_model, None
|
130
131
|
|
131
|
-
|
132
|
-
|
133
|
-
KerasModelValidation(model=in_model,
|
134
|
-
fw_info=fw_info).validate()
|
132
|
+
KerasModelValidation(model=in_model).validate()
|
135
133
|
|
136
134
|
if core_config.is_mixed_precision_enabled:
|
137
135
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
@@ -139,7 +137,7 @@ if FOUND_TF:
|
|
139
137
|
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
|
140
138
|
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
141
139
|
|
142
|
-
tb_w = init_tensorboard_writer(
|
140
|
+
tb_w = init_tensorboard_writer()
|
143
141
|
|
144
142
|
fw_impl = KerasImplementation()
|
145
143
|
|
@@ -153,7 +151,6 @@ if FOUND_TF:
|
|
153
151
|
tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_model,
|
154
152
|
representative_data_gen=representative_data_gen,
|
155
153
|
core_config=core_config,
|
156
|
-
fw_info=fw_info,
|
157
154
|
fw_impl=fw_impl,
|
158
155
|
fqc=framework_platform_capabilities,
|
159
156
|
target_resource_utilization=target_resource_utilization,
|
@@ -169,7 +166,6 @@ if FOUND_TF:
|
|
169
166
|
graph_with_stats_correction = ptq_runner(tg,
|
170
167
|
representative_data_gen,
|
171
168
|
core_config,
|
172
|
-
fw_info,
|
173
169
|
fw_impl,
|
174
170
|
tb_w)
|
175
171
|
|
@@ -179,8 +175,7 @@ if FOUND_TF:
|
|
179
175
|
tb_w,
|
180
176
|
similarity_baseline_graph,
|
181
177
|
quantized_graph,
|
182
|
-
fw_impl
|
183
|
-
fw_info)
|
178
|
+
fw_impl)
|
184
179
|
|
185
180
|
exportable_model, user_info = get_exportable_keras_model(graph_with_stats_correction)
|
186
181
|
if framework_platform_capabilities.tpc.add_metadata:
|
@@ -34,7 +34,7 @@ from model_compression_toolkit.core.common.quantization.quantize_graph_weights i
|
|
34
34
|
from model_compression_toolkit.metadata import create_model_metadata
|
35
35
|
|
36
36
|
if FOUND_TORCH:
|
37
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import
|
37
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
|
38
38
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
39
39
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
40
40
|
from torch.nn import Module
|
@@ -46,6 +46,7 @@ if FOUND_TORCH:
|
|
46
46
|
|
47
47
|
DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
48
48
|
|
49
|
+
@set_pytorch_info
|
49
50
|
def pytorch_post_training_quantization(in_module: Module,
|
50
51
|
representative_data_gen: Callable,
|
51
52
|
target_resource_utilization: ResourceUtilization = None,
|
@@ -102,8 +103,6 @@ if FOUND_TORCH:
|
|
102
103
|
if core_config.debug_config.bypass:
|
103
104
|
return in_module, None
|
104
105
|
|
105
|
-
fw_info = DEFAULT_PYTORCH_INFO
|
106
|
-
|
107
106
|
if core_config.is_mixed_precision_enabled:
|
108
107
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
109
108
|
Logger.critical("Given quantization config to mixed-precision facade is not of type "
|
@@ -111,7 +110,7 @@ if FOUND_TORCH:
|
|
111
110
|
"pytorch_post_training_quantization API, or pass a valid mixed precision "
|
112
111
|
"configuration.") # pragma: no cover
|
113
112
|
|
114
|
-
tb_w = init_tensorboard_writer(
|
113
|
+
tb_w = init_tensorboard_writer()
|
115
114
|
|
116
115
|
fw_impl = PytorchImplementation()
|
117
116
|
|
@@ -125,7 +124,6 @@ if FOUND_TORCH:
|
|
125
124
|
tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_module,
|
126
125
|
representative_data_gen=representative_data_gen,
|
127
126
|
core_config=core_config,
|
128
|
-
fw_info=fw_info,
|
129
127
|
fw_impl=fw_impl,
|
130
128
|
fqc=framework_platform_capabilities,
|
131
129
|
target_resource_utilization=target_resource_utilization,
|
@@ -141,7 +139,6 @@ if FOUND_TORCH:
|
|
141
139
|
graph_with_stats_correction = ptq_runner(tg,
|
142
140
|
representative_data_gen,
|
143
141
|
core_config,
|
144
|
-
fw_info,
|
145
142
|
fw_impl,
|
146
143
|
tb_w)
|
147
144
|
|
@@ -151,8 +148,7 @@ if FOUND_TORCH:
|
|
151
148
|
tb_w,
|
152
149
|
similarity_baseline_graph,
|
153
150
|
quantized_graph,
|
154
|
-
fw_impl
|
155
|
-
fw_info)
|
151
|
+
fw_impl)
|
156
152
|
|
157
153
|
exportable_model, user_info = get_exportable_pytorch_model(graph_with_stats_correction)
|
158
154
|
if framework_platform_capabilities.tpc.add_metadata:
|
@@ -16,7 +16,6 @@
|
|
16
16
|
|
17
17
|
from typing import Callable
|
18
18
|
|
19
|
-
from model_compression_toolkit.core.common import FrameworkInfo
|
20
19
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
21
20
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
22
21
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
@@ -28,7 +27,6 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
|
|
28
27
|
def ptq_runner(tg: Graph,
|
29
28
|
representative_data_gen: Callable,
|
30
29
|
core_config: CoreConfig,
|
31
|
-
fw_info: FrameworkInfo,
|
32
30
|
fw_impl: FrameworkImplementation,
|
33
31
|
tb_w: TensorboardWriter) -> Graph:
|
34
32
|
"""
|
@@ -38,7 +36,6 @@ def ptq_runner(tg: Graph,
|
|
38
36
|
tg: Graph to apply PTQ and to quantize.
|
39
37
|
representative_data_gen (Callable): Dataset used for calibration.
|
40
38
|
core_config: CoreConfig containing parameters of how the model should be quantized.
|
41
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
42
39
|
groups of layers by how they should be quantized, etc.)
|
43
40
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
44
41
|
tb_w: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
|
@@ -50,5 +47,5 @@ def ptq_runner(tg: Graph,
|
|
50
47
|
#############################################
|
51
48
|
# Statistics Correction
|
52
49
|
#############################################
|
53
|
-
tg = apply_statistics_correction(tg, representative_data_gen, core_config,
|
50
|
+
tg = apply_statistics_correction(tg, representative_data_gen, core_config, fw_impl, tb_w)
|
54
51
|
return tg
|
@@ -19,21 +19,17 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
19
19
|
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
|
20
20
|
|
21
21
|
|
22
|
-
def is_qat_applicable(node: common.BaseNode
|
23
|
-
fw_info: FrameworkInfo) -> bool:
|
22
|
+
def is_qat_applicable(node: common.BaseNode) -> bool:
|
24
23
|
"""
|
25
24
|
A function for deciding if a layer should be fine-tuned during QAT
|
26
25
|
|
27
26
|
Args:
|
28
27
|
node (BaseNode): Node for quantization decision
|
29
|
-
fw_info (FrameworkInfo): Pytorch quantization information
|
30
28
|
|
31
29
|
Returns:
|
32
30
|
A boolean whether the layer is to be wrapped with a QuantizeWrapper
|
33
31
|
"""
|
34
|
-
|
35
|
-
kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
|
36
|
-
return (kernel_attr is not None and node.is_weights_quantization_enabled(kernel_attr)) \
|
32
|
+
return (node.kernel_attr is not None and node.is_weights_quantization_enabled(node.kernel_attr)) \
|
37
33
|
or node.is_activation_quantization_enabled()
|
38
34
|
|
39
35
|
|
@@ -37,10 +37,10 @@ if FOUND_TF:
|
|
37
37
|
from tensorflow.keras.models import Model
|
38
38
|
|
39
39
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
40
|
-
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
41
40
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
42
41
|
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
43
42
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
43
|
+
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
|
44
44
|
|
45
45
|
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
46
46
|
|
@@ -52,7 +52,6 @@ if FOUND_TF:
|
|
52
52
|
from model_compression_toolkit.constants import TENSORFLOW
|
53
53
|
from model_compression_toolkit.qat.common.qat_config import is_qat_applicable
|
54
54
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
55
|
-
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
56
55
|
from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder, \
|
57
56
|
get_activation_quantizer_holder
|
58
57
|
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
@@ -73,11 +72,11 @@ if FOUND_TF:
|
|
73
72
|
Returns: Wrapped layer
|
74
73
|
|
75
74
|
"""
|
76
|
-
if is_qat_applicable(n
|
75
|
+
if is_qat_applicable(n):
|
77
76
|
# If we are here, then the node has a kernel attribute to quantize and training during QAT
|
78
77
|
weights_quantizers, _ = quantization_builder(n,
|
79
78
|
qat_config,
|
80
|
-
|
79
|
+
n.kernel_attr)
|
81
80
|
if len(weights_quantizers) > 0:
|
82
81
|
layer.trainable = True
|
83
82
|
return KerasTrainableQuantizationWrapper(layer, weights_quantizers)
|
@@ -87,6 +86,7 @@ if FOUND_TF:
|
|
87
86
|
return layer
|
88
87
|
|
89
88
|
|
89
|
+
@set_keras_info
|
90
90
|
def keras_quantization_aware_training_init_experimental(in_model: Model,
|
91
91
|
representative_data_gen: Callable,
|
92
92
|
target_resource_utilization: ResourceUtilization = None,
|
@@ -175,8 +175,7 @@ if FOUND_TF:
|
|
175
175
|
f"If you encounter an issue, please open an issue in our GitHub "
|
176
176
|
f"project https://github.com/sony/model_optimization")
|
177
177
|
|
178
|
-
KerasModelValidation(model=in_model
|
179
|
-
fw_info=DEFAULT_KERAS_INFO).validate()
|
178
|
+
KerasModelValidation(model=in_model).validate()
|
180
179
|
|
181
180
|
if core_config.is_mixed_precision_enabled:
|
182
181
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
@@ -184,7 +183,7 @@ if FOUND_TF:
|
|
184
183
|
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization API,"
|
185
184
|
"or pass a valid mixed precision configuration.")
|
186
185
|
|
187
|
-
tb_w = init_tensorboard_writer(
|
186
|
+
tb_w = init_tensorboard_writer()
|
188
187
|
|
189
188
|
fw_impl = KerasImplementation()
|
190
189
|
|
@@ -198,17 +197,15 @@ if FOUND_TF:
|
|
198
197
|
tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
|
199
198
|
representative_data_gen=representative_data_gen,
|
200
199
|
core_config=core_config,
|
201
|
-
fw_info=DEFAULT_KERAS_INFO,
|
202
200
|
fw_impl=fw_impl,
|
203
201
|
fqc=target_platform_capabilities,
|
204
202
|
target_resource_utilization=target_resource_utilization,
|
205
203
|
tb_w=tb_w)
|
206
204
|
|
207
|
-
tg = ptq_runner(tg, representative_data_gen, core_config,
|
205
|
+
tg = ptq_runner(tg, representative_data_gen, core_config, fw_impl, tb_w)
|
208
206
|
|
209
207
|
_qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
|
210
208
|
qat_model, user_info = KerasModelBuilder(graph=tg,
|
211
|
-
fw_info=DEFAULT_KERAS_INFO,
|
212
209
|
wrapper=_qat_wrapper,
|
213
210
|
get_activation_quantizer_holder_fn=partial(get_activation_quantizer_holder,
|
214
211
|
qat_config=qat_config)).build_model()
|
@@ -36,7 +36,7 @@ if FOUND_TORCH:
|
|
36
36
|
import torch.nn as nn
|
37
37
|
from torch.nn import Module
|
38
38
|
from mct_quantizers import PytorchActivationQuantizationHolder
|
39
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import
|
39
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
|
40
40
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
41
41
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
42
42
|
from model_compression_toolkit.qat.common.qat_config import is_qat_applicable
|
@@ -62,10 +62,10 @@ if FOUND_TORCH:
|
|
62
62
|
Returns: Wrapped layer
|
63
63
|
|
64
64
|
"""
|
65
|
-
if is_qat_applicable(n
|
65
|
+
if is_qat_applicable(n):
|
66
66
|
# If we are here, then the node has a kernel attribute to quantize and training during QAT
|
67
67
|
weights_quantizers, _ = quantization_builder(n, qat_config,
|
68
|
-
|
68
|
+
n.kernel_attr)
|
69
69
|
if len(weights_quantizers) > 0:
|
70
70
|
return PytorchQuantizationWrapper(module, weights_quantizers)
|
71
71
|
|
@@ -74,6 +74,7 @@ if FOUND_TORCH:
|
|
74
74
|
return module
|
75
75
|
|
76
76
|
|
77
|
+
@set_pytorch_info
|
77
78
|
def pytorch_quantization_aware_training_init_experimental(in_model: Module,
|
78
79
|
representative_data_gen: Callable,
|
79
80
|
target_resource_utilization: ResourceUtilization = None,
|
@@ -149,7 +150,7 @@ if FOUND_TORCH:
|
|
149
150
|
"MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API,"
|
150
151
|
"or pass a valid mixed precision configuration.")
|
151
152
|
|
152
|
-
tb_w = init_tensorboard_writer(
|
153
|
+
tb_w = init_tensorboard_writer()
|
153
154
|
fw_impl = PytorchImplementation()
|
154
155
|
|
155
156
|
target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
|
@@ -162,18 +163,16 @@ if FOUND_TORCH:
|
|
162
163
|
tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
|
163
164
|
representative_data_gen=representative_data_gen,
|
164
165
|
core_config=core_config,
|
165
|
-
fw_info=DEFAULT_PYTORCH_INFO,
|
166
166
|
fw_impl=fw_impl,
|
167
167
|
fqc=framework_platform_capabilities,
|
168
168
|
target_resource_utilization=target_resource_utilization,
|
169
169
|
tb_w=tb_w)
|
170
170
|
|
171
|
-
tg = ptq_runner(tg, representative_data_gen, core_config,
|
171
|
+
tg = ptq_runner(tg, representative_data_gen, core_config, fw_impl, tb_w)
|
172
172
|
|
173
173
|
_qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
|
174
174
|
|
175
175
|
qat_model, user_info = PyTorchModelBuilder(graph=tg,
|
176
|
-
fw_info=DEFAULT_PYTORCH_INFO,
|
177
176
|
wrapper=_qat_wrapper,
|
178
177
|
get_activation_quantizer_holder_fn=partial(
|
179
178
|
get_activation_quantizer_holder,
|
@@ -181,9 +180,6 @@ if FOUND_TORCH:
|
|
181
180
|
|
182
181
|
user_info.mixed_precision_cfg = bit_widths_config
|
183
182
|
|
184
|
-
# Remove fw_info from graph to enable saving the pytorch model (fw_info can not be pickled)
|
185
|
-
delattr(qat_model.graph, 'fw_info')
|
186
|
-
|
187
183
|
return qat_model, user_info
|
188
184
|
|
189
185
|
|
@@ -50,7 +50,7 @@ def core_report_generator(float_model: Any,
|
|
50
50
|
|
51
51
|
# Collect histograms on the float model.
|
52
52
|
float_graph = fw_report_utils.model_folding_utils.create_float_folded_graph(float_model, repr_dataset)
|
53
|
-
mi = ModelCollector(float_graph, fw_report_utils.fw_impl
|
53
|
+
mi = ModelCollector(float_graph, fw_report_utils.fw_impl)
|
54
54
|
for _data in tqdm(repr_dataset(), desc="Collecting Histograms"):
|
55
55
|
mi.infer(_data)
|
56
56
|
|
@@ -34,7 +34,6 @@ class FrameworkReportUtils:
|
|
34
34
|
"""
|
35
35
|
|
36
36
|
def __init__(self,
|
37
|
-
fw_info: FrameworkInfo,
|
38
37
|
fw_impl: FrameworkImplementation,
|
39
38
|
similarity_calculator: SimilarityCalculator,
|
40
39
|
dataset_utils: DatasetUtils,
|
@@ -45,7 +44,6 @@ class FrameworkReportUtils:
|
|
45
44
|
Initializes the FrameworkReportUtils class with various utility components required for generating the report.
|
46
45
|
|
47
46
|
Args:
|
48
|
-
fw_info (FrameworkInfo): Information about the framework being used.
|
49
47
|
fw_impl (FrameworkImplementation): The implemented functions of the framework.
|
50
48
|
similarity_calculator (SimilarityCalculator): A utility for calculating similarity metrics.
|
51
49
|
dataset_utils (DatasetUtils): Utilities for handling datasets.
|
@@ -53,7 +51,6 @@ class FrameworkReportUtils:
|
|
53
51
|
tb_utils (TensorboardUtils): Utilities for TensorBoard operations.
|
54
52
|
get_metadata_fn (Callable): Function to retrieve the metadata from the quantized model.
|
55
53
|
"""
|
56
|
-
self.fw_info = fw_info
|
57
54
|
self.fw_impl = fw_impl
|
58
55
|
self.similarity_calculator = similarity_calculator
|
59
56
|
self.dataset_utils = dataset_utils
|
@@ -34,7 +34,6 @@ class ModelFoldingUtils:
|
|
34
34
|
"""
|
35
35
|
|
36
36
|
def __init__(self,
|
37
|
-
fw_info: FrameworkInfo,
|
38
37
|
fw_impl: FrameworkImplementation,
|
39
38
|
fw_default_fqc: FrameworkQuantizationCapabilities):
|
40
39
|
"""
|
@@ -42,11 +41,9 @@ class ModelFoldingUtils:
|
|
42
41
|
and default FQC.
|
43
42
|
|
44
43
|
Args:
|
45
|
-
fw_info: Framework-specific information.
|
46
44
|
fw_impl: Implementation functions for the framework.
|
47
45
|
fw_default_fqc: Default target platform capabilities for the handled framework.
|
48
46
|
"""
|
49
|
-
self.fw_info = fw_info
|
50
47
|
self.fw_impl = fw_impl
|
51
48
|
self.fw_default_fqc = fw_default_fqc
|
52
49
|
|
@@ -69,8 +66,7 @@ class ModelFoldingUtils:
|
|
69
66
|
float_folded_model, _ = self.fw_impl.model_builder(
|
70
67
|
float_graph,
|
71
68
|
mode=ModelBuilderMode.FLOAT,
|
72
|
-
append2output=None
|
73
|
-
fw_info=self.fw_info
|
69
|
+
append2output=None
|
74
70
|
)
|
75
71
|
return float_folded_model
|
76
72
|
|
@@ -100,7 +96,6 @@ class ModelFoldingUtils:
|
|
100
96
|
graph = graph_preparation_runner(in_model=model,
|
101
97
|
representative_data_gen=repr_dataset,
|
102
98
|
fw_impl=self.fw_impl,
|
103
|
-
fw_info=self.fw_info,
|
104
99
|
quantization_config=DEFAULTCONFIG,
|
105
100
|
fqc=self.fw_default_fqc)
|
106
101
|
return graph
|
@@ -36,19 +36,16 @@ class TensorboardUtils:
|
|
36
36
|
|
37
37
|
def __init__(self,
|
38
38
|
report_dir: str,
|
39
|
-
fw_info: FrameworkInfo,
|
40
39
|
fw_impl: FrameworkImplementation):
|
41
40
|
"""
|
42
41
|
Initialize the TensorboardUtils.
|
43
42
|
|
44
43
|
Args:
|
45
44
|
report_dir (str): Directory where Tensorboard logs will be stored.
|
46
|
-
fw_info (FrameworkInfo): Framework-specific information.
|
47
45
|
fw_impl (FrameworkImplementation): Framework-specific implementation.
|
48
46
|
"""
|
49
47
|
self.fw_impl = fw_impl
|
50
|
-
self.
|
51
|
-
self.tb_writer = TensorboardWriter(report_dir, fw_info)
|
48
|
+
self.tb_writer = TensorboardWriter(report_dir)
|
52
49
|
Logger.info(f"Please run: tensorboard --logdir {self.tb_writer.dir_path}")
|
53
50
|
|
54
51
|
def get_graph_for_tensorboard_display(self,
|
@@ -15,7 +15,6 @@
|
|
15
15
|
|
16
16
|
from model_compression_toolkit import get_target_platform_capabilities
|
17
17
|
from model_compression_toolkit.constants import TENSORFLOW
|
18
|
-
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
19
18
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
20
19
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
21
20
|
AttachTpcToKeras
|
@@ -40,7 +39,6 @@ class KerasReportUtils(FrameworkReportUtils):
|
|
40
39
|
Args:
|
41
40
|
report_dir: Logging dir path.
|
42
41
|
"""
|
43
|
-
fw_info = DEFAULT_KERAS_INFO
|
44
42
|
fw_impl = KerasImplementation()
|
45
43
|
|
46
44
|
# Set the default Target Platform Capabilities (TPC) for Keras.
|
@@ -49,8 +47,7 @@ class KerasReportUtils(FrameworkReportUtils):
|
|
49
47
|
framework_platform_capabilities = attach2pytorch.attach(default_tpc)
|
50
48
|
|
51
49
|
dataset_utils = KerasDatasetUtils()
|
52
|
-
model_folding = ModelFoldingUtils(
|
53
|
-
fw_impl=fw_impl,
|
50
|
+
model_folding = ModelFoldingUtils(fw_impl=fw_impl,
|
54
51
|
fw_default_fqc=framework_platform_capabilities)
|
55
52
|
|
56
53
|
similarity_calculator = SimilarityCalculator(dataset_utils=dataset_utils,
|
@@ -59,10 +56,8 @@ class KerasReportUtils(FrameworkReportUtils):
|
|
59
56
|
model_analyzer_utils=KerasModelAnalyzer())
|
60
57
|
|
61
58
|
tb_utils = KerasTensorboardUtils(report_dir=report_dir,
|
62
|
-
fw_impl=fw_impl
|
63
|
-
|
64
|
-
super().__init__(fw_info,
|
65
|
-
fw_impl,
|
59
|
+
fw_impl=fw_impl)
|
60
|
+
super().__init__(fw_impl,
|
66
61
|
similarity_calculator,
|
67
62
|
dataset_utils,
|
68
63
|
model_folding,
|
@@ -40,18 +40,15 @@ class KerasTensorboardUtils(TensorboardUtils):
|
|
40
40
|
"""
|
41
41
|
|
42
42
|
def __init__(self, report_dir: str,
|
43
|
-
fw_info: FrameworkInfo,
|
44
43
|
fw_impl: FrameworkImplementation):
|
45
44
|
"""
|
46
45
|
Initialize the KerasTensorboardUtils class with the given parameters.
|
47
46
|
|
48
47
|
Args:
|
49
48
|
report_dir (str): Directory where the TensorBoard files will be stored.
|
50
|
-
fw_info (FrameworkInfo): Information about the framework being used.
|
51
49
|
fw_impl (FrameworkImplementation): Implementation functions for the framework.
|
52
50
|
"""
|
53
51
|
super().__init__(report_dir,
|
54
|
-
fw_info,
|
55
52
|
fw_impl)
|
56
53
|
|
57
54
|
def get_graph_for_tensorboard_display(self,
|
@@ -20,7 +20,7 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
|
|
20
20
|
AttachTpcToPytorch
|
21
21
|
|
22
22
|
from model_compression_toolkit.xquant.common.framework_report_utils import FrameworkReportUtils
|
23
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import
|
23
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import PyTorchInfo
|
24
24
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
25
25
|
from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
|
26
26
|
from model_compression_toolkit.xquant.common.similarity_calculator import SimilarityCalculator
|
@@ -30,6 +30,7 @@ from model_compression_toolkit.xquant.pytorch.similarity_functions import Pytorc
|
|
30
30
|
from model_compression_toolkit.xquant.pytorch.tensorboard_utils import PytorchTensorboardUtils
|
31
31
|
from mct_quantizers.pytorch.metadata import get_metadata
|
32
32
|
|
33
|
+
|
33
34
|
class PytorchReportUtils(FrameworkReportUtils):
|
34
35
|
"""
|
35
36
|
Class with various utility components required for generating the report for a Pytorch model.
|
@@ -39,7 +40,6 @@ class PytorchReportUtils(FrameworkReportUtils):
|
|
39
40
|
Args:
|
40
41
|
report_dir: Logging dir path.
|
41
42
|
"""
|
42
|
-
fw_info = DEFAULT_PYTORCH_INFO
|
43
43
|
fw_impl = PytorchImplementation()
|
44
44
|
# Set the default Target Platform Capabilities (TPC) for PyTorch.
|
45
45
|
default_tpc = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
@@ -47,8 +47,7 @@ class PytorchReportUtils(FrameworkReportUtils):
|
|
47
47
|
framework_quantization_capabilities = attach2pytorch.attach(default_tpc)
|
48
48
|
|
49
49
|
dataset_utils = PytorchDatasetUtils()
|
50
|
-
model_folding = ModelFoldingUtils(
|
51
|
-
fw_impl=fw_impl,
|
50
|
+
model_folding = ModelFoldingUtils(fw_impl=fw_impl,
|
52
51
|
fw_default_fqc=framework_quantization_capabilities)
|
53
52
|
|
54
53
|
similarity_calculator = SimilarityCalculator(dataset_utils=dataset_utils,
|
@@ -58,11 +57,9 @@ class PytorchReportUtils(FrameworkReportUtils):
|
|
58
57
|
device=get_working_device())
|
59
58
|
|
60
59
|
tb_utils = PytorchTensorboardUtils(report_dir=report_dir,
|
61
|
-
fw_impl=fw_impl
|
62
|
-
fw_info=fw_info)
|
60
|
+
fw_impl=fw_impl)
|
63
61
|
|
64
|
-
super().__init__(
|
65
|
-
fw_impl=fw_impl,
|
62
|
+
super().__init__(fw_impl=fw_impl,
|
66
63
|
tb_utils=tb_utils,
|
67
64
|
dataset_utils=dataset_utils,
|
68
65
|
similarity_calculator=similarity_calculator,
|
@@ -41,18 +41,15 @@ class PytorchTensorboardUtils(TensorboardUtils):
|
|
41
41
|
|
42
42
|
def __init__(self,
|
43
43
|
report_dir: str,
|
44
|
-
fw_info: FrameworkInfo,
|
45
44
|
fw_impl: FrameworkImplementation):
|
46
45
|
"""
|
47
46
|
Initialize the PytorchTensorboardUtils instance.
|
48
47
|
|
49
48
|
Args:
|
50
49
|
report_dir: Directory where the reports are stored.
|
51
|
-
fw_info: Information about the framework being used.
|
52
50
|
fw_impl: Implementation methods for the framework.
|
53
51
|
"""
|
54
52
|
super().__init__(report_dir,
|
55
|
-
fw_info,
|
56
53
|
fw_impl)
|
57
54
|
|
58
55
|
def get_graph_for_tensorboard_display(self,
|
File without changes
|
File without changes
|
{mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/top_level.txt
RENAMED
File without changes
|