mct-nightly 1.9.0.20230621.post405__py3-none-any.whl → 1.9.0.20230623.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.9.0.20230621.post405.dist-info → mct_nightly-1.9.0.20230623.post423.dist-info}/METADATA +1 -1
- {mct_nightly-1.9.0.20230621.post405.dist-info → mct_nightly-1.9.0.20230623.post423.dist-info}/RECORD +9 -9
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +33 -26
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +50 -23
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +7 -1
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +7 -1
- {mct_nightly-1.9.0.20230621.post405.dist-info → mct_nightly-1.9.0.20230623.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.9.0.20230621.post405.dist-info → mct_nightly-1.9.0.20230623.post423.dist-info}/WHEEL +0 -0
- {mct_nightly-1.9.0.20230621.post405.dist-info → mct_nightly-1.9.0.20230623.post423.dist-info}/top_level.txt +0 -0
{mct_nightly-1.9.0.20230621.post405.dist-info → mct_nightly-1.9.0.20230623.post423.dist-info}/RECORD
RENAMED
|
@@ -251,13 +251,13 @@ model_compression_toolkit/exporter/model_wrapper/keras/__init__.py,sha256=cco4Tm
|
|
|
251
251
|
model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py,sha256=ihcMbqi_UGYnDZNnTS3XouKF7dmrrBGIZbfFEzW6KXE,3543
|
|
252
252
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
253
253
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py,sha256=iuo76cqmoHpF9eAc3Sqz4W-i6nnY1eeySBOdzh8bY5g,4287
|
|
254
|
-
model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py,sha256=
|
|
254
|
+
model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py,sha256=2Ex9kTJ4H9dQzV_5KHXCIk_tdgs3DFv8OwNALHS8PK8,8764
|
|
255
255
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py,sha256=n7VTA-a9TrLFpfdYAqrAKj6PGlAyLq8-xdwnMMpX71k,2077
|
|
256
256
|
model_compression_toolkit/exporter/model_wrapper/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
257
257
|
model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py,sha256=gvX5ILs5vjQ_F_dq5KaFs0GOQEq9gYXO5a6YZlYY8h4,3449
|
|
258
258
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
259
259
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=SJ5fetbUMkmB0tkHkmVhMrLksh7eqMQJLFuMD08ZKWM,3921
|
|
260
|
-
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=
|
|
260
|
+
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=2ur9W4t_lMhHVXMs1b538470IU5KfQ5mGeD6AUbYd8s,8656
|
|
261
261
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py,sha256=hinP-wtyxZyoW860GdJAk6M3iPjmwwPXQTUxd56yhq8,2086
|
|
262
262
|
model_compression_toolkit/gptq/__init__.py,sha256=2xos6AJziEy-eK91XtIJlunf8LhK4OayU7d6CQvXWsw,1276
|
|
263
263
|
model_compression_toolkit/gptq/runner.py,sha256=vWd7cWKgTGc9oPcTtwTQZoI3MArCx19Y61uteLFCxVo,5534
|
|
@@ -276,7 +276,7 @@ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=UKnP0iZlLjKVe
|
|
|
276
276
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
|
277
277
|
model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=RWmsUXCw051shsPZ6igkSJBzqp7r4ddW1zYzZd3g0Xs,4751
|
|
278
278
|
model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
|
|
279
|
-
model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py,sha256=
|
|
279
|
+
model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py,sha256=lUJ2cT5WCSoEeOD6ZTk0MxqFD5KlBtcgNktIxpMfZpI,4431
|
|
280
280
|
model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py,sha256=iKzHnxl2ZSEp09oatfJVoiDuu6Q_iN36mOxQzDr1cy8,2087
|
|
281
281
|
model_compression_toolkit/gptq/keras/quantizer/soft_rounding/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
282
282
|
model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=t9-CQZE9AgnQ_Lq4SPd5uemvNcbtUHnU0qTHnx-QxZc,3962
|
|
@@ -293,7 +293,7 @@ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=T4512JClbZI
|
|
|
293
293
|
model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
|
|
294
294
|
model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=Zb-P0yRyZHHBlDvUBdRwxDpdduEJyJp6OT9pfKFF5ks,4171
|
|
295
295
|
model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
|
|
296
|
-
model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py,sha256=
|
|
296
|
+
model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py,sha256=OAYcc34a6ozOIH4Ju4Bb1ggtSQ0MaBA-7sEpBkuui2I,4291
|
|
297
297
|
model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=9owTzSu_xz29dsjONB-AYXuCZoPo_4nqxTk3yH18a0g,2089
|
|
298
298
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
|
299
299
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=oO7WgsAHMnWoXNm_gTKAAe-Nd79mGL_m677ai-ui424,4132
|
|
@@ -418,8 +418,8 @@ model_compression_toolkit/trainable_infrastructure/keras/load_model.py,sha256=Vq
|
|
|
418
418
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
|
419
419
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
420
420
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=SbvRlIdE32PEBsINt1bhSqvrKL_zbM9V-aeSkOn-sw4,3083
|
|
421
|
-
mct_nightly-1.9.0.
|
|
422
|
-
mct_nightly-1.9.0.
|
|
423
|
-
mct_nightly-1.9.0.
|
|
424
|
-
mct_nightly-1.9.0.
|
|
425
|
-
mct_nightly-1.9.0.
|
|
421
|
+
mct_nightly-1.9.0.20230623.post423.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
422
|
+
mct_nightly-1.9.0.20230623.post423.dist-info/METADATA,sha256=7ao3yBjgYi_woxmPzx1rJKDIhTrncSAmOsZ0Gwj5mO0,10872
|
|
423
|
+
mct_nightly-1.9.0.20230623.post423.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
|
|
424
|
+
mct_nightly-1.9.0.20230623.post423.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
425
|
+
mct_nightly-1.9.0.20230623.post423.dist-info/RECORD,,
|
|
@@ -16,6 +16,8 @@ from typing import Dict, Any
|
|
|
16
16
|
|
|
17
17
|
from model_compression_toolkit.core.common import BaseNode
|
|
18
18
|
from model_compression_toolkit.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED, CLUSTER_CENTERS, SCALE_PER_CHANNEL
|
|
19
|
+
from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
|
|
20
|
+
NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
|
|
19
21
|
|
|
20
22
|
from model_compression_toolkit.logger import Logger
|
|
21
23
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
@@ -24,54 +26,59 @@ from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
|
|
|
24
26
|
from mct_quantizers.keras.quantizers import BaseKerasInferableQuantizer
|
|
25
27
|
from mct_quantizers import constants as qi_keras_consts
|
|
26
28
|
|
|
27
|
-
|
|
29
|
+
|
|
30
|
+
def get_inferable_quantizer_kwargs(node_qc: BaseNodeQuantizationConfig,
|
|
28
31
|
quantization_target: QuantizationTarget) -> Dict[str, Any]:
|
|
29
32
|
"""
|
|
30
33
|
Get the quantization parameters for an inferable quantizer.
|
|
31
34
|
Args:
|
|
32
|
-
|
|
35
|
+
node_qc: The node quantization configuration of the node for which the quantizer is being created.
|
|
36
|
+
Needs to match the specific quantization target.
|
|
33
37
|
quantization_target: The target of the quantization (weights or activations).
|
|
38
|
+
|
|
34
39
|
Returns:
|
|
35
40
|
The quantization parameters as a dictionary.
|
|
36
41
|
"""
|
|
37
42
|
|
|
38
43
|
if quantization_target == QuantizationTarget.Weights:
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
44
|
+
if not isinstance(node_qc, NodeWeightsQuantizationConfig):
|
|
45
|
+
Logger.error(f"Non-compatible node quantization config was given for quantization target Weights.") # pragma: no cover
|
|
46
|
+
|
|
47
|
+
quantization_method = node_qc.weights_quantization_method
|
|
42
48
|
|
|
43
49
|
# Return the appropriate quantization parameters based on the quantization method
|
|
44
50
|
if quantization_method in [QuantizationMethod.POWER_OF_TWO,
|
|
45
51
|
QuantizationMethod.SYMMETRIC]:
|
|
46
|
-
return {qi_keras_consts.NUM_BITS:
|
|
47
|
-
qi_keras_consts.THRESHOLD: list(
|
|
48
|
-
qi_keras_consts.PER_CHANNEL:
|
|
49
|
-
qi_keras_consts.CHANNEL_AXIS:
|
|
50
|
-
qi_keras_consts.INPUT_RANK: len(
|
|
52
|
+
return {qi_keras_consts.NUM_BITS: node_qc.weights_n_bits,
|
|
53
|
+
qi_keras_consts.THRESHOLD: list(node_qc.weights_quantization_params[THRESHOLD].flatten()),
|
|
54
|
+
qi_keras_consts.PER_CHANNEL: node_qc.weights_per_channel_threshold,
|
|
55
|
+
qi_keras_consts.CHANNEL_AXIS: node_qc.weights_channels_axis,
|
|
56
|
+
qi_keras_consts.INPUT_RANK: len(node_qc.weights_quantization_params[THRESHOLD].shape)}
|
|
51
57
|
|
|
52
58
|
elif quantization_method in [QuantizationMethod.UNIFORM]:
|
|
53
|
-
return {qi_keras_consts.NUM_BITS:
|
|
54
|
-
qi_keras_consts.PER_CHANNEL:
|
|
55
|
-
qi_keras_consts.MIN_RANGE: list(
|
|
56
|
-
qi_keras_consts.MAX_RANGE: list(
|
|
57
|
-
qi_keras_consts.CHANNEL_AXIS:
|
|
58
|
-
qi_keras_consts.INPUT_RANK: len(
|
|
59
|
+
return {qi_keras_consts.NUM_BITS: node_qc.weights_n_bits,
|
|
60
|
+
qi_keras_consts.PER_CHANNEL: node_qc.weights_per_channel_threshold,
|
|
61
|
+
qi_keras_consts.MIN_RANGE: list(node_qc.weights_quantization_params[RANGE_MIN].flatten()),
|
|
62
|
+
qi_keras_consts.MAX_RANGE: list(node_qc.weights_quantization_params[RANGE_MAX].flatten()),
|
|
63
|
+
qi_keras_consts.CHANNEL_AXIS: node_qc.weights_channels_axis,
|
|
64
|
+
qi_keras_consts.INPUT_RANK: len(node_qc.weights_quantization_params[RANGE_MIN].shape)}
|
|
59
65
|
|
|
60
66
|
elif quantization_method in [QuantizationMethod.LUT_SYM_QUANTIZER, QuantizationMethod.LUT_POT_QUANTIZER]:
|
|
61
|
-
return {qi_keras_consts.NUM_BITS:
|
|
62
|
-
qi_keras_consts.PER_CHANNEL:
|
|
63
|
-
qi_keras_consts.CLUSTER_CENTERS:
|
|
64
|
-
qi_keras_consts.THRESHOLD: list(
|
|
65
|
-
qi_keras_consts.CHANNEL_AXIS:
|
|
67
|
+
return {qi_keras_consts.NUM_BITS: node_qc.weights_n_bits,
|
|
68
|
+
qi_keras_consts.PER_CHANNEL: node_qc.weights_per_channel_threshold,
|
|
69
|
+
qi_keras_consts.CLUSTER_CENTERS: node_qc.weights_quantization_params[CLUSTER_CENTERS],
|
|
70
|
+
qi_keras_consts.THRESHOLD: list(node_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten()),
|
|
71
|
+
qi_keras_consts.CHANNEL_AXIS: node_qc.weights_channels_axis,
|
|
66
72
|
# TODO: how to pass multiplier nbits and eps for a specific node?
|
|
67
|
-
qi_keras_consts.INPUT_RANK: len(
|
|
73
|
+
qi_keras_consts.INPUT_RANK: len(node_qc.weights_quantization_params[SCALE_PER_CHANNEL].shape)}
|
|
68
74
|
|
|
69
75
|
else:
|
|
70
76
|
Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
|
|
71
77
|
|
|
72
78
|
elif quantization_target == QuantizationTarget.Activation:
|
|
73
|
-
|
|
74
|
-
|
|
79
|
+
if not isinstance(node_qc, NodeActivationQuantizationConfig):
|
|
80
|
+
Logger.error(f"Non-compatible node quantization config was given for quantization target Activation.") # pragma: no cover
|
|
81
|
+
|
|
75
82
|
quantization_method = node_qc.activation_quantization_method
|
|
76
83
|
|
|
77
84
|
# Return the appropriate quantization parameters based on the quantization method
|
|
@@ -118,7 +125,7 @@ def get_weights_quantizer_for_node(node: BaseNode) -> BaseKerasInferableQuantize
|
|
|
118
125
|
quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Weights,
|
|
119
126
|
weights_quantization_method,
|
|
120
127
|
BaseKerasInferableQuantizer)
|
|
121
|
-
kwargs = get_inferable_quantizer_kwargs(
|
|
128
|
+
kwargs = get_inferable_quantizer_kwargs(node_w_qc, QuantizationTarget.Weights)
|
|
122
129
|
|
|
123
130
|
return quantier_for_node(**kwargs)
|
|
124
131
|
|
|
@@ -140,6 +147,6 @@ def get_activations_quantizer_for_node(node: BaseNode) -> BaseKerasInferableQuan
|
|
|
140
147
|
quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
|
|
141
148
|
activation_quantization_method,
|
|
142
149
|
BaseKerasInferableQuantizer)
|
|
143
|
-
kwargs = get_inferable_quantizer_kwargs(
|
|
150
|
+
kwargs = get_inferable_quantizer_kwargs(node_act_qc, QuantizationTarget.Activation)
|
|
144
151
|
|
|
145
152
|
return quantier_for_node(**kwargs)
|
|
@@ -18,6 +18,8 @@ from typing import Dict, Any
|
|
|
18
18
|
from model_compression_toolkit.core.common import BaseNode
|
|
19
19
|
from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX, \
|
|
20
20
|
SCALE_PER_CHANNEL, CLUSTER_CENTERS
|
|
21
|
+
from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
|
|
22
|
+
NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
|
|
21
23
|
from model_compression_toolkit.logger import Logger
|
|
22
24
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
23
25
|
from mct_quantizers import QuantizationTarget
|
|
@@ -28,41 +30,66 @@ from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer
|
|
|
28
30
|
import numpy as np
|
|
29
31
|
|
|
30
32
|
|
|
31
|
-
def get_weights_inferable_quantizer_kwargs(
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
def get_weights_inferable_quantizer_kwargs(node_qc: NodeWeightsQuantizationConfig) -> Dict[str, Any]:
|
|
34
|
+
"""
|
|
35
|
+
Get the quantization parameters for a weights inferable quantizer.
|
|
36
|
+
Args:
|
|
37
|
+
node_qc: The node quantization configuration of the node for which the quantizer is being created.
|
|
38
|
+
Needs to match the specific quantization target.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
The quantization parameters as a dictionary.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
if not isinstance(node_qc, NodeWeightsQuantizationConfig):
|
|
45
|
+
Logger.error(
|
|
46
|
+
f"Non-compatible node quantization config was given for quantization target Weights.") # pragma: no cover
|
|
47
|
+
|
|
48
|
+
quantization_method = node_qc.weights_quantization_method
|
|
35
49
|
|
|
36
50
|
# Return the appropriate quantization parameters based on the quantization method
|
|
37
51
|
if quantization_method in [QuantizationMethod.POWER_OF_TWO,
|
|
38
52
|
QuantizationMethod.SYMMETRIC]:
|
|
39
|
-
return {qi_inferable_quantizers_constants.NUM_BITS:
|
|
40
|
-
qi_inferable_quantizers_constants.THRESHOLD:
|
|
41
|
-
qi_inferable_quantizers_constants.PER_CHANNEL:
|
|
42
|
-
qi_inferable_quantizers_constants.CHANNEL_AXIS:
|
|
53
|
+
return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.weights_n_bits,
|
|
54
|
+
qi_inferable_quantizers_constants.THRESHOLD: node_qc.weights_quantization_params[THRESHOLD].flatten(),
|
|
55
|
+
qi_inferable_quantizers_constants.PER_CHANNEL: node_qc.weights_per_channel_threshold,
|
|
56
|
+
qi_inferable_quantizers_constants.CHANNEL_AXIS: node_qc.weights_channels_axis}
|
|
43
57
|
|
|
44
58
|
elif quantization_method in [QuantizationMethod.UNIFORM]:
|
|
45
|
-
return {qi_inferable_quantizers_constants.NUM_BITS:
|
|
46
|
-
qi_inferable_quantizers_constants.PER_CHANNEL:
|
|
47
|
-
qi_inferable_quantizers_constants.MIN_RANGE:
|
|
48
|
-
qi_inferable_quantizers_constants.MAX_RANGE:
|
|
49
|
-
qi_inferable_quantizers_constants.CHANNEL_AXIS:
|
|
59
|
+
return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.weights_n_bits,
|
|
60
|
+
qi_inferable_quantizers_constants.PER_CHANNEL: node_qc.weights_per_channel_threshold,
|
|
61
|
+
qi_inferable_quantizers_constants.MIN_RANGE: node_qc.weights_quantization_params[RANGE_MIN].flatten(),
|
|
62
|
+
qi_inferable_quantizers_constants.MAX_RANGE: node_qc.weights_quantization_params[RANGE_MAX].flatten(),
|
|
63
|
+
qi_inferable_quantizers_constants.CHANNEL_AXIS: node_qc.weights_channels_axis}
|
|
50
64
|
|
|
51
65
|
elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER, QuantizationMethod.LUT_SYM_QUANTIZER]:
|
|
52
|
-
return {qi_inferable_quantizers_constants.NUM_BITS:
|
|
53
|
-
qi_inferable_quantizers_constants.CLUSTER_CENTERS:
|
|
54
|
-
qi_inferable_quantizers_constants.THRESHOLD:
|
|
55
|
-
qi_inferable_quantizers_constants.PER_CHANNEL:
|
|
56
|
-
qi_inferable_quantizers_constants.CHANNEL_AXIS:
|
|
66
|
+
return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.weights_n_bits,
|
|
67
|
+
qi_inferable_quantizers_constants.CLUSTER_CENTERS: node_qc.weights_quantization_params[CLUSTER_CENTERS].flatten(),
|
|
68
|
+
qi_inferable_quantizers_constants.THRESHOLD: node_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten(),
|
|
69
|
+
qi_inferable_quantizers_constants.PER_CHANNEL: node_qc.weights_per_channel_threshold,
|
|
70
|
+
qi_inferable_quantizers_constants.CHANNEL_AXIS: node_qc.weights_channels_axis}
|
|
57
71
|
# TODO: Add MULTIPLIER_N_BITS & EPS to node quantization config
|
|
58
72
|
|
|
59
73
|
else:
|
|
60
74
|
Logger.critical(f'Not supported quantization method for weights inferable quantizers.') # pragma: no cover
|
|
61
75
|
|
|
62
76
|
|
|
63
|
-
def get_activation_inferable_quantizer_kwargs(
|
|
64
|
-
|
|
65
|
-
|
|
77
|
+
def get_activation_inferable_quantizer_kwargs(node_qc: NodeActivationQuantizationConfig) -> Dict[str, Any]:
|
|
78
|
+
"""
|
|
79
|
+
Get the quantization parameters for an activation inferable quantizer.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
node_qc: The node quantization configuration of the node for which the quantizer is being created.
|
|
83
|
+
Needs to match the specific quantization target.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
The quantization parameters as a dictionary.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
if not isinstance(node_qc, NodeActivationQuantizationConfig):
|
|
90
|
+
Logger.error(
|
|
91
|
+
f"Non-compatible node quantization config was given for quantization target Activation.") # pragma: no cover
|
|
92
|
+
|
|
66
93
|
quantization_method = node_qc.activation_quantization_method
|
|
67
94
|
|
|
68
95
|
# Return the appropriate quantization parameters based on the quantization method
|
|
@@ -109,7 +136,7 @@ def get_weights_quantizer_for_node(node: BaseNode) -> BasePyTorchInferableQuanti
|
|
|
109
136
|
quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Weights,
|
|
110
137
|
weights_quantization_method,
|
|
111
138
|
BasePyTorchInferableQuantizer)
|
|
112
|
-
kwargs = get_weights_inferable_quantizer_kwargs(
|
|
139
|
+
kwargs = get_weights_inferable_quantizer_kwargs(node_w_qc)
|
|
113
140
|
|
|
114
141
|
return quantier_for_node(**kwargs)
|
|
115
142
|
|
|
@@ -134,7 +161,7 @@ def get_activations_quantizer_for_node(node: BaseNode) -> BasePyTorchInferableQu
|
|
|
134
161
|
quantizer_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
|
|
135
162
|
activation_quantization_method,
|
|
136
163
|
BasePyTorchInferableQuantizer)
|
|
137
|
-
kwargs = get_activation_inferable_quantizer_kwargs(
|
|
164
|
+
kwargs = get_activation_inferable_quantizer_kwargs(node_act_qc)
|
|
138
165
|
|
|
139
166
|
return quantizer_for_node(**kwargs)
|
|
140
167
|
|
|
@@ -24,6 +24,8 @@ from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer im
|
|
|
24
24
|
from mct_quantizers import QuantizationTarget
|
|
25
25
|
from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
|
|
26
26
|
from mct_quantizers.keras.quantizers import BaseKerasInferableQuantizer
|
|
27
|
+
|
|
28
|
+
from model_compression_toolkit.logger import Logger
|
|
27
29
|
from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
|
|
28
30
|
get_trainable_quantizer_weights_config
|
|
29
31
|
from model_compression_toolkit.trainable_infrastructure.common.get_quantizers import \
|
|
@@ -63,13 +65,17 @@ def quantization_builder(n: common.BaseNode,
|
|
|
63
65
|
|
|
64
66
|
activation_quantizers = []
|
|
65
67
|
if n.is_activation_quantization_enabled():
|
|
68
|
+
if n.final_activation_quantization_cfg is None:
|
|
69
|
+
Logger.critical(f'Can not set quantizer for a node with no final activation quantization configuration') #
|
|
70
|
+
# pragma: no cover
|
|
71
|
+
|
|
66
72
|
quant_method = n.final_activation_quantization_cfg.activation_quantization_method
|
|
67
73
|
|
|
68
74
|
quantizer_class = get_inferable_quantizer_class(quant_target=QuantizationTarget.Activation,
|
|
69
75
|
quant_method=quant_method,
|
|
70
76
|
quantizer_base_class=BaseKerasInferableQuantizer)
|
|
71
77
|
|
|
72
|
-
kwargs = get_inferable_quantizer_kwargs(n, QuantizationTarget.Activation)
|
|
78
|
+
kwargs = get_inferable_quantizer_kwargs(n.final_activation_quantization_cfg, QuantizationTarget.Activation)
|
|
73
79
|
|
|
74
80
|
activation_quantizers.append(quantizer_class(**kwargs))
|
|
75
81
|
|
|
@@ -24,6 +24,8 @@ from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantize
|
|
|
24
24
|
from mct_quantizers import QuantizationTarget
|
|
25
25
|
from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
|
|
26
26
|
from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer
|
|
27
|
+
|
|
28
|
+
from model_compression_toolkit.logger import Logger
|
|
27
29
|
from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
|
|
28
30
|
get_trainable_quantizer_weights_config
|
|
29
31
|
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
|
|
@@ -60,13 +62,17 @@ def quantization_builder(n: common.BaseNode,
|
|
|
60
62
|
**gptq_config.gptq_quantizer_params_override)})
|
|
61
63
|
activation_quantizers = []
|
|
62
64
|
if n.is_activation_quantization_enabled():
|
|
65
|
+
if n.final_activation_quantization_cfg is None:
|
|
66
|
+
Logger.critical(f'Can not set quantizer for a node with no final activation quantization configuration') #
|
|
67
|
+
# pragma: no cover
|
|
68
|
+
|
|
63
69
|
quant_method = n.final_activation_quantization_cfg.activation_quantization_method
|
|
64
70
|
|
|
65
71
|
quantizer_class = get_inferable_quantizer_class(quant_target=QuantizationTarget.Activation,
|
|
66
72
|
quant_method=quant_method,
|
|
67
73
|
quantizer_base_class=BasePyTorchInferableQuantizer)
|
|
68
74
|
|
|
69
|
-
kwargs = get_activation_inferable_quantizer_kwargs(n)
|
|
75
|
+
kwargs = get_activation_inferable_quantizer_kwargs(n.final_activation_quantization_cfg)
|
|
70
76
|
|
|
71
77
|
activation_quantizers.append(quantizer_class(**kwargs))
|
|
72
78
|
|
|
File without changes
|
{mct_nightly-1.9.0.20230621.post405.dist-info → mct_nightly-1.9.0.20230623.post423.dist-info}/WHEEL
RENAMED
|
File without changes
|
|
File without changes
|