mct-nightly 1.11.0.20240213.post434__py3-none-any.whl → 1.11.0.20240215.post404__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 1.11.0.20240213.post434
3
+ Version: 1.11.0.20240215.post404
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,5 +1,5 @@
1
1
  model_compression_toolkit/__init__.py,sha256=WXRBerevhP4sQ4NIHd-tcHcMEan7Qx_Wz1sTtF-HuQc,3697
2
- model_compression_toolkit/constants.py,sha256=5aystyH4YQv3J9X3Xx3eQvnfFBpo1NDju8jwfqH4z2A,4131
2
+ model_compression_toolkit/constants.py,sha256=_OW_bUeQmf08Bb4oVZ0KfUt-rcCeNOmdBv3aP7NF5fM,3631
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=b9DVktZ-LymFcRxv2aL_sdiE6S2sSrFGWltx6dgEuUY,4863
5
5
  model_compression_toolkit/core/__init__.py,sha256=pRP8FZ_46vpd6MVrcec5O5wnoByQqRzq_tMzaDRiMmM,1934
@@ -99,7 +99,7 @@ model_compression_toolkit/core/common/pruning/mask/__init__.py,sha256=huHoBUcKNB
99
99
  model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py,sha256=hkm8xU4o9LvFeCc_KRg7PGYd_eQa6Kbjx-rGHvgajnA,5054
100
100
  model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256=gmzD32xsfJH8vkkqaspS7vYa6VWayk1GJe-NfoAEugQ,5901
101
101
  model_compression_toolkit/core/common/quantization/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
102
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=2iu35iI5gnWLHBKSaLVsPQWr1ssly6Z-gbaNCauvcQM,3223
102
+ model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=VW0vGCPKMxbhl2cB_zHx3g9c7cNS4ctVEAvnaNq17jw,5153
103
103
  model_compression_toolkit/core/common/quantization/core_config.py,sha256=8DRM4Ar4Er-bllo56LG-Lcx9U2Ebd3jJctf4t2hOcXc,2021
104
104
  model_compression_toolkit/core/common/quantization/debug_config.py,sha256=HtkMmneN-EmAzgZK4Vp4M8Sqm5QKdrvNyyZMpaVqYzY,1482
105
105
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=4GCr4Z6pRMbxIAnq4s7YtdMSqwbRwUzTzCFfs2ahVfk,6137
@@ -110,7 +110,7 @@ model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,
110
110
  model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=sEPDeClFxh0uHEGznX7E3bSOJ_t0kUvyWcdxcyZJdwA,4090
111
111
  model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=_OQEFAdYDTHu2Qp-qs02Z1CDxugUKG6k5eCePS1WpXY,2939
112
112
  model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=UK_YshvZI0-LrKeT9gFGYcMA7pma1kaR5JAfzJH3HNw,3614
113
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=3Z5XiR6bESBrPPkvq9JetndiJ-R7hwXNQgMtxo_P2mc,12302
113
+ model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=qkrWJXLyDSIJhvT8tO9Nh51f4abyVR8zMFuaaMRRrRw,12304
114
114
  model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=_U4IFPuzGyyAymjDjsPl2NF6UbFggqBaiA1Td3sug3I,1608
115
115
  model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=rwCedE0zggamSBY50rqh-xqZpIMrn8o96YH_jMCuPrk,16505
116
116
  model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py,sha256=qDfJbvY64KLOG6n18ddEPTFGrKHlaXzZ136TrVpgH9s,2917
@@ -475,8 +475,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
475
475
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
476
476
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
477
477
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=SbvRlIdE32PEBsINt1bhSqvrKL_zbM9V-aeSkOn-sw4,3083
478
- mct_nightly-1.11.0.20240213.post434.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
479
- mct_nightly-1.11.0.20240213.post434.dist-info/METADATA,sha256=9B2zsXw6Y1lq0vqBff_BZTxhxbQcU51CifFeZ5YlDLQ,17187
480
- mct_nightly-1.11.0.20240213.post434.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
481
- mct_nightly-1.11.0.20240213.post434.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
482
- mct_nightly-1.11.0.20240213.post434.dist-info/RECORD,,
478
+ mct_nightly-1.11.0.20240215.post404.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
479
+ mct_nightly-1.11.0.20240215.post404.dist-info/METADATA,sha256=4EU04O01WY1IopkRRWpPxNvWs6FTE-1etC_M8XzDLaM,17187
480
+ mct_nightly-1.11.0.20240215.post404.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
481
+ mct_nightly-1.11.0.20240215.post404.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
482
+ mct_nightly-1.11.0.20240215.post404.dist-info/RECORD,,
@@ -104,18 +104,6 @@ VIRTUAL_WEIGHTS_SUFFIX = '_v_weights'
104
104
  VIRTUAL_ACTIVATION_SUFFIX = '_v_activation'
105
105
  VIRTUAL_ACTIVATION_WEIGHTS_NODE_PREFIX = 'virtual'
106
106
 
107
- # Quantization config candidate initialization
108
- ACTIVATION_QUANTIZATION_CFG = 'activation_quantization_cfg'
109
- WEIGHTS_QUANTIZATION_CFG = 'weights_quantization_cfg'
110
- QC = 'qc'
111
- OP_CFG = 'op_cfg'
112
- ACTIVATION_QUANTIZATION_FN = 'activation_quantization_fn'
113
- WEIGHTS_QUANTIZATION_FN = 'weights_quantization_fn'
114
- ACTIVATION_QUANT_PARAMS_FN = 'activation_quantization_params_fn'
115
- WEIGHTS_QUANT_PARAMS_FN = 'weights_quantization_params_fn'
116
- WEIGHTS_CHANNELS_AXIS = 'weights_channels_axis'
117
- WEIGHTS_CFG = 'weights_cfg'
118
-
119
107
  # Memory graph constants
120
108
  DUMMY_NODE = 'dummy_node'
121
109
  DUMMY_TENSOR = 'dummy_tensor'
@@ -12,12 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from model_compression_toolkit.constants import ACTIVATION_QUANTIZATION_CFG, WEIGHTS_QUANTIZATION_CFG, QC, \
16
- OP_CFG, ACTIVATION_QUANTIZATION_FN, WEIGHTS_QUANTIZATION_FN, ACTIVATION_QUANT_PARAMS_FN, WEIGHTS_QUANT_PARAMS_FN, \
17
- WEIGHTS_CHANNELS_AXIS, WEIGHTS_CFG
15
+ from typing import Callable
16
+
17
+ from model_compression_toolkit.core import QuantizationConfig
18
18
  from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
19
19
  NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
20
- from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \
21
+ AttributeQuantizationConfig
22
+ from model_compression_toolkit.logger import Logger
21
23
 
22
24
 
23
25
  ##########################################
@@ -31,22 +33,55 @@ class CandidateNodeQuantizationConfig(BaseNodeQuantizationConfig):
31
33
  Class for representing candidate node configuration, which includes weights and activation configuration combined.
32
34
  """
33
35
 
34
- def __init__(self, **kwargs):
35
- activation_quantization_cfg = kwargs.get(ACTIVATION_QUANTIZATION_CFG, None)
36
+ def __init__(self,
37
+ qc: QuantizationConfig = None,
38
+ op_cfg: OpQuantizationConfig = None,
39
+ activation_quantization_cfg: NodeActivationQuantizationConfig = None,
40
+ activation_quantization_fn: Callable = None,
41
+ activation_quantization_params_fn: Callable = None,
42
+ weights_quantization_cfg: NodeWeightsQuantizationConfig = None,
43
+ weights_quantization_fn: Callable = None,
44
+ weights_quantization_params_fn: Callable = None,
45
+ weights_channels_axis: int = None,
46
+ weights_cfg: AttributeQuantizationConfig = None):
47
+ """
48
+
49
+ Args:
50
+ qc: QuantizationConfig to create the node's config from.
51
+ op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
52
+ activation_quantization_cfg: An option to pass a NodeActivationQuantizationConfig to create a new config from.
53
+ activation_quantization_fn: Function to use when quantizing the node's activations.
54
+ activation_quantization_params_fn: Function to use when computing the threshold for quantizing a node's activations.
55
+ weights_quantization_cfg: An option to pass a NodeWeightsQuantizationConfig to create a new config from.
56
+ weights_quantization_fn: Function to use when quantizing the node's weights.
57
+ weights_quantization_params_fn: Function to use when computing the threshold for quantizing a node's weights.
58
+ weights_channels_axis: Axis to quantize a node's kernel when quantizing per-channel.
59
+ weights_cfg: Weights attribute quantization config.
60
+ """
61
+
36
62
  if activation_quantization_cfg is not None:
37
63
  self.activation_quantization_cfg = activation_quantization_cfg
38
64
  else:
39
- self.activation_quantization_cfg = NodeActivationQuantizationConfig(kwargs.get(QC),
40
- kwargs.get(OP_CFG),
41
- kwargs.get(ACTIVATION_QUANTIZATION_FN),
42
- kwargs.get(ACTIVATION_QUANT_PARAMS_FN))
43
- weights_quantization_cfg = kwargs.get(WEIGHTS_QUANTIZATION_CFG, None)
65
+ if any(v is None for v in (qc, op_cfg, activation_quantization_fn, activation_quantization_params_fn)):
66
+ Logger.error("Missing some required arguments to initialize "
67
+ "a node activation quantization configuration.")
68
+ self.activation_quantization_cfg = (
69
+ NodeActivationQuantizationConfig(qc=qc,
70
+ op_cfg=op_cfg,
71
+ activation_quantization_fn=activation_quantization_fn,
72
+ activation_quantization_params_fn=activation_quantization_params_fn))
73
+
44
74
  if weights_quantization_cfg is not None:
45
75
  self.weights_quantization_cfg = weights_quantization_cfg
46
76
  else:
47
- self.weights_quantization_cfg = NodeWeightsQuantizationConfig(kwargs.get(QC),
48
- kwargs.get(OP_CFG),
49
- kwargs.get(WEIGHTS_QUANTIZATION_FN),
50
- kwargs.get(WEIGHTS_QUANT_PARAMS_FN),
51
- kwargs.get(WEIGHTS_CHANNELS_AXIS),
52
- kwargs.get(WEIGHTS_CFG))
77
+ if any(v is None for v in (qc, op_cfg, weights_quantization_fn, weights_quantization_params_fn,
78
+ weights_cfg)):
79
+ Logger.error("Missing some required arguments to initialize "
80
+ "a node weights quantization configuration.")
81
+ self.weights_quantization_cfg = (
82
+ NodeWeightsQuantizationConfig(qc=qc,
83
+ op_cfg=op_cfg,
84
+ weights_quantization_fn=weights_quantization_fn,
85
+ weights_quantization_params_fn=weights_quantization_params_fn,
86
+ weights_channels_axis=weights_channels_axis,
87
+ weights_cfg=weights_cfg))
@@ -169,7 +169,7 @@ def _create_node_single_candidate_qc(qc: QuantizationConfig,
169
169
  activation_quantization_params_fn=activation_quantization_params_fn,
170
170
  weights_quantization_fn=weights_quantization_fn,
171
171
  weights_quantization_params_fn=weights_quantization_params_fn,
172
- weight_channel_axis=weight_channel_axis,
172
+ weights_channels_axis=weight_channel_axis,
173
173
  weights_cfg=weights_cfg)
174
174
 
175
175