mct-nightly 1.8.0.2032023.post428__py3-none-any.whl → 1.8.0.2042023.post413__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. {mct_nightly-1.8.0.2032023.post428.dist-info → mct_nightly-1.8.0.2042023.post413.dist-info}/METADATA +7 -7
  2. {mct_nightly-1.8.0.2032023.post428.dist-info → mct_nightly-1.8.0.2042023.post413.dist-info}/RECORD +65 -59
  3. {mct_nightly-1.8.0.2032023.post428.dist-info → mct_nightly-1.8.0.2042023.post413.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +9 -15
  5. model_compression_toolkit/core/common/logger.py +10 -2
  6. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  7. model_compression_toolkit/core/keras/back2framework/model_gradients.py +3 -2
  8. model_compression_toolkit/core/keras/quantization_facade.py +1 -1
  9. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +13 -6
  10. model_compression_toolkit/core/pytorch/constants.py +4 -0
  11. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  12. model_compression_toolkit/exporter/__init__.py +5 -0
  13. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
  14. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +1 -1
  15. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
  16. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -39
  17. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +39 -24
  18. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +50 -42
  19. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -36
  20. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +24 -5
  21. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
  22. model_compression_toolkit/gptq/__init__.py +6 -0
  23. model_compression_toolkit/gptq/common/gptq_config.py +58 -106
  24. model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
  25. model_compression_toolkit/gptq/common/gptq_training.py +28 -38
  26. model_compression_toolkit/gptq/keras/gptq_training.py +10 -28
  27. model_compression_toolkit/gptq/keras/graph_info.py +8 -33
  28. model_compression_toolkit/gptq/keras/quantization_facade.py +6 -12
  29. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +0 -1
  30. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
  31. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  32. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  33. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +38 -139
  34. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +11 -41
  35. model_compression_toolkit/gptq/pytorch/gptq_training.py +12 -4
  36. model_compression_toolkit/gptq/pytorch/graph_info.py +9 -6
  37. model_compression_toolkit/gptq/pytorch/quantization_facade.py +9 -22
  38. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +3 -1
  39. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +0 -20
  40. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -1
  41. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
  42. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  43. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py +14 -0
  44. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  45. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  46. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  47. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +9 -31
  48. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +30 -37
  49. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +27 -36
  50. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +21 -21
  51. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +25 -26
  52. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
  53. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +1 -1
  54. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
  55. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
  56. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +13 -3
  57. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
  58. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
  59. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +53 -2
  60. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +2 -1
  61. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +22 -4
  62. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +24 -3
  63. model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
  64. {mct_nightly-1.8.0.2032023.post428.dist-info → mct_nightly-1.8.0.2042023.post413.dist-info}/LICENSE.md +0 -0
  65. {mct_nightly-1.8.0.2032023.post428.dist-info → mct_nightly-1.8.0.2042023.post413.dist-info}/top_level.txt +0 -0
  66. /model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +0 -0
@@ -14,46 +14,52 @@
14
14
  # ==============================================================================
15
15
  from typing import Tuple
16
16
 
17
- import tensorflow as tf
18
- from tensorflow.keras.layers import Layer
19
17
 
20
18
  from model_compression_toolkit import quantizers_infrastructure as qi
21
19
  from model_compression_toolkit.core import common
22
- from model_compression_toolkit.core.common import Graph
20
+ from model_compression_toolkit.core.common import Graph, Logger
21
+ from model_compression_toolkit.core.common.constants import FOUND_TF
23
22
  from model_compression_toolkit.core.common.user_info import UserInformation
24
- from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
25
- from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizers import \
26
- get_quantization_quantizers
27
-
28
-
29
- def _get_wrapper(node: common.BaseNode,
30
- layer: Layer) -> qi.KerasQuantizationWrapper:
31
- """
32
- A function which takes a computational graph node and a keras layer and perform the quantization wrapping
33
- Args:
34
- n: A node of mct graph.
35
- layer: A keras layer
36
-
37
- Returns: Wrapped layer with weights quantizers and activation quantizers
38
-
39
- """
40
- weights_quantizers, activation_quantizers = get_quantization_quantizers(node)
41
- return qi.KerasQuantizationWrapper(layer, weights_quantizers, activation_quantizers)
42
-
43
-
44
- def get_exportable_keras_model(graph: Graph) -> Tuple[tf.keras.models.Model,UserInformation]:
45
- """
46
- Convert graph to an exportable Keras model (model with all quantization parameters).
47
- An exportable model can then be exported using model_exporter, to retrieve the
48
- final exported model.
49
-
50
- Args:
51
- graph: Graph to convert to an exportable Keras model.
52
-
53
- Returns:
54
- Exportable Keras model and user information.
55
- """
56
- exportable_model, user_info = KerasModelBuilder(graph=graph,
57
- wrapper=_get_wrapper).build_model()
58
- exportable_model.trainable = False
59
- return exportable_model, user_info
23
+
24
+ if FOUND_TF:
25
+ import tensorflow as tf
26
+ from tensorflow.keras.layers import Layer
27
+ from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
28
+ from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizers import get_quantization_quantizers
29
+
30
+ def _get_wrapper(node: common.BaseNode,
31
+ layer: Layer) -> qi.KerasQuantizationWrapper:
32
+ """
33
+ A function which takes a computational graph node and a keras layer and perform the quantization wrapping
34
+ Args:
35
+ n: A node of mct graph.
36
+ layer: A keras layer
37
+
38
+ Returns: Wrapped layer with weights quantizers and activation quantizers
39
+
40
+ """
41
+ weights_quantizers, activation_quantizers = get_quantization_quantizers(node)
42
+ return qi.KerasQuantizationWrapper(layer, weights_quantizers, activation_quantizers)
43
+
44
+
45
+ def get_exportable_keras_model(graph: Graph) -> Tuple[tf.keras.models.Model, UserInformation]:
46
+ """
47
+ Convert graph to an exportable Keras model (model with all quantization parameters).
48
+ An exportable model can then be exported using model_exporter, to retrieve the
49
+ final exported model.
50
+
51
+ Args:
52
+ graph: Graph to convert to an exportable Keras model.
53
+
54
+ Returns:
55
+ Exportable Keras model and user information.
56
+ """
57
+ exportable_model, user_info = KerasModelBuilder(graph=graph,
58
+ wrapper=_get_wrapper).build_model()
59
+ exportable_model.trainable = False
60
+ return exportable_model, user_info
61
+ else:
62
+ def get_exportable_keras_model(*args, **kwargs): # pragma: no cover
63
+ Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
64
+ 'when using get_exportable_keras_model. '
65
+ 'Could not find Tensorflow package.')
@@ -15,15 +15,12 @@
15
15
  from typing import Dict, Any
16
16
 
17
17
  from model_compression_toolkit.core.common import BaseNode, Logger
18
- from model_compression_toolkit.core.common.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED
18
+ from model_compression_toolkit.core.common.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED, CLUSTER_CENTERS, SCALE_PER_CHANNEL
19
19
  from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget
21
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
22
- get_inferable_quantizer_class
23
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer \
24
- import \
25
- BaseKerasInferableQuantizer
26
-
21
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import get_inferable_quantizer_class
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import BaseKerasInferableQuantizer
23
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import constants as qi_keras_consts
27
24
 
28
25
  def get_inferable_quantizer_kwargs(node: BaseNode,
29
26
  quantization_target: QuantizationTarget) -> Dict[str, Any]:
@@ -44,19 +41,29 @@ def get_inferable_quantizer_kwargs(node: BaseNode,
44
41
  # Return the appropriate quantization parameters based on the quantization method
45
42
  if quantization_method in [QuantizationMethod.POWER_OF_TWO,
46
43
  QuantizationMethod.SYMMETRIC]:
47
- return {'num_bits': node_w_qc.weights_n_bits,
48
- 'threshold': list(node_w_qc.weights_quantization_params[THRESHOLD].flatten()),
49
- 'per_channel': node_w_qc.weights_per_channel_threshold,
50
- 'channel_axis': node_w_qc.weights_channels_axis,
51
- 'input_rank': len(node_w_qc.weights_quantization_params[THRESHOLD].shape)}
44
+ return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
45
+ qi_keras_consts.THRESHOLD: list(node_w_qc.weights_quantization_params[THRESHOLD].flatten()),
46
+ qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
47
+ qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
48
+ qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[THRESHOLD].shape)}
52
49
 
53
50
  elif quantization_method in [QuantizationMethod.UNIFORM]:
54
- return {'num_bits': node_w_qc.weights_n_bits,
55
- 'per_channel': node_w_qc.weights_per_channel_threshold,
56
- 'min_range': list(node_w_qc.weights_quantization_params[RANGE_MIN].flatten()),
57
- 'max_range': list(node_w_qc.weights_quantization_params[RANGE_MAX].flatten()),
58
- 'channel_axis': node_w_qc.weights_channels_axis,
59
- 'input_rank': len(node_w_qc.weights_quantization_params[RANGE_MIN].shape)}
51
+ return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
52
+ qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
53
+ qi_keras_consts.MIN_RANGE: list(node_w_qc.weights_quantization_params[RANGE_MIN].flatten()),
54
+ qi_keras_consts.MAX_RANGE: list(node_w_qc.weights_quantization_params[RANGE_MAX].flatten()),
55
+ qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
56
+ qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[RANGE_MIN].shape)}
57
+
58
+ elif quantization_method in [QuantizationMethod.LUT_SYM_QUANTIZER, QuantizationMethod.LUT_POT_QUANTIZER]:
59
+ return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
60
+ qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
61
+ qi_keras_consts.CLUSTER_CENTERS: node_w_qc.weights_quantization_params[CLUSTER_CENTERS],
62
+ qi_keras_consts.THRESHOLD: list(node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten()),
63
+ qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
64
+ # TODO: how to pass multiplier nbits and eps for a specific node?
65
+ qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].shape)}
66
+
60
67
  else:
61
68
  Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
62
69
 
@@ -68,16 +75,24 @@ def get_inferable_quantizer_kwargs(node: BaseNode,
68
75
  # Return the appropriate quantization parameters based on the quantization method
69
76
  if quantization_method in [QuantizationMethod.POWER_OF_TWO,
70
77
  QuantizationMethod.SYMMETRIC]:
71
- return {'num_bits': node_qc.activation_n_bits,
78
+ return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
72
79
  # In activation quantization is per-tensor only - thus we hold the threshold as a list with a len of 1
73
- 'threshold': [node_qc.activation_quantization_params[THRESHOLD]],
74
- 'signed': node_qc.activation_quantization_params[SIGNED]}
80
+ qi_keras_consts.THRESHOLD: [node_qc.activation_quantization_params[THRESHOLD]],
81
+ qi_keras_consts.SIGNED: node_qc.activation_quantization_params[SIGNED]}
75
82
 
76
83
  elif quantization_method in [QuantizationMethod.UNIFORM]:
77
- return {'num_bits': node_qc.activation_n_bits,
84
+ return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
78
85
  # In activation quantization is per-tensor only - thus we hold the min/max as a list with a len of 1
79
- 'min_range': [node_qc.activation_quantization_params[RANGE_MIN]],
80
- 'max_range': [node_qc.activation_quantization_params[RANGE_MAX]]}
86
+ qi_keras_consts.MIN_RANGE: [node_qc.activation_quantization_params[RANGE_MIN]],
87
+ qi_keras_consts.MAX_RANGE: [node_qc.activation_quantization_params[RANGE_MAX]]}
88
+
89
+ elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER]:
90
+ return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
91
+ qi_keras_consts.SIGNED: node_qc.activation_quantization_params[SIGNED],
92
+ qi_keras_consts.CLUSTER_CENTERS: node_qc.activation_quantization_params[CLUSTER_CENTERS],
93
+ qi_keras_consts.THRESHOLD: [node_qc.activation_quantization_params[THRESHOLD]]
94
+ # TODO: how to pass multiplier nbits and eps for a specific node?
95
+ }
81
96
  else:
82
97
  Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
83
98
  else:
@@ -14,61 +14,69 @@
14
14
  # ==============================================================================
15
15
  from typing import Any
16
16
 
17
- from keras.engine.input_layer import InputLayer
18
17
 
19
18
  from model_compression_toolkit.core.common import Logger
20
- from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
21
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
22
-
19
+ from model_compression_toolkit.core.common.constants import FOUND_TF
23
20
 
21
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
24
22
 
25
- def is_keras_layer_exportable(layer: Any) -> bool:
26
- """
27
- Check whether a Keras layer is a valid exportable layer or not.
28
23
 
29
- Args:
30
- layer: Keras layer to check if considered to be valid for exporting.
24
+ if FOUND_TF:
25
+ from keras.engine.input_layer import InputLayer
26
+ from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
31
27
 
32
- Returns:
28
+ def is_keras_layer_exportable(layer: Any) -> bool:
29
+ """
33
30
  Check whether a Keras layer is a valid exportable layer or not.
34
- """
35
- # Keras Input layers are not wrapped
36
- if isinstance(layer, InputLayer):
37
- return True
38
31
 
39
- valid_layer = isinstance(layer, KerasQuantizationWrapper)
40
- if not valid_layer:
41
- Logger.error(
42
- f'Exportable layer must be wrapped using KerasQuantizationWrapper, but layer {layer.name} is of type '
43
- f'{type(layer)}') # pragma: no cover
32
+ Args:
33
+ layer: Keras layer to check if considered to be valid for exporting.
44
34
 
45
- valid_weights_quantizers = isinstance(layer.weights_quantizers, dict)
46
- if not valid_weights_quantizers:
47
- Logger.error(
48
- f'KerasQuantizationWrapper must have a weights_quantizers but has a '
49
- f'{type(layer.weights_quantizers)} object') # pragma: no cover
35
+ Returns:
36
+ Check whether a Keras layer is a valid exportable layer or not.
37
+ """
38
+ # Keras Input layers are not wrapped
39
+ if isinstance(layer, InputLayer):
40
+ return True
50
41
 
51
- for _, weights_quantizer in layer.weights_quantizers.items():
52
- if not isinstance(weights_quantizer, BaseInferableQuantizer):
42
+ valid_layer = isinstance(layer, KerasQuantizationWrapper)
43
+ if not valid_layer:
53
44
  Logger.error(
54
- f'weights_quantizer must be a BaseInferableQuantizer object but has a '
55
- f'{type(weights_quantizer)} object') # pragma: no cover
45
+ f'Exportable layer must be wrapped using KerasQuantizationWrapper, but layer {layer.name} is of type '
46
+ f'{type(layer)}') # pragma: no cover
56
47
 
57
- valid_activation_quantizers = isinstance(layer.activation_quantizers, list)
58
- if not valid_activation_quantizers:
59
- Logger.error(
60
- f'KerasQuantizationWrapper must have a activation_quantizers list but has a '
61
- f'{type(layer.activation_quantizers)} object') # pragma: no cover
48
+ valid_weights_quantizers = isinstance(layer.weights_quantizers, dict)
49
+ if not valid_weights_quantizers:
50
+ Logger.error(
51
+ f'KerasQuantizationWrapper must have a weights_quantizers but has a '
52
+ f'{type(layer.weights_quantizers)} object') # pragma: no cover
53
+
54
+ for _, weights_quantizer in layer.weights_quantizers.items():
55
+ if not isinstance(weights_quantizer, BaseInferableQuantizer):
56
+ Logger.error(
57
+ f'weights_quantizer must be a BaseInferableQuantizer object but has a '
58
+ f'{type(weights_quantizer)} object') # pragma: no cover
62
59
 
63
- for activation_quantizers in layer.activation_quantizers:
64
- if not isinstance(activation_quantizers, BaseInferableQuantizer):
60
+ valid_activation_quantizers = isinstance(layer.activation_quantizers, list)
61
+ if not valid_activation_quantizers:
65
62
  Logger.error(
66
- f'activation_quantizers must be a BaseInferableQuantizer object but has a '
67
- f'{type(activation_quantizers)} object') # pragma: no cover
63
+ f'KerasQuantizationWrapper must have a activation_quantizers list but has a '
64
+ f'{type(layer.activation_quantizers)} object') # pragma: no cover
68
65
 
69
- quantizers = layer.activation_quantizers + list(layer.weights_quantizers.values())
70
- is_valid_quantizers = all([isinstance(x, BaseInferableQuantizer) for x in quantizers])
71
- if not is_valid_quantizers:
72
- Logger.error(f'Found a quantizer that is not of type BaseInferableQuantizer') # pragma: no cover
66
+ for activation_quantizers in layer.activation_quantizers:
67
+ if not isinstance(activation_quantizers, BaseInferableQuantizer):
68
+ Logger.error(
69
+ f'activation_quantizers must be a BaseInferableQuantizer object but has a '
70
+ f'{type(activation_quantizers)} object') # pragma: no cover
73
71
 
74
- return True
72
+ quantizers = layer.activation_quantizers + list(layer.weights_quantizers.values())
73
+ is_valid_quantizers = all([isinstance(x, BaseInferableQuantizer) for x in quantizers])
74
+ if not is_valid_quantizers:
75
+ Logger.error(f'Found a quantizer that is not of type BaseInferableQuantizer') # pragma: no cover
76
+
77
+ return True
78
+ else:
79
+ def is_keras_layer_exportable(*args, **kwargs): # pragma: no cover
80
+ Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
81
+ 'when using is_keras_layer_exportable. '
82
+ 'Could not find Tensorflow package.')
@@ -13,42 +13,49 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import torch
17
16
 
18
17
  from model_compression_toolkit import quantizers_infrastructure as qi
19
18
  from model_compression_toolkit.core import common
20
- from model_compression_toolkit.core.common import Graph
21
- from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
22
- from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizers import \
23
- get_quantization_quantizers
24
-
25
-
26
- def fully_quantized_wrapper(node: common.BaseNode, module: torch.nn.Module) -> qi.PytorchQuantizationWrapper:
27
- """
28
- A function which takes a computational graph node and a pytorch module and
29
- perform the quantization wrapping
30
-
31
- Args:
32
- node: A node of mct graph.
33
- module: A Pytorch module
34
-
35
- Returns: Wrapped layer
36
-
37
- """
38
- weight_quantizers, activation_quantizers = get_quantization_quantizers(node)
39
- wrapped_layer = qi.PytorchQuantizationWrapper(module, weight_quantizers, activation_quantizers)
40
- return wrapped_layer
41
-
42
-
43
- def get_exportable_pytorch_model(graph: Graph):
44
- """
45
- Convert graph to fully quantized PyTorch model.
46
-
47
- Args:
48
- graph: Graph to convert to a PyTorch model.
49
-
50
- Returns:
51
- Fully quantized PyTorch model.
52
- """
53
- return PyTorchModelBuilder(graph=graph,
54
- wrapper=fully_quantized_wrapper).build_model()
19
+ from model_compression_toolkit.core.common import Graph, Logger
20
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
21
+
22
+ if FOUND_TORCH:
23
+ import torch
24
+ from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
25
+ from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizers import \
26
+ get_quantization_quantizers
27
+
28
+ def fully_quantized_wrapper(node: common.BaseNode, module: torch.nn.Module) -> qi.PytorchQuantizationWrapper:
29
+ """
30
+ A function which takes a computational graph node and a pytorch module and
31
+ perform the quantization wrapping
32
+
33
+ Args:
34
+ node: A node of mct graph.
35
+ module: A Pytorch module
36
+
37
+ Returns: Wrapped layer
38
+
39
+ """
40
+ weight_quantizers, activation_quantizers = get_quantization_quantizers(node)
41
+ wrapped_layer = qi.PytorchQuantizationWrapper(module, weight_quantizers, activation_quantizers)
42
+ return wrapped_layer
43
+
44
+
45
+ def get_exportable_pytorch_model(graph: Graph):
46
+ """
47
+ Convert graph to fully quantized PyTorch model.
48
+
49
+ Args:
50
+ graph: Graph to convert to a PyTorch model.
51
+
52
+ Returns:
53
+ Fully quantized PyTorch model.
54
+ """
55
+ return PyTorchModelBuilder(graph=graph,
56
+ wrapper=fully_quantized_wrapper).build_model()
57
+ else:
58
+ def get_exportable_pytorch_model(*args, **kwargs): # pragma: no cover
59
+ Logger.error('Installing torch is mandatory '
60
+ 'when using get_exportable_pytorch_model. '
61
+ 'Could not find PyTorch package.')
@@ -16,7 +16,8 @@
16
16
  from typing import Dict, Any
17
17
 
18
18
  from model_compression_toolkit.core.common import BaseNode, Logger
19
- from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX
19
+ from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX, \
20
+ SCALE_PER_CHANNEL, CLUSTER_CENTERS
20
21
  from model_compression_toolkit.core.common.target_platform import QuantizationMethod
21
22
  from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
22
23
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
@@ -45,6 +46,15 @@ def get_weights_inferable_quantizer_kwargs(node: BaseNode) -> Dict[str, Any]:
45
46
  qi_inferable_quantizers_constants.MIN_RANGE: node_w_qc.weights_quantization_params[RANGE_MIN].flatten(),
46
47
  qi_inferable_quantizers_constants.MAX_RANGE: node_w_qc.weights_quantization_params[RANGE_MAX].flatten(),
47
48
  qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
49
+
50
+ elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER, QuantizationMethod.LUT_SYM_QUANTIZER]:
51
+ return {qi_inferable_quantizers_constants.NUM_BITS: node_w_qc.weights_n_bits,
52
+ qi_inferable_quantizers_constants.CLUSTER_CENTERS: node_w_qc.weights_quantization_params[CLUSTER_CENTERS].flatten(),
53
+ qi_inferable_quantizers_constants.THRESHOLD: node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten(),
54
+ qi_inferable_quantizers_constants.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
55
+ qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
56
+ # TODO: Add MULTIPLIER_N_BITS & EPS to node quantization config
57
+
48
58
  else:
49
59
  Logger.critical(f'Not supported quantization method for weights inferable quantizers.') # pragma: no cover
50
60
 
@@ -65,6 +75,15 @@ def get_activation_inferable_quantizer_kwargs(node: BaseNode) -> Dict[str, Any]:
65
75
  return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.activation_n_bits,
66
76
  qi_inferable_quantizers_constants.MIN_RANGE: np.asarray([node_qc.activation_quantization_params[RANGE_MIN]]),
67
77
  qi_inferable_quantizers_constants.MAX_RANGE: np.asarray([node_qc.activation_quantization_params[RANGE_MAX]])}
78
+
79
+ elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER]:
80
+ return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.activation_n_bits,
81
+ qi_inferable_quantizers_constants.CLUSTER_CENTERS: np.asarray(
82
+ [node_qc.activation_quantization_params[CLUSTER_CENTERS]]),
83
+ qi_inferable_quantizers_constants.THRESHOLD: np.asarray(
84
+ [node_qc.activation_quantization_params[THRESHOLD]]),
85
+ qi_inferable_quantizers_constants.SIGNED: node_qc.activation_quantization_params.get(SIGNED)}
86
+ # TODO: Add MULTIPLIER_N_BITS & EPS to node quantization config
68
87
  else:
69
88
  Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
70
89
 
@@ -111,10 +130,10 @@ def get_activations_quantizer_for_node(node: BaseNode) -> BasePyTorchInferableQu
111
130
  node_act_qc = node.final_activation_quantization_cfg
112
131
  activation_quantization_method = node_act_qc.activation_quantization_method
113
132
 
114
- quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
115
- activation_quantization_method,
116
- BasePyTorchInferableQuantizer)
133
+ quantizer_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
134
+ activation_quantization_method,
135
+ BasePyTorchInferableQuantizer)
117
136
  kwargs = get_activation_inferable_quantizer_kwargs(node)
118
137
 
119
- return quantier_for_node(**kwargs)
138
+ return quantizer_for_node(**kwargs)
120
139
 
@@ -14,24 +14,31 @@
14
14
  # ==============================================================================
15
15
  from typing import Any
16
16
 
17
- from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
18
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
19
- BasePyTorchInferableQuantizer
17
+ from model_compression_toolkit.core.common import Logger
18
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
20
19
 
20
+ if FOUND_TORCH:
21
+ from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
23
+ BasePyTorchInferableQuantizer
24
+ def is_pytorch_layer_exportable(layer: Any) -> bool:
25
+ """
26
+ Check whether a torch Module is a valid exportable module or not.
21
27
 
22
- def is_pytorch_layer_exportable(layer: Any) -> bool:
23
- """
24
- Check whether a torch Module is a valid exportable module or not.
28
+ Args:
29
+ layer: PyTorch module to check if considered to be valid for exporting.
25
30
 
26
- Args:
27
- layer: PyTorch module to check if considered to be valid for exporting.
28
-
29
- Returns:
30
- Check whether a PyTorch layer is a valid exportable layer or not.
31
- """
32
- if isinstance(layer, PytorchQuantizationWrapper):
33
- quantizers = list(layer.weights_quantizers.values())
34
- quantizers.extend(layer.activation_quantizers)
35
- if all([isinstance(q, BasePyTorchInferableQuantizer) for q in quantizers]):
36
- return True
37
- return False
31
+ Returns:
32
+ Check whether a PyTorch layer is a valid exportable layer or not.
33
+ """
34
+ if isinstance(layer, PytorchQuantizationWrapper):
35
+ quantizers = list(layer.weights_quantizers.values())
36
+ quantizers.extend(layer.activation_quantizers)
37
+ if all([isinstance(q, BasePyTorchInferableQuantizer) for q in quantizers]):
38
+ return True
39
+ return False
40
+ else:
41
+ def is_pytorch_layer_exportable(*args, **kwargs): # pragma: no cover
42
+ Logger.error('Installing torch is mandatory '
43
+ 'when using is_pytorch_layer_exportable. '
44
+ 'Could not find PyTorch package.')
@@ -12,3 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
16
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GradientPTQConfigV2
17
+ from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization_experimental
18
+ from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
19
+ from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization_experimental
20
+ from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config