mct-nightly 2.4.0.20250706.701__py3-none-any.whl → 2.4.0.20250707.643__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/METADATA +1 -1
- {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/RECORD +36 -38
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +4 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +7 -4
- model_compression_toolkit/core/common/model_collector.py +11 -0
- model_compression_toolkit/core/common/pruning/memory_calculator.py +1 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +22 -87
- model_compression_toolkit/core/common/quantization/quantization_config.py +0 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +23 -17
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +26 -48
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +12 -7
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +4 -14
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -1
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +4 -13
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +5 -7
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +7 -5
- model_compression_toolkit/core/graph_prep_runner.py +1 -11
- model_compression_toolkit/core/keras/default_framework_info.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +21 -11
- model_compression_toolkit/core/keras/keras_implementation.py +2 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +8 -0
- model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +9 -1
- model_compression_toolkit/core/quantization_prep_runner.py +2 -2
- model_compression_toolkit/gptq/keras/quantization_facade.py +0 -3
- model_compression_toolkit/ptq/keras/quantization_facade.py +0 -3
- model_compression_toolkit/qat/keras/quantization_facade.py +0 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +0 -2
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +0 -6
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +2 -4
- model_compression_toolkit/core/common/model_validation.py +0 -41
- model_compression_toolkit/core/keras/keras_model_validation.py +0 -37
- {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/top_level.txt +0 -0
|
@@ -29,6 +29,10 @@ import numpy as np
|
|
|
29
29
|
from model_compression_toolkit.logger import Logger
|
|
30
30
|
|
|
31
31
|
|
|
32
|
+
# default output channel axis to use when it's not defined in node's fw_info.
|
|
33
|
+
_default_output_channel_axis = 1
|
|
34
|
+
|
|
35
|
+
|
|
32
36
|
class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplementation):
|
|
33
37
|
"""
|
|
34
38
|
Implementation of the PruningFramework for the Pytorch framework. This class provides
|
|
@@ -190,6 +194,10 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
|
190
194
|
|
|
191
195
|
return attributes_with_axis
|
|
192
196
|
|
|
197
|
+
@property
|
|
198
|
+
def default_output_channel_axis(self):
|
|
199
|
+
return _default_output_channel_axis
|
|
200
|
+
|
|
193
201
|
|
|
194
202
|
def _is_pytorch_node_pruning_section_edge(node: BaseNode) -> bool:
|
|
195
203
|
"""
|
|
@@ -283,7 +291,7 @@ def _edit_node_input_shape(node: BaseNode,
|
|
|
283
291
|
|
|
284
292
|
# Adjust the last dimension of the shape to match the number of unpruned (retained) channels.
|
|
285
293
|
# This is done by summing the mask, as each '1' in the mask represents a retained channel.
|
|
286
|
-
channel_axis = node.out_channel_axis
|
|
294
|
+
channel_axis = _default_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis
|
|
287
295
|
new_input_shape[0][channel_axis] = int(np.sum(input_mask))
|
|
288
296
|
|
|
289
297
|
# Update the node's input shape with the new dimensions.
|
|
@@ -87,8 +87,8 @@ def quantization_preparation_runner(graph: Graph,
|
|
|
87
87
|
# Calculate quantization params
|
|
88
88
|
######################################
|
|
89
89
|
|
|
90
|
-
calculate_quantization_params(graph, fw_impl=fw_impl,
|
|
91
|
-
hessian_info_service=hessian_info_service)
|
|
90
|
+
calculate_quantization_params(graph, core_config.quantization_config, fw_impl=fw_impl,
|
|
91
|
+
repr_data_gen_fn=representative_data_gen, hessian_info_service=hessian_info_service)
|
|
92
92
|
|
|
93
93
|
if tb_w is not None:
|
|
94
94
|
tb_w.add_graph(graph, 'thresholds_selection')
|
|
@@ -43,7 +43,6 @@ if FOUND_TF:
|
|
|
43
43
|
import tensorflow as tf
|
|
44
44
|
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
|
|
45
45
|
from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
|
|
46
|
-
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
|
47
46
|
from tensorflow.keras.models import Model
|
|
48
47
|
from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss, sample_layer_attention_loss
|
|
49
48
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
@@ -235,8 +234,6 @@ if FOUND_TF:
|
|
|
235
234
|
if core_config.debug_config.bypass:
|
|
236
235
|
return in_model, None
|
|
237
236
|
|
|
238
|
-
KerasModelValidation(model=in_model).validate()
|
|
239
|
-
|
|
240
237
|
if core_config.is_mixed_precision_enabled:
|
|
241
238
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
242
239
|
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
|
|
@@ -38,7 +38,6 @@ if FOUND_TF:
|
|
|
38
38
|
AttachTpcToKeras
|
|
39
39
|
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
|
|
40
40
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
|
41
|
-
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
|
42
41
|
from tensorflow.keras.models import Model
|
|
43
42
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
44
43
|
from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
|
|
@@ -129,8 +128,6 @@ if FOUND_TF:
|
|
|
129
128
|
if core_config.debug_config.bypass:
|
|
130
129
|
return in_model, None
|
|
131
130
|
|
|
132
|
-
KerasModelValidation(model=in_model).validate()
|
|
133
|
-
|
|
134
131
|
if core_config.is_mixed_precision_enabled:
|
|
135
132
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
136
133
|
Logger.critical("Given quantization config to mixed-precision facade is not of type "
|
|
@@ -38,7 +38,6 @@ if FOUND_TF:
|
|
|
38
38
|
|
|
39
39
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
|
40
40
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
|
41
|
-
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
|
42
41
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
43
42
|
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
|
|
44
43
|
|
|
@@ -175,8 +174,6 @@ if FOUND_TF:
|
|
|
175
174
|
f"If you encounter an issue, please open an issue in our GitHub "
|
|
176
175
|
f"project https://github.com/sony/model_optimization")
|
|
177
176
|
|
|
178
|
-
KerasModelValidation(model=in_model).validate()
|
|
179
|
-
|
|
180
177
|
if core_config.is_mixed_precision_enabled:
|
|
181
178
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
182
179
|
Logger.critical("Given quantization config to mixed-precision facade is not of type "
|
|
@@ -48,7 +48,6 @@ def get_trainable_quantizer_weights_config(
|
|
|
48
48
|
final_attr_cfg.enable_weights_quantization,
|
|
49
49
|
final_attr_cfg.weights_channels_axis[0], # Output channel axis
|
|
50
50
|
final_attr_cfg.weights_per_channel_threshold,
|
|
51
|
-
final_node_cfg.min_threshold,
|
|
52
51
|
weights_quantization_candidates)
|
|
53
52
|
|
|
54
53
|
|
|
@@ -76,7 +75,6 @@ def get_trainable_quantizer_activation_config(
|
|
|
76
75
|
final_cfg.activation_n_bits,
|
|
77
76
|
final_cfg.activation_quantization_params,
|
|
78
77
|
final_cfg.enable_activation_quantization,
|
|
79
|
-
final_cfg.min_threshold,
|
|
80
78
|
activation_quantization_candidates)
|
|
81
79
|
|
|
82
80
|
|
|
@@ -44,7 +44,6 @@ class TrainableQuantizerActivationConfig:
|
|
|
44
44
|
activation_n_bits: int,
|
|
45
45
|
activation_quantization_params: Dict,
|
|
46
46
|
enable_activation_quantization: bool,
|
|
47
|
-
min_threshold: float,
|
|
48
47
|
activation_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None,
|
|
49
48
|
):
|
|
50
49
|
"""
|
|
@@ -55,13 +54,11 @@ class TrainableQuantizerActivationConfig:
|
|
|
55
54
|
activation_n_bits (int): Number of bits to quantize the activations.
|
|
56
55
|
activation_quantization_params (Dict): Dictionary that contains activation quantization params.
|
|
57
56
|
enable_activation_quantization (bool): Whether to quantize the layer's activations or not.
|
|
58
|
-
min_threshold (float): Minimum threshold to use during thresholds selection.
|
|
59
57
|
"""
|
|
60
58
|
self.activation_quantization_method = activation_quantization_method
|
|
61
59
|
self.activation_n_bits = activation_n_bits
|
|
62
60
|
self.activation_quantization_params = activation_quantization_params
|
|
63
61
|
self.enable_activation_quantization = enable_activation_quantization
|
|
64
|
-
self.min_threshold = min_threshold
|
|
65
62
|
self.activation_bits_candidates = activation_quantization_candidates
|
|
66
63
|
|
|
67
64
|
|
|
@@ -73,7 +70,6 @@ class TrainableQuantizerWeightsConfig:
|
|
|
73
70
|
enable_weights_quantization: bool,
|
|
74
71
|
weights_channels_axis: int,
|
|
75
72
|
weights_per_channel_threshold: bool,
|
|
76
|
-
min_threshold: float,
|
|
77
73
|
weights_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None,
|
|
78
74
|
):
|
|
79
75
|
"""
|
|
@@ -86,7 +82,6 @@ class TrainableQuantizerWeightsConfig:
|
|
|
86
82
|
enable_weights_quantization (bool): Whether to quantize the layer's weights or not.
|
|
87
83
|
weights_channels_axis (int): Axis to quantize a node's kernel when quantizing per-channel.
|
|
88
84
|
weights_per_channel_threshold (bool): Whether to quantize the weights per-channel or not (per-tensor).
|
|
89
|
-
min_threshold (float): Minimum threshold to use during thresholds selection.
|
|
90
85
|
"""
|
|
91
86
|
self.weights_quantization_method = weights_quantization_method
|
|
92
87
|
self.weights_n_bits = weights_n_bits
|
|
@@ -94,5 +89,4 @@ class TrainableQuantizerWeightsConfig:
|
|
|
94
89
|
self.enable_weights_quantization = enable_weights_quantization
|
|
95
90
|
self.weights_channels_axis = weights_channels_axis
|
|
96
91
|
self.weights_per_channel_threshold = weights_per_channel_threshold
|
|
97
|
-
self.min_threshold = min_threshold
|
|
98
92
|
self.weights_bits_candidates = weights_quantization_candidates
|
|
@@ -77,13 +77,11 @@ def config_deserialization(in_config: dict) -> Union[TrainableQuantizerWeightsCo
|
|
|
77
77
|
weights_quantization_params=weights_quantization_params,
|
|
78
78
|
enable_weights_quantization=in_config[C.ENABLE_WEIGHTS_QUANTIZATION],
|
|
79
79
|
weights_channels_axis=in_config[C.WEIGHTS_CHANNELS_AXIS],
|
|
80
|
-
weights_per_channel_threshold=in_config[C.WEIGHTS_PER_CHANNEL_THRESHOLD]
|
|
81
|
-
min_threshold=in_config[C.MIN_THRESHOLD])
|
|
80
|
+
weights_per_channel_threshold=in_config[C.WEIGHTS_PER_CHANNEL_THRESHOLD])
|
|
82
81
|
elif in_config[C.IS_ACTIVATIONS]:
|
|
83
82
|
return TrainableQuantizerActivationConfig(activation_quantization_method=QuantizationMethod(in_config[C.ACTIVATION_QUANTIZATION_METHOD]),
|
|
84
83
|
activation_n_bits=in_config[C.ACTIVATION_N_BITS],
|
|
85
84
|
activation_quantization_params=in_config[C.ACTIVATION_QUANTIZATION_PARAMS],
|
|
86
|
-
enable_activation_quantization=in_config[C.ENABLE_ACTIVATION_QUANTIZATION]
|
|
87
|
-
min_threshold=in_config[C.MIN_THRESHOLD])
|
|
85
|
+
enable_activation_quantization=in_config[C.ENABLE_ACTIVATION_QUANTIZATION])
|
|
88
86
|
else:
|
|
89
87
|
raise NotImplemented # pragma: no cover
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
from abc import abstractmethod
|
|
2
|
-
from typing import Any
|
|
3
|
-
|
|
4
|
-
from model_compression_toolkit.core import FrameworkInfo
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class ModelValidation:
|
|
8
|
-
"""
|
|
9
|
-
Class to define validation methods in order to validate the received model to quantize.
|
|
10
|
-
"""
|
|
11
|
-
|
|
12
|
-
def __init__(self,
|
|
13
|
-
model: Any):
|
|
14
|
-
"""
|
|
15
|
-
Initialize a ModelValidation object.
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
model: Model to check its validity.
|
|
19
|
-
"""
|
|
20
|
-
self.model = model
|
|
21
|
-
|
|
22
|
-
@abstractmethod
|
|
23
|
-
def validate_output_channel_consistency(self):
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
Validate that output channels index in all layers of the model are the same.
|
|
27
|
-
If the model has layers with different output channels index, it should throw an exception.
|
|
28
|
-
|
|
29
|
-
"""
|
|
30
|
-
raise NotImplemented(
|
|
31
|
-
f'Framework validation class did not implement validate_output_channel_consistency') # pragma: no cover
|
|
32
|
-
|
|
33
|
-
def validate(self):
|
|
34
|
-
"""
|
|
35
|
-
|
|
36
|
-
Run all validation methods before the quantization process starts.
|
|
37
|
-
|
|
38
|
-
"""
|
|
39
|
-
self.validate_output_channel_consistency()
|
|
40
|
-
|
|
41
|
-
|
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
from tensorflow.keras.models import Model
|
|
2
|
-
|
|
3
|
-
from model_compression_toolkit.core.common.framework_info import get_fw_info
|
|
4
|
-
from model_compression_toolkit.core.common.framework_info import ChannelAxis
|
|
5
|
-
from model_compression_toolkit.core.common.model_validation import ModelValidation
|
|
6
|
-
from model_compression_toolkit.core.keras.constants import CHANNELS_FORMAT, CHANNELS_FORMAT_LAST, CHANNELS_FORMAT_FIRST
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class KerasModelValidation(ModelValidation):
|
|
10
|
-
"""
|
|
11
|
-
Class to define validation methods in order to validate the received Keras model to quantize.
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
def __init__(self, model: Model):
|
|
15
|
-
"""
|
|
16
|
-
Initialize a KerasModelValidation object.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
model: Keras model to check its validity.
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
super(KerasModelValidation, self).__init__(model=model)
|
|
23
|
-
|
|
24
|
-
def validate_output_channel_consistency(self):
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
Validate that output channels index in all layers of the model are the same.
|
|
28
|
-
If the model has layers with different output channels index, an exception is thrown.
|
|
29
|
-
|
|
30
|
-
"""
|
|
31
|
-
fw_info = get_fw_info()
|
|
32
|
-
for layer in self.model.layers:
|
|
33
|
-
data_format = layer.get_config().get(CHANNELS_FORMAT)
|
|
34
|
-
if data_format is not None:
|
|
35
|
-
assert (data_format == CHANNELS_FORMAT_LAST and fw_info.get_out_channel_axis(layer) == ChannelAxis.NHWC.value
|
|
36
|
-
or data_format == CHANNELS_FORMAT_FIRST and fw_info.get_out_channel_axis(layer) == ChannelAxis.NCHW.value), \
|
|
37
|
-
f'Model can not have layers with different data formats.'
|
|
File without changes
|
|
File without changes
|
{mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/top_level.txt
RENAMED
|
File without changes
|