mct-nightly 1.8.0.27022023.post430__py3-none-any.whl → 1.8.0.27032023.post403__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/METADATA +7 -7
- {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/RECORD +65 -59
- {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +9 -15
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/keras/quantization_facade.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +4 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +4 -10
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -39
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +39 -24
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +50 -42
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -36
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +24 -5
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +60 -106
- model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
- model_compression_toolkit/gptq/common/gptq_training.py +28 -38
- model_compression_toolkit/gptq/keras/gptq_training.py +10 -28
- model_compression_toolkit/gptq/keras/graph_info.py +8 -33
- model_compression_toolkit/gptq/keras/quantization_facade.py +6 -12
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +0 -1
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +22 -128
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +11 -41
- model_compression_toolkit/gptq/pytorch/gptq_training.py +12 -4
- model_compression_toolkit/gptq/pytorch/graph_info.py +9 -6
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +9 -22
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +0 -20
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -1
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py +14 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +236 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +9 -31
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +30 -37
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +27 -36
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +21 -21
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +25 -26
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +12 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +12 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +53 -2
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +22 -4
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +24 -3
- model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
- {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +0 -0
model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -14,46 +14,52 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Tuple
|
|
16
16
|
|
|
17
|
-
import tensorflow as tf
|
|
18
|
-
from tensorflow.keras.layers import Layer
|
|
19
17
|
|
|
20
18
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
21
19
|
from model_compression_toolkit.core import common
|
|
22
|
-
from model_compression_toolkit.core.common import Graph
|
|
20
|
+
from model_compression_toolkit.core.common import Graph, Logger
|
|
21
|
+
from model_compression_toolkit.core.common.constants import FOUND_TF
|
|
23
22
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
23
|
+
|
|
24
|
+
if FOUND_TF:
|
|
25
|
+
import tensorflow as tf
|
|
26
|
+
from tensorflow.keras.layers import Layer
|
|
27
|
+
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
|
28
|
+
from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizers import get_quantization_quantizers
|
|
29
|
+
|
|
30
|
+
def _get_wrapper(node: common.BaseNode,
|
|
31
|
+
layer: Layer) -> qi.KerasQuantizationWrapper:
|
|
32
|
+
"""
|
|
33
|
+
A function which takes a computational graph node and a keras layer and perform the quantization wrapping
|
|
34
|
+
Args:
|
|
35
|
+
n: A node of mct graph.
|
|
36
|
+
layer: A keras layer
|
|
37
|
+
|
|
38
|
+
Returns: Wrapped layer with weights quantizers and activation quantizers
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
weights_quantizers, activation_quantizers = get_quantization_quantizers(node)
|
|
42
|
+
return qi.KerasQuantizationWrapper(layer, weights_quantizers, activation_quantizers)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_exportable_keras_model(graph: Graph) -> Tuple[tf.keras.models.Model, UserInformation]:
|
|
46
|
+
"""
|
|
47
|
+
Convert graph to an exportable Keras model (model with all quantization parameters).
|
|
48
|
+
An exportable model can then be exported using model_exporter, to retrieve the
|
|
49
|
+
final exported model.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
graph: Graph to convert to an exportable Keras model.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Exportable Keras model and user information.
|
|
56
|
+
"""
|
|
57
|
+
exportable_model, user_info = KerasModelBuilder(graph=graph,
|
|
58
|
+
wrapper=_get_wrapper).build_model()
|
|
59
|
+
exportable_model.trainable = False
|
|
60
|
+
return exportable_model, user_info
|
|
61
|
+
else:
|
|
62
|
+
def get_exportable_keras_model(*args, **kwargs): # pragma: no cover
|
|
63
|
+
Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
64
|
+
'when using get_exportable_keras_model. '
|
|
65
|
+
'Could not find Tensorflow package.')
|
|
@@ -15,15 +15,12 @@
|
|
|
15
15
|
from typing import Dict, Any
|
|
16
16
|
|
|
17
17
|
from model_compression_toolkit.core.common import BaseNode, Logger
|
|
18
|
-
from model_compression_toolkit.core.common.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED
|
|
18
|
+
from model_compression_toolkit.core.common.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED, CLUSTER_CENTERS, SCALE_PER_CHANNEL
|
|
19
19
|
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
20
20
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget
|
|
21
|
-
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import
|
|
22
|
-
|
|
23
|
-
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers
|
|
24
|
-
import \
|
|
25
|
-
BaseKerasInferableQuantizer
|
|
26
|
-
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import get_inferable_quantizer_class
|
|
22
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import BaseKerasInferableQuantizer
|
|
23
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import constants as qi_keras_consts
|
|
27
24
|
|
|
28
25
|
def get_inferable_quantizer_kwargs(node: BaseNode,
|
|
29
26
|
quantization_target: QuantizationTarget) -> Dict[str, Any]:
|
|
@@ -44,19 +41,29 @@ def get_inferable_quantizer_kwargs(node: BaseNode,
|
|
|
44
41
|
# Return the appropriate quantization parameters based on the quantization method
|
|
45
42
|
if quantization_method in [QuantizationMethod.POWER_OF_TWO,
|
|
46
43
|
QuantizationMethod.SYMMETRIC]:
|
|
47
|
-
return {
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
44
|
+
return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
|
|
45
|
+
qi_keras_consts.THRESHOLD: list(node_w_qc.weights_quantization_params[THRESHOLD].flatten()),
|
|
46
|
+
qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
|
|
47
|
+
qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
|
|
48
|
+
qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[THRESHOLD].shape)}
|
|
52
49
|
|
|
53
50
|
elif quantization_method in [QuantizationMethod.UNIFORM]:
|
|
54
|
-
return {
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
51
|
+
return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
|
|
52
|
+
qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
|
|
53
|
+
qi_keras_consts.MIN_RANGE: list(node_w_qc.weights_quantization_params[RANGE_MIN].flatten()),
|
|
54
|
+
qi_keras_consts.MAX_RANGE: list(node_w_qc.weights_quantization_params[RANGE_MAX].flatten()),
|
|
55
|
+
qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
|
|
56
|
+
qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[RANGE_MIN].shape)}
|
|
57
|
+
|
|
58
|
+
elif quantization_method in [QuantizationMethod.LUT_SYM_QUANTIZER, QuantizationMethod.LUT_POT_QUANTIZER]:
|
|
59
|
+
return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
|
|
60
|
+
qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
|
|
61
|
+
qi_keras_consts.CLUSTER_CENTERS: node_w_qc.weights_quantization_params[CLUSTER_CENTERS],
|
|
62
|
+
qi_keras_consts.THRESHOLD: list(node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten()),
|
|
63
|
+
qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
|
|
64
|
+
# TODO: how to pass multiplier nbits and eps for a specific node?
|
|
65
|
+
qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].shape)}
|
|
66
|
+
|
|
60
67
|
else:
|
|
61
68
|
Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
|
|
62
69
|
|
|
@@ -68,16 +75,24 @@ def get_inferable_quantizer_kwargs(node: BaseNode,
|
|
|
68
75
|
# Return the appropriate quantization parameters based on the quantization method
|
|
69
76
|
if quantization_method in [QuantizationMethod.POWER_OF_TWO,
|
|
70
77
|
QuantizationMethod.SYMMETRIC]:
|
|
71
|
-
return {
|
|
78
|
+
return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
|
|
72
79
|
# In activation quantization is per-tensor only - thus we hold the threshold as a list with a len of 1
|
|
73
|
-
|
|
74
|
-
|
|
80
|
+
qi_keras_consts.THRESHOLD: [node_qc.activation_quantization_params[THRESHOLD]],
|
|
81
|
+
qi_keras_consts.SIGNED: node_qc.activation_quantization_params[SIGNED]}
|
|
75
82
|
|
|
76
83
|
elif quantization_method in [QuantizationMethod.UNIFORM]:
|
|
77
|
-
return {
|
|
84
|
+
return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
|
|
78
85
|
# In activation quantization is per-tensor only - thus we hold the min/max as a list with a len of 1
|
|
79
|
-
|
|
80
|
-
|
|
86
|
+
qi_keras_consts.MIN_RANGE: [node_qc.activation_quantization_params[RANGE_MIN]],
|
|
87
|
+
qi_keras_consts.MAX_RANGE: [node_qc.activation_quantization_params[RANGE_MAX]]}
|
|
88
|
+
|
|
89
|
+
elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER]:
|
|
90
|
+
return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
|
|
91
|
+
qi_keras_consts.SIGNED: node_qc.activation_quantization_params[SIGNED],
|
|
92
|
+
qi_keras_consts.CLUSTER_CENTERS: node_qc.activation_quantization_params[CLUSTER_CENTERS],
|
|
93
|
+
qi_keras_consts.THRESHOLD: [node_qc.activation_quantization_params[THRESHOLD]]
|
|
94
|
+
# TODO: how to pass multiplier nbits and eps for a specific node?
|
|
95
|
+
}
|
|
81
96
|
else:
|
|
82
97
|
Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
|
|
83
98
|
else:
|
|
@@ -14,61 +14,69 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Any
|
|
16
16
|
|
|
17
|
-
from keras.engine.input_layer import InputLayer
|
|
18
17
|
|
|
19
18
|
from model_compression_toolkit.core.common import Logger
|
|
20
|
-
from model_compression_toolkit.
|
|
21
|
-
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
|
|
22
|
-
|
|
19
|
+
from model_compression_toolkit.core.common.constants import FOUND_TF
|
|
23
20
|
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
|
|
24
22
|
|
|
25
|
-
def is_keras_layer_exportable(layer: Any) -> bool:
|
|
26
|
-
"""
|
|
27
|
-
Check whether a Keras layer is a valid exportable layer or not.
|
|
28
23
|
|
|
29
|
-
|
|
30
|
-
|
|
24
|
+
if FOUND_TF:
|
|
25
|
+
from keras.engine.input_layer import InputLayer
|
|
26
|
+
from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
|
|
31
27
|
|
|
32
|
-
|
|
28
|
+
def is_keras_layer_exportable(layer: Any) -> bool:
|
|
29
|
+
"""
|
|
33
30
|
Check whether a Keras layer is a valid exportable layer or not.
|
|
34
|
-
"""
|
|
35
|
-
# Keras Input layers are not wrapped
|
|
36
|
-
if isinstance(layer, InputLayer):
|
|
37
|
-
return True
|
|
38
31
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
Logger.error(
|
|
42
|
-
f'Exportable layer must be wrapped using KerasQuantizationWrapper, but layer {layer.name} is of type '
|
|
43
|
-
f'{type(layer)}') # pragma: no cover
|
|
32
|
+
Args:
|
|
33
|
+
layer: Keras layer to check if considered to be valid for exporting.
|
|
44
34
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
35
|
+
Returns:
|
|
36
|
+
Check whether a Keras layer is a valid exportable layer or not.
|
|
37
|
+
"""
|
|
38
|
+
# Keras Input layers are not wrapped
|
|
39
|
+
if isinstance(layer, InputLayer):
|
|
40
|
+
return True
|
|
50
41
|
|
|
51
|
-
|
|
52
|
-
if not
|
|
42
|
+
valid_layer = isinstance(layer, KerasQuantizationWrapper)
|
|
43
|
+
if not valid_layer:
|
|
53
44
|
Logger.error(
|
|
54
|
-
f'
|
|
55
|
-
f'{type(
|
|
45
|
+
f'Exportable layer must be wrapped using KerasQuantizationWrapper, but layer {layer.name} is of type '
|
|
46
|
+
f'{type(layer)}') # pragma: no cover
|
|
56
47
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
48
|
+
valid_weights_quantizers = isinstance(layer.weights_quantizers, dict)
|
|
49
|
+
if not valid_weights_quantizers:
|
|
50
|
+
Logger.error(
|
|
51
|
+
f'KerasQuantizationWrapper must have a weights_quantizers but has a '
|
|
52
|
+
f'{type(layer.weights_quantizers)} object') # pragma: no cover
|
|
53
|
+
|
|
54
|
+
for _, weights_quantizer in layer.weights_quantizers.items():
|
|
55
|
+
if not isinstance(weights_quantizer, BaseInferableQuantizer):
|
|
56
|
+
Logger.error(
|
|
57
|
+
f'weights_quantizer must be a BaseInferableQuantizer object but has a '
|
|
58
|
+
f'{type(weights_quantizer)} object') # pragma: no cover
|
|
62
59
|
|
|
63
|
-
|
|
64
|
-
if not
|
|
60
|
+
valid_activation_quantizers = isinstance(layer.activation_quantizers, list)
|
|
61
|
+
if not valid_activation_quantizers:
|
|
65
62
|
Logger.error(
|
|
66
|
-
f'
|
|
67
|
-
f'{type(activation_quantizers)} object')
|
|
63
|
+
f'KerasQuantizationWrapper must have a activation_quantizers list but has a '
|
|
64
|
+
f'{type(layer.activation_quantizers)} object') # pragma: no cover
|
|
68
65
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
66
|
+
for activation_quantizers in layer.activation_quantizers:
|
|
67
|
+
if not isinstance(activation_quantizers, BaseInferableQuantizer):
|
|
68
|
+
Logger.error(
|
|
69
|
+
f'activation_quantizers must be a BaseInferableQuantizer object but has a '
|
|
70
|
+
f'{type(activation_quantizers)} object') # pragma: no cover
|
|
73
71
|
|
|
74
|
-
|
|
72
|
+
quantizers = layer.activation_quantizers + list(layer.weights_quantizers.values())
|
|
73
|
+
is_valid_quantizers = all([isinstance(x, BaseInferableQuantizer) for x in quantizers])
|
|
74
|
+
if not is_valid_quantizers:
|
|
75
|
+
Logger.error(f'Found a quantizer that is not of type BaseInferableQuantizer') # pragma: no cover
|
|
76
|
+
|
|
77
|
+
return True
|
|
78
|
+
else:
|
|
79
|
+
def is_keras_layer_exportable(*args, **kwargs): # pragma: no cover
|
|
80
|
+
Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
81
|
+
'when using is_keras_layer_exportable. '
|
|
82
|
+
'Could not find Tensorflow package.')
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -13,42 +13,49 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import torch
|
|
17
16
|
|
|
18
17
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
19
18
|
from model_compression_toolkit.core import common
|
|
20
|
-
from model_compression_toolkit.core.common import Graph
|
|
21
|
-
from model_compression_toolkit.core.
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
graph
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
19
|
+
from model_compression_toolkit.core.common import Graph, Logger
|
|
20
|
+
from model_compression_toolkit.core.common.constants import FOUND_TORCH
|
|
21
|
+
|
|
22
|
+
if FOUND_TORCH:
|
|
23
|
+
import torch
|
|
24
|
+
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
|
25
|
+
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizers import \
|
|
26
|
+
get_quantization_quantizers
|
|
27
|
+
|
|
28
|
+
def fully_quantized_wrapper(node: common.BaseNode, module: torch.nn.Module) -> qi.PytorchQuantizationWrapper:
|
|
29
|
+
"""
|
|
30
|
+
A function which takes a computational graph node and a pytorch module and
|
|
31
|
+
perform the quantization wrapping
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
node: A node of mct graph.
|
|
35
|
+
module: A Pytorch module
|
|
36
|
+
|
|
37
|
+
Returns: Wrapped layer
|
|
38
|
+
|
|
39
|
+
"""
|
|
40
|
+
weight_quantizers, activation_quantizers = get_quantization_quantizers(node)
|
|
41
|
+
wrapped_layer = qi.PytorchQuantizationWrapper(module, weight_quantizers, activation_quantizers)
|
|
42
|
+
return wrapped_layer
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_exportable_pytorch_model(graph: Graph):
|
|
46
|
+
"""
|
|
47
|
+
Convert graph to fully quantized PyTorch model.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
graph: Graph to convert to a PyTorch model.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Fully quantized PyTorch model.
|
|
54
|
+
"""
|
|
55
|
+
return PyTorchModelBuilder(graph=graph,
|
|
56
|
+
wrapper=fully_quantized_wrapper).build_model()
|
|
57
|
+
else:
|
|
58
|
+
def get_exportable_pytorch_model(*args, **kwargs): # pragma: no cover
|
|
59
|
+
Logger.error('Installing torch is mandatory '
|
|
60
|
+
'when using get_exportable_pytorch_model. '
|
|
61
|
+
'Could not find PyTorch package.')
|
|
@@ -16,7 +16,8 @@
|
|
|
16
16
|
from typing import Dict, Any
|
|
17
17
|
|
|
18
18
|
from model_compression_toolkit.core.common import BaseNode, Logger
|
|
19
|
-
from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX
|
|
19
|
+
from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX, \
|
|
20
|
+
SCALE_PER_CHANNEL, CLUSTER_CENTERS
|
|
20
21
|
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
21
22
|
from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
|
|
22
23
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
|
|
@@ -45,6 +46,15 @@ def get_weights_inferable_quantizer_kwargs(node: BaseNode) -> Dict[str, Any]:
|
|
|
45
46
|
qi_inferable_quantizers_constants.MIN_RANGE: node_w_qc.weights_quantization_params[RANGE_MIN].flatten(),
|
|
46
47
|
qi_inferable_quantizers_constants.MAX_RANGE: node_w_qc.weights_quantization_params[RANGE_MAX].flatten(),
|
|
47
48
|
qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
|
|
49
|
+
|
|
50
|
+
elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER, QuantizationMethod.LUT_SYM_QUANTIZER]:
|
|
51
|
+
return {qi_inferable_quantizers_constants.NUM_BITS: node_w_qc.weights_n_bits,
|
|
52
|
+
qi_inferable_quantizers_constants.CLUSTER_CENTERS: node_w_qc.weights_quantization_params[CLUSTER_CENTERS].flatten(),
|
|
53
|
+
qi_inferable_quantizers_constants.THRESHOLD: node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten(),
|
|
54
|
+
qi_inferable_quantizers_constants.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
|
|
55
|
+
qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
|
|
56
|
+
# TODO: Add MULTIPLIER_N_BITS & EPS to node quantization config
|
|
57
|
+
|
|
48
58
|
else:
|
|
49
59
|
Logger.critical(f'Not supported quantization method for weights inferable quantizers.') # pragma: no cover
|
|
50
60
|
|
|
@@ -65,6 +75,15 @@ def get_activation_inferable_quantizer_kwargs(node: BaseNode) -> Dict[str, Any]:
|
|
|
65
75
|
return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.activation_n_bits,
|
|
66
76
|
qi_inferable_quantizers_constants.MIN_RANGE: np.asarray([node_qc.activation_quantization_params[RANGE_MIN]]),
|
|
67
77
|
qi_inferable_quantizers_constants.MAX_RANGE: np.asarray([node_qc.activation_quantization_params[RANGE_MAX]])}
|
|
78
|
+
|
|
79
|
+
elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER]:
|
|
80
|
+
return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.activation_n_bits,
|
|
81
|
+
qi_inferable_quantizers_constants.CLUSTER_CENTERS: np.asarray(
|
|
82
|
+
[node_qc.activation_quantization_params[CLUSTER_CENTERS]]),
|
|
83
|
+
qi_inferable_quantizers_constants.THRESHOLD: np.asarray(
|
|
84
|
+
[node_qc.activation_quantization_params[THRESHOLD]]),
|
|
85
|
+
qi_inferable_quantizers_constants.SIGNED: node_qc.activation_quantization_params.get(SIGNED)}
|
|
86
|
+
# TODO: Add MULTIPLIER_N_BITS & EPS to node quantization config
|
|
68
87
|
else:
|
|
69
88
|
Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
|
|
70
89
|
|
|
@@ -111,10 +130,10 @@ def get_activations_quantizer_for_node(node: BaseNode) -> BasePyTorchInferableQu
|
|
|
111
130
|
node_act_qc = node.final_activation_quantization_cfg
|
|
112
131
|
activation_quantization_method = node_act_qc.activation_quantization_method
|
|
113
132
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
133
|
+
quantizer_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
|
|
134
|
+
activation_quantization_method,
|
|
135
|
+
BasePyTorchInferableQuantizer)
|
|
117
136
|
kwargs = get_activation_inferable_quantizer_kwargs(node)
|
|
118
137
|
|
|
119
|
-
return
|
|
138
|
+
return quantizer_for_node(**kwargs)
|
|
120
139
|
|
|
@@ -14,24 +14,31 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Any
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
BasePyTorchInferableQuantizer
|
|
17
|
+
from model_compression_toolkit.core.common import Logger
|
|
18
|
+
from model_compression_toolkit.core.common.constants import FOUND_TORCH
|
|
20
19
|
|
|
20
|
+
if FOUND_TORCH:
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
|
|
22
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
|
|
23
|
+
BasePyTorchInferableQuantizer
|
|
24
|
+
def is_pytorch_layer_exportable(layer: Any) -> bool:
|
|
25
|
+
"""
|
|
26
|
+
Check whether a torch Module is a valid exportable module or not.
|
|
21
27
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
Check whether a torch Module is a valid exportable module or not.
|
|
28
|
+
Args:
|
|
29
|
+
layer: PyTorch module to check if considered to be valid for exporting.
|
|
25
30
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
31
|
+
Returns:
|
|
32
|
+
Check whether a PyTorch layer is a valid exportable layer or not.
|
|
33
|
+
"""
|
|
34
|
+
if isinstance(layer, PytorchQuantizationWrapper):
|
|
35
|
+
quantizers = list(layer.weights_quantizers.values())
|
|
36
|
+
quantizers.extend(layer.activation_quantizers)
|
|
37
|
+
if all([isinstance(q, BasePyTorchInferableQuantizer) for q in quantizers]):
|
|
38
|
+
return True
|
|
39
|
+
return False
|
|
40
|
+
else:
|
|
41
|
+
def is_pytorch_layer_exportable(*args, **kwargs): # pragma: no cover
|
|
42
|
+
Logger.error('Installing torch is mandatory '
|
|
43
|
+
'when using is_pytorch_layer_exportable. '
|
|
44
|
+
'Could not find PyTorch package.')
|
|
@@ -12,3 +12,9 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GradientPTQConfigV2
|
|
17
|
+
from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization_experimental
|
|
18
|
+
from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
|
|
19
|
+
from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization_experimental
|
|
20
|
+
from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config
|