mct-nightly 1.8.0.20052023.post401__py3-none-any.whl → 1.8.0.20230610.post356__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/METADATA +10 -7
- {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/RECORD +68 -115
- model_compression_toolkit/__init__.py +23 -3
- model_compression_toolkit/core/common/framework_info.py +1 -1
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +16 -9
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +8 -34
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +103 -28
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +39 -44
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/int8_tflite_exporter.py +20 -18
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +3 -3
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +36 -9
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +24 -32
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +31 -8
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +5 -5
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +34 -8
- model_compression_toolkit/gptq/keras/gptq_training.py +15 -16
- model_compression_toolkit/gptq/keras/graph_info.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +4 -5
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +5 -7
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +6 -6
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +7 -7
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +6 -6
- model_compression_toolkit/gptq/pytorch/gptq_training.py +30 -10
- model_compression_toolkit/gptq/pytorch/graph_info.py +5 -2
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +4 -2
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +4 -4
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +5 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +7 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +7 -8
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +7 -8
- model_compression_toolkit/qat/common/__init__.py +2 -1
- model_compression_toolkit/qat/common/qat_config.py +2 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +18 -8
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +1 -1
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +11 -11
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +11 -12
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +12 -13
- model_compression_toolkit/qat/pytorch/quantization_facade.py +27 -16
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +31 -4
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +10 -9
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +11 -10
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +2 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +1 -25
- model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py → trainable_infrastructure/__init__.py} +3 -10
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/base_trainable_quantizer.py +3 -3
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/get_quantizer_config.py +1 -1
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/get_quantizers.py +3 -3
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/base_keras_quantizer.py +4 -4
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/config_serialization.py +2 -2
- model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure → trainable_infrastructure}/keras/load_model.py +16 -23
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/pytorch/base_pytorch_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/__init__.py +0 -23
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +0 -87
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +0 -46
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +0 -31
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +0 -53
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +0 -49
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +0 -147
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +0 -345
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +0 -85
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +0 -27
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +0 -148
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +0 -65
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +0 -86
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +0 -111
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +0 -56
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +0 -79
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +0 -179
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +0 -67
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +0 -87
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +0 -163
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +0 -66
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +0 -269
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +0 -152
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +0 -35
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +0 -96
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +0 -62
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +0 -83
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +0 -100
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +0 -95
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +0 -48
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +0 -70
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +0 -57
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +0 -26
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +0 -77
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +0 -106
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +0 -66
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +0 -104
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +0 -109
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +0 -14
- {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/WHEEL +0 -0
- {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure → trainable_infrastructure/common}/__init__.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure → trainable_infrastructure/common}/constants.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/quant_utils.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/trainable_quantizer_config.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/common → trainable_infrastructure/keras}/__init__.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/quantizer_utils.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/keras → trainable_infrastructure/pytorch}/__init__.py +0 -0
|
@@ -21,15 +21,13 @@ from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_qu
|
|
|
21
21
|
get_activation_inferable_quantizer_kwargs
|
|
22
22
|
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
|
23
23
|
BasePytorchGPTQTrainableQuantizer
|
|
24
|
-
from
|
|
25
|
-
|
|
26
|
-
from
|
|
27
|
-
|
|
28
|
-
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizer_config import \
|
|
24
|
+
from mct_quantizers import QuantizationTarget
|
|
25
|
+
from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
|
|
26
|
+
from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer
|
|
27
|
+
from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
|
|
29
28
|
get_trainable_quantizer_weights_config
|
|
30
29
|
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
|
|
31
|
-
from model_compression_toolkit.
|
|
32
|
-
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizers import \
|
|
30
|
+
from model_compression_toolkit.trainable_infrastructure.common.get_quantizers import \
|
|
33
31
|
get_trainable_quantizer_class
|
|
34
32
|
|
|
35
33
|
|
|
@@ -21,7 +21,7 @@ from torch import nn
|
|
|
21
21
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
22
22
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
24
|
-
from
|
|
24
|
+
from mct_quantizers import PytorchQuantizationWrapper
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class LinearTempDecay:
|
|
@@ -18,8 +18,8 @@ from typing import Dict
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
|
|
20
20
|
from model_compression_toolkit.core.common import max_power_of_two
|
|
21
|
-
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
22
21
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
22
|
+
from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
|
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
|
|
24
24
|
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
|
25
25
|
BasePytorchGPTQTrainableQuantizer
|
|
@@ -28,11 +28,11 @@ from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as quti
|
|
|
28
28
|
from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
|
|
29
29
|
SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
|
|
30
30
|
from model_compression_toolkit.constants import THRESHOLD, MIN_THRESHOLD
|
|
31
|
-
from model_compression_toolkit.
|
|
32
|
-
from
|
|
33
|
-
from model_compression_toolkit.
|
|
31
|
+
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig
|
|
32
|
+
from mct_quantizers import mark_quantizer
|
|
33
|
+
from model_compression_toolkit.trainable_infrastructure.common.quant_utils import \
|
|
34
34
|
get_threshold_reshape_shape
|
|
35
|
-
from model_compression_toolkit.
|
|
35
|
+
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
def soft_rounding_symmetric_quantizer(input_tensor: torch.Tensor,
|
|
@@ -68,7 +68,7 @@ def soft_rounding_symmetric_quantizer(input_tensor: torch.Tensor,
|
|
|
68
68
|
max_val=int_threshold - 1)
|
|
69
69
|
|
|
70
70
|
|
|
71
|
-
@mark_quantizer(quantization_target=
|
|
71
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Weights,
|
|
72
72
|
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
|
|
73
73
|
quantizer_type=RoundingType.SoftQuantizer)
|
|
74
74
|
class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
@@ -110,7 +110,7 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
|
110
110
|
def initialize_quantization(self,
|
|
111
111
|
tensor_shape: torch.Size,
|
|
112
112
|
name: str,
|
|
113
|
-
layer:
|
|
113
|
+
layer: PytorchQuantizationWrapper):
|
|
114
114
|
"""
|
|
115
115
|
Add quantizer parameters to the quantizer parameters dictionary
|
|
116
116
|
|
|
@@ -17,9 +17,9 @@ import torch.nn as nn
|
|
|
17
17
|
from typing import Dict
|
|
18
18
|
import numpy as np
|
|
19
19
|
|
|
20
|
-
from model_compression_toolkit import
|
|
21
|
-
from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
|
|
20
|
+
from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
|
|
22
21
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
22
|
+
from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
|
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
|
|
24
24
|
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
|
25
25
|
BasePytorchGPTQTrainableQuantizer
|
|
@@ -27,10 +27,9 @@ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_
|
|
|
27
27
|
from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
|
|
28
28
|
from model_compression_toolkit.gptq.common.gptq_constants import SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
|
|
29
29
|
from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import fix_range_to_include_zero
|
|
30
|
-
from model_compression_toolkit.
|
|
31
|
-
from
|
|
32
|
-
|
|
33
|
-
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import \
|
|
30
|
+
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig
|
|
31
|
+
from mct_quantizers import mark_quantizer
|
|
32
|
+
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import \
|
|
34
33
|
VariableGroup
|
|
35
34
|
from model_compression_toolkit.constants import RANGE_MAX, RANGE_MIN
|
|
36
35
|
|
|
@@ -63,7 +62,7 @@ def soft_rounding_unifrom_quantizer(input_tensor: torch.Tensor,
|
|
|
63
62
|
max_val=2 ** num_bits - 1) + min_range
|
|
64
63
|
|
|
65
64
|
|
|
66
|
-
@mark_quantizer(quantization_target=
|
|
65
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Weights,
|
|
67
66
|
quantization_method=[QuantizationMethod.UNIFORM],
|
|
68
67
|
quantizer_type=RoundingType.SoftQuantizer)
|
|
69
68
|
class UniformSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
@@ -100,7 +99,7 @@ class UniformSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
|
100
99
|
def initialize_quantization(self,
|
|
101
100
|
tensor_shape: torch.Size,
|
|
102
101
|
name: str,
|
|
103
|
-
layer:
|
|
102
|
+
layer: PytorchQuantizationWrapper):
|
|
104
103
|
"""
|
|
105
104
|
Add quantizer parameters to the quantizer parameters dictionary
|
|
106
105
|
|
|
@@ -18,8 +18,8 @@ from typing import Dict
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
from model_compression_toolkit.core.common.defaultdict import DefaultDict
|
|
20
20
|
|
|
21
|
-
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
22
21
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
22
|
+
from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
|
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
|
|
24
24
|
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
|
25
25
|
BasePytorchGPTQTrainableQuantizer
|
|
@@ -27,11 +27,10 @@ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_
|
|
|
27
27
|
from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
|
|
28
28
|
from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD, MAX_LSB_CHANGE
|
|
29
29
|
from model_compression_toolkit.constants import THRESHOLD
|
|
30
|
-
from model_compression_toolkit.
|
|
31
|
-
from
|
|
32
|
-
|
|
33
|
-
from model_compression_toolkit.
|
|
34
|
-
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
|
|
30
|
+
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig
|
|
31
|
+
from mct_quantizers import mark_quantizer
|
|
32
|
+
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
33
|
+
from model_compression_toolkit.trainable_infrastructure.common.quant_utils import \
|
|
35
34
|
get_threshold_reshape_shape
|
|
36
35
|
|
|
37
36
|
|
|
@@ -75,7 +74,7 @@ def pertubation_symmetric_quantizer(input_tensor: torch.Tensor,
|
|
|
75
74
|
return delta * qutils.ste_clip(tensor_q, max_val=max_int, min_val=min_int)
|
|
76
75
|
|
|
77
76
|
|
|
78
|
-
@mark_quantizer(quantization_target=
|
|
77
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Weights,
|
|
79
78
|
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
|
|
80
79
|
quantizer_type=RoundingType.STE)
|
|
81
80
|
class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
|
|
@@ -109,7 +108,7 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
|
|
|
109
108
|
def initialize_quantization(self,
|
|
110
109
|
tensor_shape: torch.Size,
|
|
111
110
|
name: str,
|
|
112
|
-
layer:
|
|
111
|
+
layer: PytorchQuantizationWrapper):
|
|
113
112
|
"""
|
|
114
113
|
Add quantizer parameters to the quantizer parameters dictionary
|
|
115
114
|
|
|
@@ -12,4 +12,5 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from model_compression_toolkit.
|
|
15
|
+
from model_compression_toolkit.trainable_infrastructure.common.constants import THRESHOLD_TENSOR, \
|
|
16
|
+
WEIGHTS_QUANTIZATION_PARAMS
|
|
@@ -20,8 +20,8 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
|
20
20
|
from model_compression_toolkit.logger import Logger
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
def
|
|
24
|
-
|
|
23
|
+
def is_qat_applicable(node: common.BaseNode,
|
|
24
|
+
fw_info: FrameworkInfo) -> bool:
|
|
25
25
|
"""
|
|
26
26
|
A function for deciding if a layer should be fine-tuned during QAT
|
|
27
27
|
Args:
|
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.constants import FOUND_TF
|
|
|
22
22
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
23
23
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
24
24
|
MixedPrecisionQuantizationConfigV2
|
|
25
|
-
from
|
|
25
|
+
from mct_quantizers import KerasActivationQuantizationHolder, KerasQuantizationWrapper
|
|
26
26
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
27
27
|
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
|
|
28
28
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
@@ -40,20 +40,18 @@ if FOUND_TF:
|
|
|
40
40
|
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
|
41
41
|
|
|
42
42
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
43
|
-
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
44
43
|
|
|
45
44
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
46
45
|
from model_compression_toolkit.core import common
|
|
47
46
|
from model_compression_toolkit.core.common import BaseNode
|
|
48
47
|
from model_compression_toolkit.constants import TENSORFLOW
|
|
49
48
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
50
|
-
from model_compression_toolkit.qat.common.qat_config import
|
|
49
|
+
from model_compression_toolkit.qat.common.qat_config import is_qat_applicable
|
|
51
50
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
52
51
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
53
52
|
from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder, \
|
|
54
53
|
get_activation_quantizer_holder
|
|
55
54
|
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
|
56
|
-
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
57
55
|
|
|
58
56
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
|
59
57
|
|
|
@@ -71,9 +69,12 @@ if FOUND_TF:
|
|
|
71
69
|
Returns: Wrapped layer
|
|
72
70
|
|
|
73
71
|
"""
|
|
74
|
-
if
|
|
75
|
-
weights_quantizers,
|
|
76
|
-
|
|
72
|
+
if is_qat_applicable(n, DEFAULT_KERAS_INFO):
|
|
73
|
+
weights_quantizers, _ = quantization_builder(n,
|
|
74
|
+
qat_config,
|
|
75
|
+
DEFAULT_KERAS_INFO)
|
|
76
|
+
if len(weights_quantizers) > 0:
|
|
77
|
+
return KerasQuantizationWrapper(layer, weights_quantizers)
|
|
77
78
|
return layer
|
|
78
79
|
|
|
79
80
|
|
|
@@ -255,8 +256,17 @@ if FOUND_TF:
|
|
|
255
256
|
|
|
256
257
|
"""
|
|
257
258
|
def _export(layer):
|
|
258
|
-
if isinstance(layer,
|
|
259
|
+
if isinstance(layer, KerasQuantizationWrapper):
|
|
259
260
|
layer.convert_to_inferable_quantizers()
|
|
261
|
+
# In the KerasActivationQuantizationHolder case - converting the quantizers only
|
|
262
|
+
# is not enough. We need to create a new layer with inferable quantizers. The reason for that
|
|
263
|
+
# is that if we only convert the quantizers, the layer will have some weights (such as min, max,
|
|
264
|
+
# threshold) that do not match the configuration, thus loading such a model will fail.
|
|
265
|
+
# To overcome this, the convert_to_inferable_quantizers of KerasActivationQuantizationHolder
|
|
266
|
+
# creates a new layer from its new configuration after converting the trainable quantizer
|
|
267
|
+
# to an inferable quantizer.
|
|
268
|
+
elif isinstance(layer, KerasActivationQuantizationHolder):
|
|
269
|
+
layer = layer.convert_to_inferable_quantizers()
|
|
260
270
|
return layer
|
|
261
271
|
|
|
262
272
|
# clone each layer in the model and apply _export to layers with TrainableQuantizeWrappers
|
|
@@ -17,7 +17,7 @@ from typing import Union
|
|
|
17
17
|
from model_compression_toolkit.logger import Logger
|
|
18
18
|
from model_compression_toolkit.constants import FOUND_TF
|
|
19
19
|
|
|
20
|
-
from model_compression_toolkit.
|
|
20
|
+
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
21
21
|
TrainableQuantizerActivationConfig, BaseKerasTrainableQuantizer
|
|
22
22
|
|
|
23
23
|
if FOUND_TF:
|
|
@@ -12,34 +12,34 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from typing import Tuple, Dict, List,
|
|
15
|
+
from typing import Tuple, Dict, List, Callable
|
|
16
16
|
|
|
17
17
|
from model_compression_toolkit.core import common
|
|
18
18
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
19
19
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
20
20
|
from model_compression_toolkit.logger import Logger
|
|
21
|
-
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
|
21
|
+
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
|
22
22
|
from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
|
|
23
|
-
from
|
|
24
|
-
from model_compression_toolkit.
|
|
23
|
+
from mct_quantizers import QuantizationTarget, KerasActivationQuantizationHolder
|
|
24
|
+
from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
|
|
25
25
|
get_trainable_quantizer_weights_config, get_trainable_quantizer_activation_config, \
|
|
26
26
|
get_trainable_quantizer_quantization_candidates
|
|
27
|
-
from model_compression_toolkit.
|
|
27
|
+
from model_compression_toolkit.trainable_infrastructure.common.get_quantizers import \
|
|
28
28
|
get_trainable_quantizer_class
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
def get_activation_quantizer_holder(n: common.BaseNode,
|
|
32
|
-
qat_config: QATConfig) ->
|
|
32
|
+
qat_config: QATConfig) -> Callable:
|
|
33
33
|
"""
|
|
34
|
-
Retrieve a
|
|
34
|
+
Retrieve a KerasActivationQuantizationHolder layer to use for activation quantization for a node.
|
|
35
35
|
If the layer is not supposed to be wrapped with activation quantizers - return None.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
|
-
n: Node to get
|
|
38
|
+
n: Node to get KerasActivationQuantizationHolder to attach in its output.
|
|
39
39
|
qat_config: Configuration of QAT (such as training methods for example).
|
|
40
40
|
|
|
41
41
|
Returns:
|
|
42
|
-
A
|
|
42
|
+
A KerasActivationQuantizationHolder layer for the node activation quantization.
|
|
43
43
|
"""
|
|
44
44
|
_, activation_quantizers = quantization_builder(n,
|
|
45
45
|
qat_config,
|
|
@@ -49,8 +49,8 @@ def get_activation_quantizer_holder(n: common.BaseNode,
|
|
|
49
49
|
# thus we make sure this is the only possible case (unless it's a node with no activation
|
|
50
50
|
# quantization, which in this case has an empty list).
|
|
51
51
|
if len(activation_quantizers) == 1:
|
|
52
|
-
return
|
|
53
|
-
Logger.error(f'
|
|
52
|
+
return KerasActivationQuantizationHolder(activation_quantizers[0])
|
|
53
|
+
Logger.error(f'KerasActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers were found for node {n}')
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
def quantization_builder(n: common.BaseNode,
|
|
@@ -19,25 +19,24 @@ import numpy as np
|
|
|
19
19
|
import tensorflow as tf
|
|
20
20
|
from tensorflow.python.framework.tensor_shape import TensorShape
|
|
21
21
|
from model_compression_toolkit.constants import SIGNED
|
|
22
|
-
from model_compression_toolkit.
|
|
22
|
+
from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
|
|
23
23
|
|
|
24
24
|
from model_compression_toolkit.qat import TrainingMethod
|
|
25
25
|
|
|
26
26
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
27
|
+
from mct_quantizers import QuantizationTarget, mark_quantizer, KerasQuantizationWrapper
|
|
27
28
|
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
|
|
28
|
-
from model_compression_toolkit import
|
|
29
|
+
from model_compression_toolkit import constants as C
|
|
29
30
|
|
|
30
31
|
from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
|
|
31
|
-
from model_compression_toolkit.
|
|
32
|
+
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
32
33
|
TrainableQuantizerActivationConfig
|
|
33
|
-
from
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
ActivationSymmetricInferableQuantizer
|
|
37
|
-
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
34
|
+
from mct_quantizers.keras.quantizers import WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer, \
|
|
35
|
+
ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer
|
|
36
|
+
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
38
37
|
|
|
39
38
|
|
|
40
|
-
@mark_quantizer(quantization_target=
|
|
39
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Weights,
|
|
41
40
|
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
|
|
42
41
|
quantizer_type=TrainingMethod.STE)
|
|
43
42
|
class STEWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
|
|
@@ -84,7 +83,7 @@ class STEWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
|
|
|
84
83
|
def initialize_quantization(self,
|
|
85
84
|
tensor_shape: TensorShape,
|
|
86
85
|
name: str,
|
|
87
|
-
layer:
|
|
86
|
+
layer: KerasQuantizationWrapper):
|
|
88
87
|
"""
|
|
89
88
|
Add quantizer parameters to the quantizer parameters dictionary
|
|
90
89
|
|
|
@@ -171,7 +170,7 @@ class STEWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
|
|
|
171
170
|
input_rank=len(self.threshold_shape))
|
|
172
171
|
|
|
173
172
|
|
|
174
|
-
@mark_quantizer(quantization_target=
|
|
173
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Activation,
|
|
175
174
|
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
|
|
176
175
|
quantizer_type=TrainingMethod.STE)
|
|
177
176
|
class STEActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
|
|
@@ -206,7 +205,7 @@ class STEActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
|
|
|
206
205
|
def initialize_quantization(self,
|
|
207
206
|
tensor_shape: TensorShape,
|
|
208
207
|
name: str,
|
|
209
|
-
layer:
|
|
208
|
+
layer: KerasQuantizationWrapper):
|
|
210
209
|
"""
|
|
211
210
|
Add quantizer parameters to the quantizer parameters dictionary
|
|
212
211
|
|
|
@@ -16,25 +16,24 @@ import numpy as np
|
|
|
16
16
|
import tensorflow as tf
|
|
17
17
|
from tensorflow.python.framework.tensor_shape import TensorShape
|
|
18
18
|
from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX
|
|
19
|
-
from model_compression_toolkit.
|
|
19
|
+
from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
|
|
20
20
|
from model_compression_toolkit.qat import TrainingMethod
|
|
21
|
-
|
|
21
|
+
|
|
22
|
+
from mct_quantizers import mark_quantizer, QuantizationMethod, QuantizationTarget, KerasQuantizationWrapper
|
|
23
|
+
from mct_quantizers.keras.quantizers import \
|
|
24
|
+
BaseKerasInferableQuantizer, WeightsUniformInferableQuantizer, ActivationUniformInferableQuantizer
|
|
22
25
|
|
|
23
26
|
from model_compression_toolkit.qat.keras.quantizer.quant_utils import adjust_range_to_include_zero
|
|
24
27
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero
|
|
25
|
-
from model_compression_toolkit import
|
|
28
|
+
from model_compression_toolkit import constants as C
|
|
26
29
|
|
|
27
30
|
from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
|
|
28
|
-
from model_compression_toolkit.
|
|
31
|
+
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
29
32
|
TrainableQuantizerActivationConfig
|
|
30
|
-
from model_compression_toolkit.
|
|
31
|
-
mark_quantizer
|
|
32
|
-
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import \
|
|
33
|
-
BaseKerasInferableQuantizer, WeightsUniformInferableQuantizer, ActivationUniformInferableQuantizer
|
|
34
|
-
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
33
|
+
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
35
34
|
|
|
36
35
|
|
|
37
|
-
@mark_quantizer(quantization_target=
|
|
36
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Weights,
|
|
38
37
|
quantization_method=[QuantizationMethod.UNIFORM],
|
|
39
38
|
quantizer_type=TrainingMethod.STE)
|
|
40
39
|
class STEUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
|
|
@@ -73,7 +72,7 @@ class STEUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
|
|
|
73
72
|
def initialize_quantization(self,
|
|
74
73
|
tensor_shape: TensorShape,
|
|
75
74
|
name: str,
|
|
76
|
-
layer:
|
|
75
|
+
layer: KerasQuantizationWrapper):
|
|
77
76
|
"""
|
|
78
77
|
Add quantizer parameters to the quantizer parameters dictionary
|
|
79
78
|
|
|
@@ -148,7 +147,7 @@ class STEUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
|
|
|
148
147
|
input_rank=len(self.min_max_shape))
|
|
149
148
|
|
|
150
149
|
|
|
151
|
-
@mark_quantizer(quantization_target=
|
|
150
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Activation,
|
|
152
151
|
quantization_method=[QuantizationMethod.UNIFORM],
|
|
153
152
|
quantizer_type=TrainingMethod.STE)
|
|
154
153
|
class STEUniformActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
|
|
@@ -173,7 +172,7 @@ class STEUniformActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
|
|
|
173
172
|
def initialize_quantization(self,
|
|
174
173
|
tensor_shape: TensorShape,
|
|
175
174
|
name: str,
|
|
176
|
-
layer:
|
|
175
|
+
layer: KerasQuantizationWrapper):
|
|
177
176
|
"""
|
|
178
177
|
Add quantizer parameters to the quantizer parameters dictionary
|
|
179
178
|
|
|
@@ -25,28 +25,32 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
|
25
25
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
26
26
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
27
27
|
MixedPrecisionQuantizationConfigV2
|
|
28
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import
|
|
28
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
|
|
29
|
+
TargetPlatformCapabilities
|
|
29
30
|
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
|
|
30
31
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
31
32
|
|
|
32
|
-
|
|
33
33
|
if FOUND_TORCH:
|
|
34
34
|
import torch.nn as nn
|
|
35
35
|
from torch.nn import Module
|
|
36
|
+
from mct_quantizers import PytorchActivationQuantizationHolder
|
|
36
37
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
37
38
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
38
39
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
|
39
|
-
from model_compression_toolkit.qat.common.qat_config import
|
|
40
|
+
from model_compression_toolkit.qat.common.qat_config import is_qat_applicable
|
|
40
41
|
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
|
41
|
-
from
|
|
42
|
-
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
42
|
+
from mct_quantizers import PytorchQuantizationWrapper
|
|
43
43
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
44
44
|
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
|
45
|
+
from model_compression_toolkit.qat.pytorch.quantizer.quantization_builder import get_activation_quantizer_holder
|
|
45
46
|
from model_compression_toolkit.qat.pytorch.quantizer.quantization_builder import quantization_builder
|
|
47
|
+
|
|
46
48
|
DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
|
47
49
|
|
|
48
50
|
|
|
49
|
-
def qat_wrapper(n: common.BaseNode,
|
|
51
|
+
def qat_wrapper(n: common.BaseNode,
|
|
52
|
+
module: nn.Module,
|
|
53
|
+
qat_config: QATConfig):
|
|
50
54
|
"""
|
|
51
55
|
A function which takes a computational graph node and a pytorch module and perform the quantization wrapping
|
|
52
56
|
Args:
|
|
@@ -56,11 +60,11 @@ if FOUND_TORCH:
|
|
|
56
60
|
Returns: Wrapped layer
|
|
57
61
|
|
|
58
62
|
"""
|
|
59
|
-
if
|
|
60
|
-
weights_quantizers,
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
63
|
+
if is_qat_applicable(n, DEFAULT_PYTORCH_INFO):
|
|
64
|
+
weights_quantizers, _ = quantization_builder(n, qat_config, DEFAULT_PYTORCH_INFO)
|
|
65
|
+
if len(weights_quantizers) > 0:
|
|
66
|
+
return PytorchQuantizationWrapper(module, weights_quantizers)
|
|
67
|
+
return module
|
|
64
68
|
|
|
65
69
|
|
|
66
70
|
def pytorch_quantization_aware_training_init(in_model: Module,
|
|
@@ -135,11 +139,11 @@ if FOUND_TORCH:
|
|
|
135
139
|
if core_config.mixed_precision_enable:
|
|
136
140
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
|
|
137
141
|
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
138
|
-
|
|
139
|
-
|
|
142
|
+
"MixedPrecisionQuantizationConfigV2. Please use pytorch_post_training_quantization API,"
|
|
143
|
+
"or pass a valid mixed precision configuration.")
|
|
140
144
|
|
|
141
145
|
Logger.info("Using experimental mixed-precision quantization. "
|
|
142
|
-
|
|
146
|
+
"If you encounter an issue please file a bug.")
|
|
143
147
|
|
|
144
148
|
tb_w = _init_tensorboard_writer(fw_info)
|
|
145
149
|
|
|
@@ -158,12 +162,18 @@ if FOUND_TORCH:
|
|
|
158
162
|
|
|
159
163
|
_qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
|
|
160
164
|
|
|
161
|
-
qat_model, user_info = PyTorchModelBuilder(graph=tg,
|
|
165
|
+
qat_model, user_info = PyTorchModelBuilder(graph=tg,
|
|
166
|
+
fw_info=fw_info,
|
|
167
|
+
wrapper=_qat_wrapper,
|
|
168
|
+
get_activation_quantizer_holder_fn=partial(
|
|
169
|
+
get_activation_quantizer_holder,
|
|
170
|
+
qat_config=qat_config)).build_model()
|
|
162
171
|
|
|
163
172
|
user_info.mixed_precision_cfg = bit_widths_config
|
|
164
173
|
|
|
165
174
|
return qat_model, user_info
|
|
166
175
|
|
|
176
|
+
|
|
167
177
|
def pytorch_quantization_aware_training_finalize(in_model: Module):
|
|
168
178
|
"""
|
|
169
179
|
Convert a model fine-tuned by the user to a network with QuantizeWrappers containing
|
|
@@ -207,7 +217,7 @@ if FOUND_TORCH:
|
|
|
207
217
|
"""
|
|
208
218
|
exported_model = copy.deepcopy(in_model)
|
|
209
219
|
for _, layer in exported_model.named_children():
|
|
210
|
-
if isinstance(layer, PytorchQuantizationWrapper):
|
|
220
|
+
if isinstance(layer, (PytorchQuantizationWrapper, PytorchActivationQuantizationHolder)):
|
|
211
221
|
layer.convert_to_inferable_quantizers()
|
|
212
222
|
|
|
213
223
|
return exported_model
|
|
@@ -221,6 +231,7 @@ else:
|
|
|
221
231
|
'when using pytorch_quantization_aware_training_init. '
|
|
222
232
|
'Could not find the torch package.') # pragma: no cover
|
|
223
233
|
|
|
234
|
+
|
|
224
235
|
def pytorch_quantization_aware_training_finalize(*args, **kwargs):
|
|
225
236
|
Logger.critical('Installing Pytorch is mandatory '
|
|
226
237
|
'when using pytorch_quantization_aware_training_finalize. '
|
|
@@ -17,9 +17,9 @@ from typing import Union
|
|
|
17
17
|
from model_compression_toolkit.logger import Logger
|
|
18
18
|
from model_compression_toolkit.constants import FOUND_TORCH
|
|
19
19
|
|
|
20
|
-
from model_compression_toolkit.
|
|
20
|
+
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
21
21
|
TrainableQuantizerActivationConfig
|
|
22
|
-
from model_compression_toolkit.
|
|
22
|
+
from model_compression_toolkit.trainable_infrastructure.pytorch.base_pytorch_quantizer import \
|
|
23
23
|
BasePytorchTrainableQuantizer
|
|
24
24
|
|
|
25
25
|
if FOUND_TORCH:
|
|
@@ -12,19 +12,46 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from typing import List, Dict, Tuple
|
|
15
|
+
from typing import List, Dict, Tuple, Callable
|
|
16
|
+
|
|
17
|
+
from mct_quantizers import PytorchActivationQuantizationHolder, QuantizationTarget
|
|
16
18
|
|
|
17
19
|
from model_compression_toolkit.core import common
|
|
18
20
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
19
21
|
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
|
20
|
-
from model_compression_toolkit.
|
|
22
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
23
|
+
from model_compression_toolkit.logger import Logger
|
|
24
|
+
from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
|
|
21
25
|
get_trainable_quantizer_quantization_candidates, get_trainable_quantizer_weights_config, \
|
|
22
26
|
get_trainable_quantizer_activation_config
|
|
23
27
|
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
|
|
24
|
-
from model_compression_toolkit.
|
|
25
|
-
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizers import \
|
|
28
|
+
from model_compression_toolkit.trainable_infrastructure.common.get_quantizers import \
|
|
26
29
|
get_trainable_quantizer_class
|
|
27
30
|
|
|
31
|
+
def get_activation_quantizer_holder(n: common.BaseNode,
|
|
32
|
+
qat_config: QATConfig) -> Callable:
|
|
33
|
+
"""
|
|
34
|
+
Retrieve a ActivationQuantizationHolder layer to use for activation quantization for a node.
|
|
35
|
+
If the layer is not supposed to be wrapped with activation quantizers - return None.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
n: Node for which to retrieve anActivationQuantizationHolder to attach to its output.
|
|
39
|
+
qat_config: QAT configuration (for example, training methods).
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
A ActivationQuantizationHolder layer for the node's activation quantization.
|
|
43
|
+
"""
|
|
44
|
+
_, activation_quantizers = quantization_builder(n,
|
|
45
|
+
qat_config,
|
|
46
|
+
DEFAULT_PYTORCH_INFO)
|
|
47
|
+
|
|
48
|
+
# Holder by definition uses a single quantizer for the activation quantization
|
|
49
|
+
# thus we make sure this is the only possible case (unless it's a node with no activation
|
|
50
|
+
# quantization, which in this case has an empty list).
|
|
51
|
+
if len(activation_quantizers) == 1:
|
|
52
|
+
return PytorchActivationQuantizationHolder(activation_quantizers[0])
|
|
53
|
+
Logger.error(f'ActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers were found for node {n}')
|
|
54
|
+
|
|
28
55
|
|
|
29
56
|
def quantization_builder(n: common.BaseNode,
|
|
30
57
|
qat_config: QATConfig,
|