mct-nightly 1.11.0.20240214.post405__py3-none-any.whl → 1.11.0.20240215.post404__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.11.0.20240214.post405.dist-info → mct_nightly-1.11.0.20240215.post404.dist-info}/METADATA +1 -1
- {mct_nightly-1.11.0.20240214.post405.dist-info → mct_nightly-1.11.0.20240215.post404.dist-info}/RECORD +8 -8
- model_compression_toolkit/constants.py +0 -12
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +52 -17
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +1 -1
- {mct_nightly-1.11.0.20240214.post405.dist-info → mct_nightly-1.11.0.20240215.post404.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.11.0.20240214.post405.dist-info → mct_nightly-1.11.0.20240215.post404.dist-info}/WHEEL +0 -0
- {mct_nightly-1.11.0.20240214.post405.dist-info → mct_nightly-1.11.0.20240215.post404.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
model_compression_toolkit/__init__.py,sha256=WXRBerevhP4sQ4NIHd-tcHcMEan7Qx_Wz1sTtF-HuQc,3697
|
|
2
|
-
model_compression_toolkit/constants.py,sha256=
|
|
2
|
+
model_compression_toolkit/constants.py,sha256=_OW_bUeQmf08Bb4oVZ0KfUt-rcCeNOmdBv3aP7NF5fM,3631
|
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
|
4
4
|
model_compression_toolkit/logger.py,sha256=b9DVktZ-LymFcRxv2aL_sdiE6S2sSrFGWltx6dgEuUY,4863
|
|
5
5
|
model_compression_toolkit/core/__init__.py,sha256=pRP8FZ_46vpd6MVrcec5O5wnoByQqRzq_tMzaDRiMmM,1934
|
|
@@ -99,7 +99,7 @@ model_compression_toolkit/core/common/pruning/mask/__init__.py,sha256=huHoBUcKNB
|
|
|
99
99
|
model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py,sha256=hkm8xU4o9LvFeCc_KRg7PGYd_eQa6Kbjx-rGHvgajnA,5054
|
|
100
100
|
model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256=gmzD32xsfJH8vkkqaspS7vYa6VWayk1GJe-NfoAEugQ,5901
|
|
101
101
|
model_compression_toolkit/core/common/quantization/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
|
102
|
-
model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=
|
|
102
|
+
model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=VW0vGCPKMxbhl2cB_zHx3g9c7cNS4ctVEAvnaNq17jw,5153
|
|
103
103
|
model_compression_toolkit/core/common/quantization/core_config.py,sha256=8DRM4Ar4Er-bllo56LG-Lcx9U2Ebd3jJctf4t2hOcXc,2021
|
|
104
104
|
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=HtkMmneN-EmAzgZK4Vp4M8Sqm5QKdrvNyyZMpaVqYzY,1482
|
|
105
105
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=4GCr4Z6pRMbxIAnq4s7YtdMSqwbRwUzTzCFfs2ahVfk,6137
|
|
@@ -110,7 +110,7 @@ model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,
|
|
|
110
110
|
model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=sEPDeClFxh0uHEGznX7E3bSOJ_t0kUvyWcdxcyZJdwA,4090
|
|
111
111
|
model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=_OQEFAdYDTHu2Qp-qs02Z1CDxugUKG6k5eCePS1WpXY,2939
|
|
112
112
|
model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=UK_YshvZI0-LrKeT9gFGYcMA7pma1kaR5JAfzJH3HNw,3614
|
|
113
|
-
model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=
|
|
113
|
+
model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=qkrWJXLyDSIJhvT8tO9Nh51f4abyVR8zMFuaaMRRrRw,12304
|
|
114
114
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=_U4IFPuzGyyAymjDjsPl2NF6UbFggqBaiA1Td3sug3I,1608
|
|
115
115
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=rwCedE0zggamSBY50rqh-xqZpIMrn8o96YH_jMCuPrk,16505
|
|
116
116
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py,sha256=qDfJbvY64KLOG6n18ddEPTFGrKHlaXzZ136TrVpgH9s,2917
|
|
@@ -475,8 +475,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
|
|
|
475
475
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
|
476
476
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
477
477
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=SbvRlIdE32PEBsINt1bhSqvrKL_zbM9V-aeSkOn-sw4,3083
|
|
478
|
-
mct_nightly-1.11.0.
|
|
479
|
-
mct_nightly-1.11.0.
|
|
480
|
-
mct_nightly-1.11.0.
|
|
481
|
-
mct_nightly-1.11.0.
|
|
482
|
-
mct_nightly-1.11.0.
|
|
478
|
+
mct_nightly-1.11.0.20240215.post404.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
479
|
+
mct_nightly-1.11.0.20240215.post404.dist-info/METADATA,sha256=4EU04O01WY1IopkRRWpPxNvWs6FTE-1etC_M8XzDLaM,17187
|
|
480
|
+
mct_nightly-1.11.0.20240215.post404.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
|
481
|
+
mct_nightly-1.11.0.20240215.post404.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
482
|
+
mct_nightly-1.11.0.20240215.post404.dist-info/RECORD,,
|
|
@@ -104,18 +104,6 @@ VIRTUAL_WEIGHTS_SUFFIX = '_v_weights'
|
|
|
104
104
|
VIRTUAL_ACTIVATION_SUFFIX = '_v_activation'
|
|
105
105
|
VIRTUAL_ACTIVATION_WEIGHTS_NODE_PREFIX = 'virtual'
|
|
106
106
|
|
|
107
|
-
# Quantization config candidate initialization
|
|
108
|
-
ACTIVATION_QUANTIZATION_CFG = 'activation_quantization_cfg'
|
|
109
|
-
WEIGHTS_QUANTIZATION_CFG = 'weights_quantization_cfg'
|
|
110
|
-
QC = 'qc'
|
|
111
|
-
OP_CFG = 'op_cfg'
|
|
112
|
-
ACTIVATION_QUANTIZATION_FN = 'activation_quantization_fn'
|
|
113
|
-
WEIGHTS_QUANTIZATION_FN = 'weights_quantization_fn'
|
|
114
|
-
ACTIVATION_QUANT_PARAMS_FN = 'activation_quantization_params_fn'
|
|
115
|
-
WEIGHTS_QUANT_PARAMS_FN = 'weights_quantization_params_fn'
|
|
116
|
-
WEIGHTS_CHANNELS_AXIS = 'weights_channels_axis'
|
|
117
|
-
WEIGHTS_CFG = 'weights_cfg'
|
|
118
|
-
|
|
119
107
|
# Memory graph constants
|
|
120
108
|
DUMMY_NODE = 'dummy_node'
|
|
121
109
|
DUMMY_TENSOR = 'dummy_tensor'
|
|
@@ -12,12 +12,14 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from
|
|
16
|
-
|
|
17
|
-
|
|
15
|
+
from typing import Callable
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.core import QuantizationConfig
|
|
18
18
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
|
|
19
19
|
NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
|
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.
|
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \
|
|
21
|
+
AttributeQuantizationConfig
|
|
22
|
+
from model_compression_toolkit.logger import Logger
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
##########################################
|
|
@@ -31,22 +33,55 @@ class CandidateNodeQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
31
33
|
Class for representing candidate node configuration, which includes weights and activation configuration combined.
|
|
32
34
|
"""
|
|
33
35
|
|
|
34
|
-
def __init__(self,
|
|
35
|
-
|
|
36
|
+
def __init__(self,
|
|
37
|
+
qc: QuantizationConfig = None,
|
|
38
|
+
op_cfg: OpQuantizationConfig = None,
|
|
39
|
+
activation_quantization_cfg: NodeActivationQuantizationConfig = None,
|
|
40
|
+
activation_quantization_fn: Callable = None,
|
|
41
|
+
activation_quantization_params_fn: Callable = None,
|
|
42
|
+
weights_quantization_cfg: NodeWeightsQuantizationConfig = None,
|
|
43
|
+
weights_quantization_fn: Callable = None,
|
|
44
|
+
weights_quantization_params_fn: Callable = None,
|
|
45
|
+
weights_channels_axis: int = None,
|
|
46
|
+
weights_cfg: AttributeQuantizationConfig = None):
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
qc: QuantizationConfig to create the node's config from.
|
|
51
|
+
op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
|
|
52
|
+
activation_quantization_cfg: An option to pass a NodeActivationQuantizationConfig to create a new config from.
|
|
53
|
+
activation_quantization_fn: Function to use when quantizing the node's activations.
|
|
54
|
+
activation_quantization_params_fn: Function to use when computing the threshold for quantizing a node's activations.
|
|
55
|
+
weights_quantization_cfg: An option to pass a NodeWeightsQuantizationConfig to create a new config from.
|
|
56
|
+
weights_quantization_fn: Function to use when quantizing the node's weights.
|
|
57
|
+
weights_quantization_params_fn: Function to use when computing the threshold for quantizing a node's weights.
|
|
58
|
+
weights_channels_axis: Axis to quantize a node's kernel when quantizing per-channel.
|
|
59
|
+
weights_cfg: Weights attribute quantization config.
|
|
60
|
+
"""
|
|
61
|
+
|
|
36
62
|
if activation_quantization_cfg is not None:
|
|
37
63
|
self.activation_quantization_cfg = activation_quantization_cfg
|
|
38
64
|
else:
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
65
|
+
if any(v is None for v in (qc, op_cfg, activation_quantization_fn, activation_quantization_params_fn)):
|
|
66
|
+
Logger.error("Missing some required arguments to initialize "
|
|
67
|
+
"a node activation quantization configuration.")
|
|
68
|
+
self.activation_quantization_cfg = (
|
|
69
|
+
NodeActivationQuantizationConfig(qc=qc,
|
|
70
|
+
op_cfg=op_cfg,
|
|
71
|
+
activation_quantization_fn=activation_quantization_fn,
|
|
72
|
+
activation_quantization_params_fn=activation_quantization_params_fn))
|
|
73
|
+
|
|
44
74
|
if weights_quantization_cfg is not None:
|
|
45
75
|
self.weights_quantization_cfg = weights_quantization_cfg
|
|
46
76
|
else:
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
77
|
+
if any(v is None for v in (qc, op_cfg, weights_quantization_fn, weights_quantization_params_fn,
|
|
78
|
+
weights_cfg)):
|
|
79
|
+
Logger.error("Missing some required arguments to initialize "
|
|
80
|
+
"a node weights quantization configuration.")
|
|
81
|
+
self.weights_quantization_cfg = (
|
|
82
|
+
NodeWeightsQuantizationConfig(qc=qc,
|
|
83
|
+
op_cfg=op_cfg,
|
|
84
|
+
weights_quantization_fn=weights_quantization_fn,
|
|
85
|
+
weights_quantization_params_fn=weights_quantization_params_fn,
|
|
86
|
+
weights_channels_axis=weights_channels_axis,
|
|
87
|
+
weights_cfg=weights_cfg))
|
|
@@ -169,7 +169,7 @@ def _create_node_single_candidate_qc(qc: QuantizationConfig,
|
|
|
169
169
|
activation_quantization_params_fn=activation_quantization_params_fn,
|
|
170
170
|
weights_quantization_fn=weights_quantization_fn,
|
|
171
171
|
weights_quantization_params_fn=weights_quantization_params_fn,
|
|
172
|
-
|
|
172
|
+
weights_channels_axis=weight_channel_axis,
|
|
173
173
|
weights_cfg=weights_cfg)
|
|
174
174
|
|
|
175
175
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|