mct-nightly 1.9.0.20230621.post405__py3-none-any.whl → 1.9.0.20230623.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 1.9.0.20230621.post405
3
+ Version: 1.9.0.20230623.post423
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -251,13 +251,13 @@ model_compression_toolkit/exporter/model_wrapper/keras/__init__.py,sha256=cco4Tm
251
251
  model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py,sha256=ihcMbqi_UGYnDZNnTS3XouKF7dmrrBGIZbfFEzW6KXE,3543
252
252
  model_compression_toolkit/exporter/model_wrapper/keras/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
253
253
  model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py,sha256=iuo76cqmoHpF9eAc3Sqz4W-i6nnY1eeySBOdzh8bY5g,4287
254
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py,sha256=prOosEwrTEUsg4gvnZwgyLtDu2id-eMsZ97pEHHBGwM,8318
254
+ model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py,sha256=2Ex9kTJ4H9dQzV_5KHXCIk_tdgs3DFv8OwNALHS8PK8,8764
255
255
  model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py,sha256=n7VTA-a9TrLFpfdYAqrAKj6PGlAyLq8-xdwnMMpX71k,2077
256
256
  model_compression_toolkit/exporter/model_wrapper/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
257
257
  model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py,sha256=gvX5ILs5vjQ_F_dq5KaFs0GOQEq9gYXO5a6YZlYY8h4,3449
258
258
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
259
259
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=SJ5fetbUMkmB0tkHkmVhMrLksh7eqMQJLFuMD08ZKWM,3921
260
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=gNURwKHO5C3fez_SPZ9lxfp7FamN5A6W6Jp4AaGQJBE,7582
260
+ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=2ur9W4t_lMhHVXMs1b538470IU5KfQ5mGeD6AUbYd8s,8656
261
261
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py,sha256=hinP-wtyxZyoW860GdJAk6M3iPjmwwPXQTUxd56yhq8,2086
262
262
  model_compression_toolkit/gptq/__init__.py,sha256=2xos6AJziEy-eK91XtIJlunf8LhK4OayU7d6CQvXWsw,1276
263
263
  model_compression_toolkit/gptq/runner.py,sha256=vWd7cWKgTGc9oPcTtwTQZoI3MArCx19Y61uteLFCxVo,5534
@@ -276,7 +276,7 @@ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=UKnP0iZlLjKVe
276
276
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
277
277
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=RWmsUXCw051shsPZ6igkSJBzqp7r4ddW1zYzZd3g0Xs,4751
278
278
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
279
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py,sha256=fh5CNTs0S47txLn8pWJfnif4CJEb1PsQbYFGBWhOp1Q,4136
279
+ model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py,sha256=lUJ2cT5WCSoEeOD6ZTk0MxqFD5KlBtcgNktIxpMfZpI,4431
280
280
  model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py,sha256=iKzHnxl2ZSEp09oatfJVoiDuu6Q_iN36mOxQzDr1cy8,2087
281
281
  model_compression_toolkit/gptq/keras/quantizer/soft_rounding/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
282
282
  model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=t9-CQZE9AgnQ_Lq4SPd5uemvNcbtUHnU0qTHnx-QxZc,3962
@@ -293,7 +293,7 @@ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=T4512JClbZI
293
293
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
294
294
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=Zb-P0yRyZHHBlDvUBdRwxDpdduEJyJp6OT9pfKFF5ks,4171
295
295
  model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
296
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py,sha256=PHbfJf7qdqWMmTGxxdGGoGFsQhhSqTELa6Sv3jeS9sQ,3996
296
+ model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py,sha256=OAYcc34a6ozOIH4Ju4Bb1ggtSQ0MaBA-7sEpBkuui2I,4291
297
297
  model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=9owTzSu_xz29dsjONB-AYXuCZoPo_4nqxTk3yH18a0g,2089
298
298
  model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
299
299
  model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=oO7WgsAHMnWoXNm_gTKAAe-Nd79mGL_m677ai-ui424,4132
@@ -418,8 +418,8 @@ model_compression_toolkit/trainable_infrastructure/keras/load_model.py,sha256=Vq
418
418
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
419
419
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
420
420
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=SbvRlIdE32PEBsINt1bhSqvrKL_zbM9V-aeSkOn-sw4,3083
421
- mct_nightly-1.9.0.20230621.post405.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
422
- mct_nightly-1.9.0.20230621.post405.dist-info/METADATA,sha256=EwUkuJv0yr8rUVK4EJjLWewvKE3-BEGo60uHdOz7MCo,10872
423
- mct_nightly-1.9.0.20230621.post405.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
424
- mct_nightly-1.9.0.20230621.post405.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
425
- mct_nightly-1.9.0.20230621.post405.dist-info/RECORD,,
421
+ mct_nightly-1.9.0.20230623.post423.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
422
+ mct_nightly-1.9.0.20230623.post423.dist-info/METADATA,sha256=7ao3yBjgYi_woxmPzx1rJKDIhTrncSAmOsZ0Gwj5mO0,10872
423
+ mct_nightly-1.9.0.20230623.post423.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
424
+ mct_nightly-1.9.0.20230623.post423.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
425
+ mct_nightly-1.9.0.20230623.post423.dist-info/RECORD,,
@@ -16,6 +16,8 @@ from typing import Dict, Any
16
16
 
17
17
  from model_compression_toolkit.core.common import BaseNode
18
18
  from model_compression_toolkit.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED, CLUSTER_CENTERS, SCALE_PER_CHANNEL
19
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
20
+ NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
19
21
 
20
22
  from model_compression_toolkit.logger import Logger
21
23
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
@@ -24,54 +26,59 @@ from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
24
26
  from mct_quantizers.keras.quantizers import BaseKerasInferableQuantizer
25
27
  from mct_quantizers import constants as qi_keras_consts
26
28
 
27
- def get_inferable_quantizer_kwargs(node: BaseNode,
29
+
30
+ def get_inferable_quantizer_kwargs(node_qc: BaseNodeQuantizationConfig,
28
31
  quantization_target: QuantizationTarget) -> Dict[str, Any]:
29
32
  """
30
33
  Get the quantization parameters for an inferable quantizer.
31
34
  Args:
32
- node: The node for which the quantizer is being created.
35
+ node_qc: The node quantization configuration of the node for which the quantizer is being created.
36
+ Needs to match the specific quantization target.
33
37
  quantization_target: The target of the quantization (weights or activations).
38
+
34
39
  Returns:
35
40
  The quantization parameters as a dictionary.
36
41
  """
37
42
 
38
43
  if quantization_target == QuantizationTarget.Weights:
39
- # Get the weights quantization configuration for the node
40
- node_w_qc = node.final_weights_quantization_cfg
41
- quantization_method = node_w_qc.weights_quantization_method
44
+ if not isinstance(node_qc, NodeWeightsQuantizationConfig):
45
+ Logger.error(f"Non-compatible node quantization config was given for quantization target Weights.") # pragma: no cover
46
+
47
+ quantization_method = node_qc.weights_quantization_method
42
48
 
43
49
  # Return the appropriate quantization parameters based on the quantization method
44
50
  if quantization_method in [QuantizationMethod.POWER_OF_TWO,
45
51
  QuantizationMethod.SYMMETRIC]:
46
- return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
47
- qi_keras_consts.THRESHOLD: list(node_w_qc.weights_quantization_params[THRESHOLD].flatten()),
48
- qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
49
- qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
50
- qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[THRESHOLD].shape)}
52
+ return {qi_keras_consts.NUM_BITS: node_qc.weights_n_bits,
53
+ qi_keras_consts.THRESHOLD: list(node_qc.weights_quantization_params[THRESHOLD].flatten()),
54
+ qi_keras_consts.PER_CHANNEL: node_qc.weights_per_channel_threshold,
55
+ qi_keras_consts.CHANNEL_AXIS: node_qc.weights_channels_axis,
56
+ qi_keras_consts.INPUT_RANK: len(node_qc.weights_quantization_params[THRESHOLD].shape)}
51
57
 
52
58
  elif quantization_method in [QuantizationMethod.UNIFORM]:
53
- return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
54
- qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
55
- qi_keras_consts.MIN_RANGE: list(node_w_qc.weights_quantization_params[RANGE_MIN].flatten()),
56
- qi_keras_consts.MAX_RANGE: list(node_w_qc.weights_quantization_params[RANGE_MAX].flatten()),
57
- qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
58
- qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[RANGE_MIN].shape)}
59
+ return {qi_keras_consts.NUM_BITS: node_qc.weights_n_bits,
60
+ qi_keras_consts.PER_CHANNEL: node_qc.weights_per_channel_threshold,
61
+ qi_keras_consts.MIN_RANGE: list(node_qc.weights_quantization_params[RANGE_MIN].flatten()),
62
+ qi_keras_consts.MAX_RANGE: list(node_qc.weights_quantization_params[RANGE_MAX].flatten()),
63
+ qi_keras_consts.CHANNEL_AXIS: node_qc.weights_channels_axis,
64
+ qi_keras_consts.INPUT_RANK: len(node_qc.weights_quantization_params[RANGE_MIN].shape)}
59
65
 
60
66
  elif quantization_method in [QuantizationMethod.LUT_SYM_QUANTIZER, QuantizationMethod.LUT_POT_QUANTIZER]:
61
- return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
62
- qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
63
- qi_keras_consts.CLUSTER_CENTERS: node_w_qc.weights_quantization_params[CLUSTER_CENTERS],
64
- qi_keras_consts.THRESHOLD: list(node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten()),
65
- qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
67
+ return {qi_keras_consts.NUM_BITS: node_qc.weights_n_bits,
68
+ qi_keras_consts.PER_CHANNEL: node_qc.weights_per_channel_threshold,
69
+ qi_keras_consts.CLUSTER_CENTERS: node_qc.weights_quantization_params[CLUSTER_CENTERS],
70
+ qi_keras_consts.THRESHOLD: list(node_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten()),
71
+ qi_keras_consts.CHANNEL_AXIS: node_qc.weights_channels_axis,
66
72
  # TODO: how to pass multiplier nbits and eps for a specific node?
67
- qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].shape)}
73
+ qi_keras_consts.INPUT_RANK: len(node_qc.weights_quantization_params[SCALE_PER_CHANNEL].shape)}
68
74
 
69
75
  else:
70
76
  Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
71
77
 
72
78
  elif quantization_target == QuantizationTarget.Activation:
73
- # Get the activation quantization configuration for the node
74
- node_qc = node.final_activation_quantization_cfg
79
+ if not isinstance(node_qc, NodeActivationQuantizationConfig):
80
+ Logger.error(f"Non-compatible node quantization config was given for quantization target Activation.") # pragma: no cover
81
+
75
82
  quantization_method = node_qc.activation_quantization_method
76
83
 
77
84
  # Return the appropriate quantization parameters based on the quantization method
@@ -118,7 +125,7 @@ def get_weights_quantizer_for_node(node: BaseNode) -> BaseKerasInferableQuantize
118
125
  quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Weights,
119
126
  weights_quantization_method,
120
127
  BaseKerasInferableQuantizer)
121
- kwargs = get_inferable_quantizer_kwargs(node, QuantizationTarget.Weights)
128
+ kwargs = get_inferable_quantizer_kwargs(node_w_qc, QuantizationTarget.Weights)
122
129
 
123
130
  return quantier_for_node(**kwargs)
124
131
 
@@ -140,6 +147,6 @@ def get_activations_quantizer_for_node(node: BaseNode) -> BaseKerasInferableQuan
140
147
  quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
141
148
  activation_quantization_method,
142
149
  BaseKerasInferableQuantizer)
143
- kwargs = get_inferable_quantizer_kwargs(node, QuantizationTarget.Activation)
150
+ kwargs = get_inferable_quantizer_kwargs(node_act_qc, QuantizationTarget.Activation)
144
151
 
145
152
  return quantier_for_node(**kwargs)
@@ -18,6 +18,8 @@ from typing import Dict, Any
18
18
  from model_compression_toolkit.core.common import BaseNode
19
19
  from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX, \
20
20
  SCALE_PER_CHANNEL, CLUSTER_CENTERS
21
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
22
+ NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
21
23
  from model_compression_toolkit.logger import Logger
22
24
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
25
  from mct_quantizers import QuantizationTarget
@@ -28,41 +30,66 @@ from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer
28
30
  import numpy as np
29
31
 
30
32
 
31
- def get_weights_inferable_quantizer_kwargs(node: BaseNode) -> Dict[str, Any]:
32
- # Get the weights quantization configuration for the node
33
- node_w_qc = node.final_weights_quantization_cfg
34
- quantization_method = node_w_qc.weights_quantization_method
33
+ def get_weights_inferable_quantizer_kwargs(node_qc: NodeWeightsQuantizationConfig) -> Dict[str, Any]:
34
+ """
35
+ Get the quantization parameters for a weights inferable quantizer.
36
+ Args:
37
+ node_qc: The node quantization configuration of the node for which the quantizer is being created.
38
+ Needs to match the specific quantization target.
39
+
40
+ Returns:
41
+ The quantization parameters as a dictionary.
42
+ """
43
+
44
+ if not isinstance(node_qc, NodeWeightsQuantizationConfig):
45
+ Logger.error(
46
+ f"Non-compatible node quantization config was given for quantization target Weights.") # pragma: no cover
47
+
48
+ quantization_method = node_qc.weights_quantization_method
35
49
 
36
50
  # Return the appropriate quantization parameters based on the quantization method
37
51
  if quantization_method in [QuantizationMethod.POWER_OF_TWO,
38
52
  QuantizationMethod.SYMMETRIC]:
39
- return {qi_inferable_quantizers_constants.NUM_BITS: node_w_qc.weights_n_bits,
40
- qi_inferable_quantizers_constants.THRESHOLD: node_w_qc.weights_quantization_params[THRESHOLD].flatten(),
41
- qi_inferable_quantizers_constants.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
42
- qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
53
+ return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.weights_n_bits,
54
+ qi_inferable_quantizers_constants.THRESHOLD: node_qc.weights_quantization_params[THRESHOLD].flatten(),
55
+ qi_inferable_quantizers_constants.PER_CHANNEL: node_qc.weights_per_channel_threshold,
56
+ qi_inferable_quantizers_constants.CHANNEL_AXIS: node_qc.weights_channels_axis}
43
57
 
44
58
  elif quantization_method in [QuantizationMethod.UNIFORM]:
45
- return {qi_inferable_quantizers_constants.NUM_BITS: node_w_qc.weights_n_bits,
46
- qi_inferable_quantizers_constants.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
47
- qi_inferable_quantizers_constants.MIN_RANGE: node_w_qc.weights_quantization_params[RANGE_MIN].flatten(),
48
- qi_inferable_quantizers_constants.MAX_RANGE: node_w_qc.weights_quantization_params[RANGE_MAX].flatten(),
49
- qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
59
+ return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.weights_n_bits,
60
+ qi_inferable_quantizers_constants.PER_CHANNEL: node_qc.weights_per_channel_threshold,
61
+ qi_inferable_quantizers_constants.MIN_RANGE: node_qc.weights_quantization_params[RANGE_MIN].flatten(),
62
+ qi_inferable_quantizers_constants.MAX_RANGE: node_qc.weights_quantization_params[RANGE_MAX].flatten(),
63
+ qi_inferable_quantizers_constants.CHANNEL_AXIS: node_qc.weights_channels_axis}
50
64
 
51
65
  elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER, QuantizationMethod.LUT_SYM_QUANTIZER]:
52
- return {qi_inferable_quantizers_constants.NUM_BITS: node_w_qc.weights_n_bits,
53
- qi_inferable_quantizers_constants.CLUSTER_CENTERS: node_w_qc.weights_quantization_params[CLUSTER_CENTERS].flatten(),
54
- qi_inferable_quantizers_constants.THRESHOLD: node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten(),
55
- qi_inferable_quantizers_constants.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
56
- qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
66
+ return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.weights_n_bits,
67
+ qi_inferable_quantizers_constants.CLUSTER_CENTERS: node_qc.weights_quantization_params[CLUSTER_CENTERS].flatten(),
68
+ qi_inferable_quantizers_constants.THRESHOLD: node_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten(),
69
+ qi_inferable_quantizers_constants.PER_CHANNEL: node_qc.weights_per_channel_threshold,
70
+ qi_inferable_quantizers_constants.CHANNEL_AXIS: node_qc.weights_channels_axis}
57
71
  # TODO: Add MULTIPLIER_N_BITS & EPS to node quantization config
58
72
 
59
73
  else:
60
74
  Logger.critical(f'Not supported quantization method for weights inferable quantizers.') # pragma: no cover
61
75
 
62
76
 
63
- def get_activation_inferable_quantizer_kwargs(node: BaseNode) -> Dict[str, Any]:
64
- # Get the activation quantization configuration for the node
65
- node_qc = node.final_activation_quantization_cfg
77
+ def get_activation_inferable_quantizer_kwargs(node_qc: NodeActivationQuantizationConfig) -> Dict[str, Any]:
78
+ """
79
+ Get the quantization parameters for an activation inferable quantizer.
80
+
81
+ Args:
82
+ node_qc: The node quantization configuration of the node for which the quantizer is being created.
83
+ Needs to match the specific quantization target.
84
+
85
+ Returns:
86
+ The quantization parameters as a dictionary.
87
+ """
88
+
89
+ if not isinstance(node_qc, NodeActivationQuantizationConfig):
90
+ Logger.error(
91
+ f"Non-compatible node quantization config was given for quantization target Activation.") # pragma: no cover
92
+
66
93
  quantization_method = node_qc.activation_quantization_method
67
94
 
68
95
  # Return the appropriate quantization parameters based on the quantization method
@@ -109,7 +136,7 @@ def get_weights_quantizer_for_node(node: BaseNode) -> BasePyTorchInferableQuanti
109
136
  quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Weights,
110
137
  weights_quantization_method,
111
138
  BasePyTorchInferableQuantizer)
112
- kwargs = get_weights_inferable_quantizer_kwargs(node)
139
+ kwargs = get_weights_inferable_quantizer_kwargs(node_w_qc)
113
140
 
114
141
  return quantier_for_node(**kwargs)
115
142
 
@@ -134,7 +161,7 @@ def get_activations_quantizer_for_node(node: BaseNode) -> BasePyTorchInferableQu
134
161
  quantizer_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
135
162
  activation_quantization_method,
136
163
  BasePyTorchInferableQuantizer)
137
- kwargs = get_activation_inferable_quantizer_kwargs(node)
164
+ kwargs = get_activation_inferable_quantizer_kwargs(node_act_qc)
138
165
 
139
166
  return quantizer_for_node(**kwargs)
140
167
 
@@ -24,6 +24,8 @@ from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer im
24
24
  from mct_quantizers import QuantizationTarget
25
25
  from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
26
26
  from mct_quantizers.keras.quantizers import BaseKerasInferableQuantizer
27
+
28
+ from model_compression_toolkit.logger import Logger
27
29
  from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
28
30
  get_trainable_quantizer_weights_config
29
31
  from model_compression_toolkit.trainable_infrastructure.common.get_quantizers import \
@@ -63,13 +65,17 @@ def quantization_builder(n: common.BaseNode,
63
65
 
64
66
  activation_quantizers = []
65
67
  if n.is_activation_quantization_enabled():
68
+ if n.final_activation_quantization_cfg is None:
69
+ Logger.critical(f'Can not set quantizer for a node with no final activation quantization configuration') #
70
+ # pragma: no cover
71
+
66
72
  quant_method = n.final_activation_quantization_cfg.activation_quantization_method
67
73
 
68
74
  quantizer_class = get_inferable_quantizer_class(quant_target=QuantizationTarget.Activation,
69
75
  quant_method=quant_method,
70
76
  quantizer_base_class=BaseKerasInferableQuantizer)
71
77
 
72
- kwargs = get_inferable_quantizer_kwargs(n, QuantizationTarget.Activation)
78
+ kwargs = get_inferable_quantizer_kwargs(n.final_activation_quantization_cfg, QuantizationTarget.Activation)
73
79
 
74
80
  activation_quantizers.append(quantizer_class(**kwargs))
75
81
 
@@ -24,6 +24,8 @@ from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantize
24
24
  from mct_quantizers import QuantizationTarget
25
25
  from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
26
26
  from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer
27
+
28
+ from model_compression_toolkit.logger import Logger
27
29
  from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
28
30
  get_trainable_quantizer_weights_config
29
31
  from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
@@ -60,13 +62,17 @@ def quantization_builder(n: common.BaseNode,
60
62
  **gptq_config.gptq_quantizer_params_override)})
61
63
  activation_quantizers = []
62
64
  if n.is_activation_quantization_enabled():
65
+ if n.final_activation_quantization_cfg is None:
66
+ Logger.critical(f'Can not set quantizer for a node with no final activation quantization configuration') #
67
+ # pragma: no cover
68
+
63
69
  quant_method = n.final_activation_quantization_cfg.activation_quantization_method
64
70
 
65
71
  quantizer_class = get_inferable_quantizer_class(quant_target=QuantizationTarget.Activation,
66
72
  quant_method=quant_method,
67
73
  quantizer_base_class=BasePyTorchInferableQuantizer)
68
74
 
69
- kwargs = get_activation_inferable_quantizer_kwargs(n)
75
+ kwargs = get_activation_inferable_quantizer_kwargs(n.final_activation_quantization_cfg)
70
76
 
71
77
  activation_quantizers.append(quantizer_class(**kwargs))
72
78