mct-nightly 1.8.0.20052023.post401__py3-none-any.whl → 1.8.0.20230610.post356__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/METADATA +10 -7
- {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/RECORD +68 -115
- model_compression_toolkit/__init__.py +23 -3
- model_compression_toolkit/core/common/framework_info.py +1 -1
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +16 -9
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +8 -34
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +103 -28
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +39 -44
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/int8_tflite_exporter.py +20 -18
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +3 -3
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +36 -9
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +24 -32
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +31 -8
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +5 -5
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +34 -8
- model_compression_toolkit/gptq/keras/gptq_training.py +15 -16
- model_compression_toolkit/gptq/keras/graph_info.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +4 -5
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +5 -7
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +6 -6
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +7 -7
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +6 -6
- model_compression_toolkit/gptq/pytorch/gptq_training.py +30 -10
- model_compression_toolkit/gptq/pytorch/graph_info.py +5 -2
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +4 -2
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +4 -4
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +5 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +7 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +7 -8
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +7 -8
- model_compression_toolkit/qat/common/__init__.py +2 -1
- model_compression_toolkit/qat/common/qat_config.py +2 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +18 -8
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +1 -1
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +11 -11
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +11 -12
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +12 -13
- model_compression_toolkit/qat/pytorch/quantization_facade.py +27 -16
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +31 -4
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +10 -9
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +11 -10
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +2 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +1 -25
- model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py → trainable_infrastructure/__init__.py} +3 -10
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/base_trainable_quantizer.py +3 -3
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/get_quantizer_config.py +1 -1
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/get_quantizers.py +3 -3
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/base_keras_quantizer.py +4 -4
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/config_serialization.py +2 -2
- model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure → trainable_infrastructure}/keras/load_model.py +16 -23
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/pytorch/base_pytorch_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/__init__.py +0 -23
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +0 -87
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +0 -46
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +0 -31
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +0 -53
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +0 -49
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +0 -147
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +0 -345
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +0 -85
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +0 -27
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +0 -148
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +0 -65
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +0 -86
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +0 -111
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +0 -56
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +0 -79
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +0 -179
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +0 -67
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +0 -87
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +0 -163
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +0 -66
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +0 -269
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +0 -152
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +0 -35
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +0 -96
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +0 -62
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +0 -83
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +0 -100
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +0 -95
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +0 -48
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +0 -70
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +0 -57
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +0 -26
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +0 -77
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +0 -106
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +0 -66
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +0 -104
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +0 -109
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +0 -14
- {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/WHEEL +0 -0
- {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure → trainable_infrastructure/common}/__init__.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure → trainable_infrastructure/common}/constants.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/quant_utils.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/trainable_quantizer_config.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/common → trainable_infrastructure/keras}/__init__.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/quantizer_utils.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/keras → trainable_infrastructure/pytorch}/__init__.py +0 -0
|
@@ -14,16 +14,14 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Any
|
|
16
16
|
|
|
17
|
-
|
|
18
|
-
from model_compression_toolkit.logger import Logger
|
|
17
|
+
from mct_quantizers import BaseInferableQuantizer, KerasActivationQuantizationHolder
|
|
19
18
|
from model_compression_toolkit.constants import FOUND_TF
|
|
20
|
-
|
|
21
|
-
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
|
|
22
|
-
|
|
19
|
+
from model_compression_toolkit.logger import Logger
|
|
23
20
|
|
|
24
21
|
if FOUND_TF:
|
|
22
|
+
from keras.engine.base_layer import Layer
|
|
25
23
|
from keras.engine.input_layer import InputLayer
|
|
26
|
-
from
|
|
24
|
+
from mct_quantizers import KerasQuantizationWrapper
|
|
27
25
|
|
|
28
26
|
def is_keras_layer_exportable(layer: Any) -> bool:
|
|
29
27
|
"""
|
|
@@ -39,40 +37,34 @@ if FOUND_TF:
|
|
|
39
37
|
if isinstance(layer, InputLayer):
|
|
40
38
|
return True
|
|
41
39
|
|
|
42
|
-
valid_layer = isinstance(layer,
|
|
40
|
+
valid_layer = isinstance(layer, Layer)
|
|
43
41
|
if not valid_layer:
|
|
44
42
|
Logger.error(
|
|
45
|
-
f'Exportable layer must be
|
|
43
|
+
f'Exportable layer must be a Keras layer, but layer {layer.name} is of type '
|
|
46
44
|
f'{type(layer)}') # pragma: no cover
|
|
47
45
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
f'KerasQuantizationWrapper must have a weights_quantizers but has a '
|
|
52
|
-
f'{type(layer.weights_quantizers)} object') # pragma: no cover
|
|
53
|
-
|
|
54
|
-
for _, weights_quantizer in layer.weights_quantizers.items():
|
|
55
|
-
if not isinstance(weights_quantizer, BaseInferableQuantizer):
|
|
46
|
+
if isinstance(layer, KerasQuantizationWrapper):
|
|
47
|
+
valid_weights_quantizers = isinstance(layer.weights_quantizers, dict)
|
|
48
|
+
if not valid_weights_quantizers:
|
|
56
49
|
Logger.error(
|
|
57
|
-
f'
|
|
58
|
-
f'{type(
|
|
50
|
+
f'KerasQuantizationWrapper must have a weights_quantizers but has a '
|
|
51
|
+
f'{type(layer.weights_quantizers)} object') # pragma: no cover
|
|
59
52
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
Logger.error(
|
|
63
|
-
f'KerasQuantizationWrapper must have a activation_quantizers list but has a '
|
|
64
|
-
f'{type(layer.activation_quantizers)} object') # pragma: no cover
|
|
53
|
+
if len(layer.weights_quantizers) == 0:
|
|
54
|
+
Logger.error(f'KerasQuantizationWrapper must have at least one weight quantizer, but found {len(layer.weights_quantizers)} quantizers. If layer is not quantized it should be a Keras layer.')
|
|
65
55
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
56
|
+
for _, weights_quantizer in layer.weights_quantizers.items():
|
|
57
|
+
if not isinstance(weights_quantizer, BaseInferableQuantizer):
|
|
58
|
+
Logger.error(
|
|
59
|
+
f'weights_quantizer must be a BaseInferableQuantizer object but has a '
|
|
60
|
+
f'{type(weights_quantizer)} object') # pragma: no cover
|
|
71
61
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
62
|
+
if isinstance(layer, KerasActivationQuantizationHolder):
|
|
63
|
+
if not isinstance(layer.activation_holder_quantizer, BaseInferableQuantizer):
|
|
64
|
+
Logger.error(
|
|
65
|
+
f'activation quantizer in KerasActivationQuantizationHolder'
|
|
66
|
+
f' must be a BaseInferableQuantizer object but has a '
|
|
67
|
+
f'{type(layer.activation_holder_quantizer)} object') # pragma: no cover
|
|
76
68
|
|
|
77
69
|
return True
|
|
78
70
|
else:
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -13,20 +13,23 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
16
|
+
from typing import Union, Callable
|
|
18
17
|
from model_compression_toolkit.core import common
|
|
19
18
|
from model_compression_toolkit.core.common import Graph
|
|
20
19
|
from model_compression_toolkit.constants import FOUND_TORCH
|
|
21
20
|
from model_compression_toolkit.logger import Logger
|
|
21
|
+
from model_compression_toolkit.core.common import BaseNode
|
|
22
22
|
|
|
23
23
|
if FOUND_TORCH:
|
|
24
24
|
import torch
|
|
25
|
+
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
|
|
25
26
|
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
|
26
27
|
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizers import \
|
|
27
28
|
get_quantization_quantizers
|
|
28
29
|
|
|
29
|
-
|
|
30
|
+
|
|
31
|
+
def fully_quantized_wrapper(node: common.BaseNode,
|
|
32
|
+
module: torch.nn.Module) -> Union[torch.nn.Module,PytorchQuantizationWrapper]:
|
|
30
33
|
"""
|
|
31
34
|
A function which takes a computational graph node and a pytorch module and
|
|
32
35
|
perform the quantization wrapping
|
|
@@ -34,14 +37,32 @@ if FOUND_TORCH:
|
|
|
34
37
|
Args:
|
|
35
38
|
node: A node of mct graph.
|
|
36
39
|
module: A Pytorch module
|
|
37
|
-
|
|
38
40
|
Returns: Wrapped layer
|
|
39
41
|
|
|
40
42
|
"""
|
|
41
|
-
weight_quantizers,
|
|
42
|
-
|
|
43
|
-
|
|
43
|
+
weight_quantizers, _ = get_quantization_quantizers(node)
|
|
44
|
+
if len(weight_quantizers) > 0:
|
|
45
|
+
return PytorchQuantizationWrapper(module, weight_quantizers)
|
|
46
|
+
return module
|
|
44
47
|
|
|
48
|
+
def get_activation_quantizer_holder(node: BaseNode) -> Callable:
|
|
49
|
+
"""
|
|
50
|
+
Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node.
|
|
51
|
+
If the layer is not supposed to be wrapped with an activation quantizer - return None.
|
|
52
|
+
Args:
|
|
53
|
+
node: Node to attach a PytorchActivationQuantizationHolder to its output.
|
|
54
|
+
Returns:
|
|
55
|
+
A PytorchActivationQuantizationHolder module for the node's activation quantization.
|
|
56
|
+
"""
|
|
57
|
+
_, activation_quantizers = get_quantization_quantizers(node)
|
|
58
|
+
# Holder by definition uses a single quantizer for the activation quantization
|
|
59
|
+
# thus we make sure this is the only possible case (unless it's a node we no activation
|
|
60
|
+
# quantization, which in this case has an empty list).
|
|
61
|
+
if len(activation_quantizers) == 1:
|
|
62
|
+
return PytorchActivationQuantizationHolder(activation_quantizers[0])
|
|
63
|
+
Logger.error(
|
|
64
|
+
f'PytorchActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
|
|
65
|
+
f'were found for node {node}')
|
|
45
66
|
|
|
46
67
|
def get_exportable_pytorch_model(graph: Graph):
|
|
47
68
|
"""
|
|
@@ -54,7 +75,9 @@ if FOUND_TORCH:
|
|
|
54
75
|
Fully quantized PyTorch model.
|
|
55
76
|
"""
|
|
56
77
|
return PyTorchModelBuilder(graph=graph,
|
|
57
|
-
wrapper=fully_quantized_wrapper
|
|
78
|
+
wrapper=fully_quantized_wrapper,
|
|
79
|
+
get_activation_quantizer_holder_fn=get_activation_quantizer_holder).build_model()
|
|
80
|
+
|
|
58
81
|
else:
|
|
59
82
|
def get_exportable_pytorch_model(*args, **kwargs): # pragma: no cover
|
|
60
83
|
Logger.error('Installing torch is mandatory '
|
|
@@ -20,11 +20,11 @@ from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RA
|
|
|
20
20
|
SCALE_PER_CHANNEL, CLUSTER_CENTERS
|
|
21
21
|
from model_compression_toolkit.logger import Logger
|
|
22
22
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
23
|
-
from
|
|
24
|
-
from
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
23
|
+
from mct_quantizers import QuantizationTarget
|
|
24
|
+
from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
|
|
25
|
+
from mct_quantizers import \
|
|
26
|
+
constants as qi_inferable_quantizers_constants
|
|
27
|
+
from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer
|
|
28
28
|
import numpy as np
|
|
29
29
|
|
|
30
30
|
|
|
@@ -17,10 +17,13 @@ from typing import Any
|
|
|
17
17
|
from model_compression_toolkit.logger import Logger
|
|
18
18
|
from model_compression_toolkit.constants import FOUND_TORCH
|
|
19
19
|
|
|
20
|
+
|
|
20
21
|
if FOUND_TORCH:
|
|
21
|
-
|
|
22
|
-
from
|
|
23
|
-
|
|
22
|
+
import torch.nn as nn
|
|
23
|
+
from mct_quantizers import PytorchQuantizationWrapper
|
|
24
|
+
from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer
|
|
25
|
+
from mct_quantizers.pytorch.activation_quantization_holder import PytorchActivationQuantizationHolder
|
|
26
|
+
|
|
24
27
|
def is_pytorch_layer_exportable(layer: Any) -> bool:
|
|
25
28
|
"""
|
|
26
29
|
Check whether a torch Module is a valid exportable module or not.
|
|
@@ -31,12 +34,35 @@ if FOUND_TORCH:
|
|
|
31
34
|
Returns:
|
|
32
35
|
Check whether a PyTorch layer is a valid exportable layer or not.
|
|
33
36
|
"""
|
|
37
|
+
if not isinstance(layer, nn.Module):
|
|
38
|
+
Logger.error(f'Exportable layer must be a nn.Module layer, but layer {layer.name} is of type {type(layer)}') # pragma: no cover
|
|
39
|
+
|
|
34
40
|
if isinstance(layer, PytorchQuantizationWrapper):
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
41
|
+
valid_weights_quantizers = isinstance(layer.weights_quantizers, dict)
|
|
42
|
+
if not valid_weights_quantizers:
|
|
43
|
+
Logger.error(
|
|
44
|
+
f'PytorchQuantizationWrapper must have a weights_quantizers but has a '
|
|
45
|
+
f'{type(layer.weights_quantizers)} object') # pragma: no cover
|
|
46
|
+
|
|
47
|
+
if len(layer.weights_quantizers) == 0:
|
|
48
|
+
Logger.error(f'PytorchQuantizationWrapper must have at least one weight quantizer, but found {len(layer.weights_quantizers)} quantizers.'
|
|
49
|
+
f'If layer is not quantized it should be a Keras layer.')
|
|
50
|
+
|
|
51
|
+
for _, weights_quantizer in layer.weights_quantizers.items():
|
|
52
|
+
if not isinstance(weights_quantizer, BasePyTorchInferableQuantizer):
|
|
53
|
+
Logger.error(
|
|
54
|
+
f'weights_quantizer must be a BasePyTorchInferableQuantizer object but has a '
|
|
55
|
+
f'{type(weights_quantizer)} object') # pragma: no cover
|
|
56
|
+
|
|
57
|
+
elif isinstance(layer, PytorchActivationQuantizationHolder):
|
|
58
|
+
if not isinstance(layer.activation_holder_quantizer, BasePyTorchInferableQuantizer):
|
|
59
|
+
Logger.error(
|
|
60
|
+
f'activation quantizer in PytorchActivationQuantizationHolder'
|
|
61
|
+
f' must be a BasePyTorchInferableQuantizer object but has a '
|
|
62
|
+
f'{type(layer.activation_holder_quantizer)} object') # pragma: no cover
|
|
63
|
+
|
|
64
|
+
return True
|
|
65
|
+
|
|
40
66
|
else:
|
|
41
67
|
def is_pytorch_layer_exportable(*args, **kwargs): # pragma: no cover
|
|
42
68
|
Logger.error('Installing torch is mandatory '
|
|
@@ -26,8 +26,7 @@ from model_compression_toolkit.core.keras.back2framework.keras_model_builder imp
|
|
|
26
26
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
27
27
|
from model_compression_toolkit.gptq.keras.quantizer.quantization_builder import quantization_builder
|
|
28
28
|
from model_compression_toolkit.logger import Logger
|
|
29
|
-
from
|
|
30
|
-
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.activation_quantization_holder import ActivationQuantizationHolder
|
|
29
|
+
from mct_quantizers import KerasQuantizationWrapper, KerasActivationQuantizationHolder
|
|
31
30
|
|
|
32
31
|
if version.parse(tf.__version__) < version.parse("2.6"):
|
|
33
32
|
from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
|
|
@@ -45,7 +44,6 @@ from model_compression_toolkit.core.common.framework_implementation import Frame
|
|
|
45
44
|
import numpy as np
|
|
46
45
|
import copy
|
|
47
46
|
from model_compression_toolkit.core.keras.constants import BIAS, USE_BIAS
|
|
48
|
-
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
49
47
|
|
|
50
48
|
|
|
51
49
|
class KerasGPTQTrainer(GPTQTrainer):
|
|
@@ -133,7 +131,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
133
131
|
|
|
134
132
|
def gptq_wrapper(self,
|
|
135
133
|
n: common.BaseNode,
|
|
136
|
-
layer: Layer) -> Union[
|
|
134
|
+
layer: Layer) -> Union[KerasQuantizationWrapper, Layer]:
|
|
137
135
|
"""
|
|
138
136
|
A function which takes a computational graph node and a keras layer and perform the quantization wrapping.
|
|
139
137
|
|
|
@@ -145,22 +143,23 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
145
143
|
|
|
146
144
|
"""
|
|
147
145
|
if self._is_gptq_weights_trainable(n):
|
|
148
|
-
weights_quantizers, _ = quantization_builder(n,
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
146
|
+
weights_quantizers, _ = quantization_builder(n,
|
|
147
|
+
self.gptq_config) # TODO: split quantizers building into two functions: for weights and activations
|
|
148
|
+
if len(weights_quantizers) > 0:
|
|
149
|
+
return KerasQuantizationWrapper(layer,
|
|
150
|
+
weights_quantizers=weights_quantizers)
|
|
151
|
+
return layer
|
|
152
|
+
|
|
153
|
+
def get_activation_quantizer_holder(self, n: common.BaseNode) -> Callable:
|
|
155
154
|
"""
|
|
156
|
-
Retrieve a
|
|
155
|
+
Retrieve a KerasActivationQuantizationHolder layer to use for activation quantization for a node.
|
|
157
156
|
If the layer is not supposed to be wrapped with activation quantizers - return None.
|
|
158
157
|
|
|
159
158
|
Args:
|
|
160
|
-
n: Node to get
|
|
159
|
+
n: Node to get KerasActivationQuantizationHolder to attach in its output.
|
|
161
160
|
|
|
162
161
|
Returns:
|
|
163
|
-
A
|
|
162
|
+
A KerasActivationQuantizationHolder layer for the node activation quantization.
|
|
164
163
|
"""
|
|
165
164
|
_, activation_quantizers = quantization_builder(n, self.gptq_config) # TODO: split quantizers building into two functions: for weights and activations
|
|
166
165
|
|
|
@@ -168,10 +167,10 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
168
167
|
# thus we make sure this is the only possible case (unless it's a node with no activation
|
|
169
168
|
# quantization, which in this case has an empty list).
|
|
170
169
|
if len(activation_quantizers) == 1:
|
|
171
|
-
return
|
|
170
|
+
return KerasActivationQuantizationHolder(activation_quantizers[0])
|
|
172
171
|
|
|
173
172
|
Logger.error(
|
|
174
|
-
f'
|
|
173
|
+
f'KerasActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
|
|
175
174
|
f'were found for node {n}')
|
|
176
175
|
|
|
177
176
|
|
|
@@ -21,8 +21,8 @@ from tensorflow.keras.models import Model
|
|
|
21
21
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
22
22
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
23
23
|
from model_compression_toolkit.logger import Logger
|
|
24
|
-
from
|
|
25
|
-
from model_compression_toolkit.
|
|
24
|
+
from mct_quantizers import KerasQuantizationWrapper
|
|
25
|
+
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def get_gptq_trainable_parameters(fxp_model: Model,
|
|
@@ -19,15 +19,14 @@ from model_compression_toolkit.logger import Logger
|
|
|
19
19
|
from model_compression_toolkit.constants import FOUND_TF
|
|
20
20
|
from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
|
|
21
21
|
|
|
22
|
-
from model_compression_toolkit.
|
|
23
|
-
|
|
24
|
-
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer
|
|
22
|
+
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
|
|
23
|
+
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer
|
|
25
24
|
|
|
26
25
|
if FOUND_TF:
|
|
27
26
|
import tensorflow as tf
|
|
28
27
|
|
|
29
|
-
from model_compression_toolkit.
|
|
30
|
-
|
|
28
|
+
from model_compression_toolkit.trainable_infrastructure import BaseKerasTrainableQuantizer
|
|
29
|
+
from mct_quantizers import KerasQuantizationWrapper
|
|
31
30
|
|
|
32
31
|
class BaseKerasGPTQTrainableQuantizer(BaseKerasTrainableQuantizer):
|
|
33
32
|
"""
|
|
@@ -21,14 +21,12 @@ from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quan
|
|
|
21
21
|
get_inferable_quantizer_kwargs
|
|
22
22
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
23
23
|
from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
|
|
24
|
-
from
|
|
25
|
-
from
|
|
26
|
-
|
|
27
|
-
from model_compression_toolkit.
|
|
28
|
-
BaseKerasInferableQuantizer
|
|
29
|
-
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizer_config import \
|
|
24
|
+
from mct_quantizers import QuantizationTarget
|
|
25
|
+
from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
|
|
26
|
+
from mct_quantizers.keras.quantizers import BaseKerasInferableQuantizer
|
|
27
|
+
from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
|
|
30
28
|
get_trainable_quantizer_weights_config
|
|
31
|
-
from model_compression_toolkit.
|
|
29
|
+
from model_compression_toolkit.trainable_infrastructure.common.get_quantizers import \
|
|
32
30
|
get_trainable_quantizer_class
|
|
33
31
|
|
|
34
32
|
|
|
@@ -19,7 +19,7 @@ from keras import Model
|
|
|
19
19
|
|
|
20
20
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
21
21
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
22
|
-
from
|
|
22
|
+
from mct_quantizers import KerasQuantizationWrapper
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class LinearTempDecay:
|
|
@@ -17,9 +17,9 @@ import tensorflow as tf
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
|
|
19
19
|
from model_compression_toolkit.gptq import RoundingType
|
|
20
|
-
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
21
20
|
from model_compression_toolkit.core.common import max_power_of_two
|
|
22
21
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
22
|
+
from mct_quantizers import QuantizationTarget
|
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
|
|
24
24
|
SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
|
|
25
25
|
from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
|
|
@@ -27,11 +27,11 @@ from typing import Dict, Any
|
|
|
27
27
|
from model_compression_toolkit.constants import THRESHOLD, MIN_THRESHOLD
|
|
28
28
|
from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
|
|
29
29
|
from model_compression_toolkit.gptq.keras.quantizer.quant_utils import power_of_two_max, clip, calculate_delta
|
|
30
|
-
from model_compression_toolkit.
|
|
31
|
-
from
|
|
32
|
-
from model_compression_toolkit.
|
|
30
|
+
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig
|
|
31
|
+
from mct_quantizers import mark_quantizer
|
|
32
|
+
from model_compression_toolkit.trainable_infrastructure.common.quant_utils import \
|
|
33
33
|
get_threshold_reshape_shape
|
|
34
|
-
from model_compression_toolkit.
|
|
34
|
+
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
def soft_rounding_symmetric_quantizer(input_tensor: tf.Tensor,
|
|
@@ -66,7 +66,7 @@ def soft_rounding_symmetric_quantizer(input_tensor: tf.Tensor,
|
|
|
66
66
|
return delta * clip(tensor_q, max_val=max_int, min_val=min_int)
|
|
67
67
|
|
|
68
68
|
|
|
69
|
-
@mark_quantizer(quantization_target=
|
|
69
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Weights,
|
|
70
70
|
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
|
|
71
71
|
quantizer_type=RoundingType.SoftQuantizer)
|
|
72
72
|
class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
|
|
@@ -17,20 +17,20 @@ import tensorflow as tf
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
|
|
19
19
|
from model_compression_toolkit.gptq import RoundingType
|
|
20
|
-
from model_compression_toolkit import
|
|
21
|
-
from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
|
|
20
|
+
from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
|
|
22
21
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
22
|
+
from mct_quantizers import QuantizationTarget
|
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_constants import \
|
|
24
24
|
SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
|
|
25
25
|
from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
|
|
26
26
|
from typing import Dict, Any
|
|
27
27
|
from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX
|
|
28
28
|
from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
|
|
29
|
-
from model_compression_toolkit.
|
|
30
|
-
from
|
|
31
|
-
from model_compression_toolkit.
|
|
29
|
+
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig
|
|
30
|
+
from mct_quantizers import mark_quantizer
|
|
31
|
+
from model_compression_toolkit.trainable_infrastructure.common.quant_utils import \
|
|
32
32
|
get_threshold_reshape_shape
|
|
33
|
-
from model_compression_toolkit.
|
|
33
|
+
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
def soft_rounding_uniform_quantizer(input_tensor: tf.Tensor,
|
|
@@ -61,7 +61,7 @@ def soft_rounding_uniform_quantizer(input_tensor: tf.Tensor,
|
|
|
61
61
|
max_val=2 ** num_bits - 1) + min_range
|
|
62
62
|
|
|
63
63
|
|
|
64
|
-
@mark_quantizer(quantization_target=
|
|
64
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Weights,
|
|
65
65
|
quantization_method=[QuantizationMethod.UNIFORM],
|
|
66
66
|
quantizer_type=RoundingType.SoftQuantizer)
|
|
67
67
|
class UniformSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
|
|
@@ -19,18 +19,18 @@ import numpy as np
|
|
|
19
19
|
import tensorflow as tf
|
|
20
20
|
|
|
21
21
|
from model_compression_toolkit.gptq import RoundingType
|
|
22
|
-
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
23
22
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
23
|
+
from mct_quantizers import QuantizationTarget
|
|
24
24
|
from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD
|
|
25
25
|
from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
|
|
26
26
|
from model_compression_toolkit.constants import THRESHOLD
|
|
27
27
|
from model_compression_toolkit.core.common.defaultdict import DefaultDict
|
|
28
28
|
from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
|
|
29
|
-
from model_compression_toolkit.
|
|
30
|
-
from
|
|
31
|
-
from model_compression_toolkit.
|
|
29
|
+
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig
|
|
30
|
+
from mct_quantizers import mark_quantizer
|
|
31
|
+
from model_compression_toolkit.trainable_infrastructure.common.quant_utils import \
|
|
32
32
|
get_threshold_reshape_shape
|
|
33
|
-
from model_compression_toolkit.
|
|
33
|
+
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
def pertubation_symmetric_quantizer(input_tensor: tf.Tensor,
|
|
@@ -67,7 +67,7 @@ def pertubation_symmetric_quantizer(input_tensor: tf.Tensor,
|
|
|
67
67
|
return delta * qutils.ste_clip(tensor_q, max_val=max_int, min_val=min_int)
|
|
68
68
|
|
|
69
69
|
|
|
70
|
-
@mark_quantizer(quantization_target=
|
|
70
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Weights,
|
|
71
71
|
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
|
|
72
72
|
quantizer_type=RoundingType.STE)
|
|
73
73
|
class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):
|
|
@@ -32,9 +32,8 @@ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, set_mo
|
|
|
32
32
|
from model_compression_toolkit.gptq.pytorch.graph_info import get_gptq_trainable_parameters, \
|
|
33
33
|
get_weights_for_loss
|
|
34
34
|
from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder
|
|
35
|
-
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
36
35
|
from model_compression_toolkit.gptq.pytorch.quantizer.regularization_factory import get_regularization
|
|
37
|
-
from
|
|
36
|
+
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
|
|
38
37
|
|
|
39
38
|
|
|
40
39
|
class PytorchGPTQTrainer(GPTQTrainer):
|
|
@@ -90,8 +89,8 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
90
89
|
|
|
91
90
|
self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
|
|
92
91
|
|
|
93
|
-
def
|
|
94
|
-
|
|
92
|
+
def _is_gptq_weights_trainable(self,
|
|
93
|
+
node: BaseNode) -> bool:
|
|
95
94
|
"""
|
|
96
95
|
A function for deciding if a layer should be fine-tuned during GPTQ.
|
|
97
96
|
Args:
|
|
@@ -105,7 +104,9 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
105
104
|
f"without a kernel isn't supported.")
|
|
106
105
|
return node.is_weights_quantization_enabled()
|
|
107
106
|
|
|
108
|
-
def gptq_wrapper(self,
|
|
107
|
+
def gptq_wrapper(self,
|
|
108
|
+
n: BaseNode,
|
|
109
|
+
layer: Module) -> Union[PytorchQuantizationWrapper, Module]:
|
|
109
110
|
"""
|
|
110
111
|
A function which takes a computational graph node and a pytorch layer and perform the quantization wrapping.
|
|
111
112
|
|
|
@@ -116,14 +117,32 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
116
117
|
Returns: Wrapped layer if the layer should be wrap, otherwise returns the layer as is.
|
|
117
118
|
"""
|
|
118
119
|
|
|
119
|
-
if self.
|
|
120
|
+
if self._is_gptq_weights_trainable(n):
|
|
120
121
|
weights_quantizers, activation_quantizers = quantization_builder(n, self.gptq_config)
|
|
121
|
-
return
|
|
122
|
-
|
|
123
|
-
activation_quantizers=activation_quantizers)
|
|
122
|
+
return PytorchQuantizationWrapper(layer,
|
|
123
|
+
weights_quantizers=weights_quantizers)
|
|
124
124
|
else:
|
|
125
125
|
return layer
|
|
126
126
|
|
|
127
|
+
def get_activation_quantizer_holder(self, n: BaseNode) -> Callable:
|
|
128
|
+
"""
|
|
129
|
+
Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node.
|
|
130
|
+
If the layer is not supposed to be wrapped with an activation quantizer - return None.
|
|
131
|
+
Args:
|
|
132
|
+
n: Node to attach a PytorchActivationQuantizationHolder to its output.
|
|
133
|
+
Returns:
|
|
134
|
+
A PytorchActivationQuantizationHolder module for the node's activation quantization.
|
|
135
|
+
"""
|
|
136
|
+
_, activation_quantizers = quantization_builder(n, self.gptq_config)
|
|
137
|
+
# Holder by definition uses a single quantizer for the activation quantization
|
|
138
|
+
# thus we make sure this is the only possible case (unless it's a node we no activation
|
|
139
|
+
# quantization, which in this case has an empty list).
|
|
140
|
+
if len(activation_quantizers) == 1:
|
|
141
|
+
return PytorchActivationQuantizationHolder(activation_quantizers[0])
|
|
142
|
+
Logger.error(
|
|
143
|
+
f'PytorchActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
|
|
144
|
+
f'were found for node {n}')
|
|
145
|
+
|
|
127
146
|
def build_gptq_model(self):
|
|
128
147
|
"""
|
|
129
148
|
Build the GPTQ model with QuantizationWrappers
|
|
@@ -134,7 +153,8 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
134
153
|
append2output=self.compare_points,
|
|
135
154
|
fw_info=self.fw_info,
|
|
136
155
|
wrapper=self.gptq_wrapper,
|
|
137
|
-
return_float_outputs=True
|
|
156
|
+
return_float_outputs=True,
|
|
157
|
+
get_activation_quantizer_holder_fn=self.get_activation_quantizer_holder).build_model()
|
|
138
158
|
|
|
139
159
|
return gptq_model, gptq_user_info
|
|
140
160
|
|
|
@@ -18,8 +18,9 @@ from typing import List
|
|
|
18
18
|
from model_compression_toolkit.core.pytorch.constants import BIAS
|
|
19
19
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
20
20
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
21
|
-
from model_compression_toolkit.
|
|
22
|
-
from
|
|
21
|
+
from model_compression_toolkit.logger import Logger
|
|
22
|
+
from mct_quantizers import PytorchQuantizationWrapper
|
|
23
|
+
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
def get_gptq_trainable_parameters(fxp_model: nn.Module,
|
|
@@ -46,6 +47,8 @@ def get_gptq_trainable_parameters(fxp_model: nn.Module,
|
|
|
46
47
|
fw_info=DEFAULT_PYTORCH_INFO)
|
|
47
48
|
|
|
48
49
|
# collect trainable weights per quantizer
|
|
50
|
+
if kernel_attribute not in layer.weights_quantizers:
|
|
51
|
+
Logger.error(f'{kernel_attribute} was not found in weight quantizers of layer {layer.layer}')
|
|
49
52
|
quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
|
|
50
53
|
quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
|
|
51
54
|
trainable_aux_weights.extend(quantizer_trainable_weights)
|
|
@@ -53,7 +53,8 @@ if FOUND_TORCH:
|
|
|
53
53
|
optimizer: Optimizer = Adam([torch.Tensor([])], lr=LR_DEFAULT),
|
|
54
54
|
optimizer_rest: Optimizer = Adam([torch.Tensor([])], lr=LR_REST_DEFAULT),
|
|
55
55
|
loss: Callable = multiple_tensors_mse_loss,
|
|
56
|
-
log_function: Callable = None
|
|
56
|
+
log_function: Callable = None,
|
|
57
|
+
use_hessian_based_weights: bool = True) -> GradientPTQConfigV2:
|
|
57
58
|
"""
|
|
58
59
|
Create a GradientPTQConfigV2 instance for Pytorch models.
|
|
59
60
|
|
|
@@ -63,6 +64,7 @@ if FOUND_TORCH:
|
|
|
63
64
|
optimizer_rest (Optimizer): Pytorch optimizer to use for fine-tuning of the bias variable.
|
|
64
65
|
loss (Callable): loss to use during fine-tuning. should accept 4 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors, the 3rd is a list of quantized weights and the 4th is a list of float weights.
|
|
65
66
|
log_function (Callable): Function to log information about the gptq process.
|
|
67
|
+
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
|
|
66
68
|
|
|
67
69
|
returns:
|
|
68
70
|
a GradientPTQConfigV2 object to use when fine-tuning the quantized model using gptq.
|
|
@@ -84,7 +86,7 @@ if FOUND_TORCH:
|
|
|
84
86
|
"""
|
|
85
87
|
bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
|
|
86
88
|
return GradientPTQConfigV2(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
|
|
87
|
-
log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer)
|
|
89
|
+
log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer, use_hessian_based_weights=use_hessian_based_weights)
|
|
88
90
|
|
|
89
91
|
|
|
90
92
|
def pytorch_gradient_post_training_quantization_experimental(model: Module,
|
|
@@ -19,16 +19,16 @@ from model_compression_toolkit.logger import Logger
|
|
|
19
19
|
from model_compression_toolkit.constants import FOUND_TORCH
|
|
20
20
|
from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
|
|
21
21
|
|
|
22
|
-
from model_compression_toolkit.
|
|
22
|
+
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
23
23
|
TrainableQuantizerActivationConfig
|
|
24
|
-
from model_compression_toolkit.
|
|
24
|
+
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import \
|
|
25
25
|
BaseTrainableQuantizer
|
|
26
|
-
from model_compression_toolkit.
|
|
26
|
+
from model_compression_toolkit.trainable_infrastructure.pytorch.base_pytorch_quantizer import \
|
|
27
27
|
BasePytorchTrainableQuantizer
|
|
28
28
|
|
|
29
29
|
if FOUND_TORCH:
|
|
30
30
|
from torch import Tensor
|
|
31
|
-
from
|
|
31
|
+
from mct_quantizers import PytorchQuantizationWrapper
|
|
32
32
|
|
|
33
33
|
class BasePytorchGPTQTrainableQuantizer(BasePytorchTrainableQuantizer):
|
|
34
34
|
"""
|