mct-nightly 1.8.0.20052023.post401__py3-none-any.whl → 1.8.0.20230610.post356__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/METADATA +10 -7
  2. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/RECORD +68 -115
  3. model_compression_toolkit/__init__.py +23 -3
  4. model_compression_toolkit/core/common/framework_info.py +1 -1
  5. model_compression_toolkit/core/keras/back2framework/instance_builder.py +16 -9
  6. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +8 -34
  7. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +5 -1
  8. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +103 -28
  9. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +39 -44
  10. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_tflite_exporter.py +1 -1
  11. model_compression_toolkit/exporter/model_exporter/keras/int8_tflite_exporter.py +20 -18
  12. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +3 -3
  13. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  14. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +36 -9
  15. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +4 -4
  16. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +24 -32
  17. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +31 -8
  18. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +5 -5
  19. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +34 -8
  20. model_compression_toolkit/gptq/keras/gptq_training.py +15 -16
  21. model_compression_toolkit/gptq/keras/graph_info.py +2 -2
  22. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +4 -5
  23. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +5 -7
  24. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -1
  25. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +6 -6
  26. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +7 -7
  27. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +6 -6
  28. model_compression_toolkit/gptq/pytorch/gptq_training.py +30 -10
  29. model_compression_toolkit/gptq/pytorch/graph_info.py +5 -2
  30. model_compression_toolkit/gptq/pytorch/quantization_facade.py +4 -2
  31. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +4 -4
  32. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +5 -7
  33. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -1
  34. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +7 -7
  35. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +7 -8
  36. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +7 -8
  37. model_compression_toolkit/qat/common/__init__.py +2 -1
  38. model_compression_toolkit/qat/common/qat_config.py +2 -2
  39. model_compression_toolkit/qat/keras/quantization_facade.py +18 -8
  40. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +1 -1
  41. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +11 -11
  42. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +11 -12
  43. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +12 -13
  44. model_compression_toolkit/qat/pytorch/quantization_facade.py +27 -16
  45. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  46. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +31 -4
  47. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +10 -9
  48. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +11 -10
  49. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +2 -1
  50. model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +1 -25
  51. model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py → trainable_infrastructure/__init__.py} +3 -10
  52. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/base_trainable_quantizer.py +3 -3
  53. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/get_quantizer_config.py +1 -1
  54. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/get_quantizers.py +3 -3
  55. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/base_keras_quantizer.py +4 -4
  56. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/config_serialization.py +2 -2
  57. model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure → trainable_infrastructure}/keras/load_model.py +16 -23
  58. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/pytorch/base_pytorch_quantizer.py +3 -3
  59. model_compression_toolkit/quantizers_infrastructure/__init__.py +0 -23
  60. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +0 -87
  61. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +0 -46
  62. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +0 -31
  63. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +0 -53
  64. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +0 -49
  65. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +0 -147
  66. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +0 -345
  67. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +0 -85
  68. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +0 -27
  69. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +0 -14
  70. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +0 -148
  71. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +0 -65
  72. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +0 -86
  73. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +0 -111
  74. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +0 -56
  75. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +0 -14
  76. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +0 -79
  77. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +0 -179
  78. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +0 -67
  79. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +0 -87
  80. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +0 -163
  81. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +0 -66
  82. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +0 -14
  83. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +0 -269
  84. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +0 -152
  85. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +0 -35
  86. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/__init__.py +0 -14
  87. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +0 -96
  88. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +0 -62
  89. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +0 -83
  90. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +0 -100
  91. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +0 -95
  92. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +0 -48
  93. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +0 -70
  94. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +0 -57
  95. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +0 -26
  96. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +0 -14
  97. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +0 -77
  98. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +0 -106
  99. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +0 -66
  100. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +0 -104
  101. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +0 -109
  102. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +0 -14
  103. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +0 -14
  104. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +0 -14
  105. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +0 -14
  106. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/LICENSE.md +0 -0
  107. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/WHEEL +0 -0
  108. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/top_level.txt +0 -0
  109. /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure → trainable_infrastructure/common}/__init__.py +0 -0
  110. /model_compression_toolkit/{quantizers_infrastructure → trainable_infrastructure/common}/constants.py +0 -0
  111. /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/quant_utils.py +0 -0
  112. /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/trainable_quantizer_config.py +0 -0
  113. /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/common → trainable_infrastructure/keras}/__init__.py +0 -0
  114. /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/quantizer_utils.py +0 -0
  115. /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/keras → trainable_infrastructure/pytorch}/__init__.py +0 -0
@@ -14,16 +14,14 @@
14
14
  # ==============================================================================
15
15
  from typing import Any
16
16
 
17
-
18
- from model_compression_toolkit.logger import Logger
17
+ from mct_quantizers import BaseInferableQuantizer, KerasActivationQuantizationHolder
19
18
  from model_compression_toolkit.constants import FOUND_TF
20
-
21
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
22
-
19
+ from model_compression_toolkit.logger import Logger
23
20
 
24
21
  if FOUND_TF:
22
+ from keras.engine.base_layer import Layer
25
23
  from keras.engine.input_layer import InputLayer
26
- from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
24
+ from mct_quantizers import KerasQuantizationWrapper
27
25
 
28
26
  def is_keras_layer_exportable(layer: Any) -> bool:
29
27
  """
@@ -39,40 +37,34 @@ if FOUND_TF:
39
37
  if isinstance(layer, InputLayer):
40
38
  return True
41
39
 
42
- valid_layer = isinstance(layer, KerasQuantizationWrapper)
40
+ valid_layer = isinstance(layer, Layer)
43
41
  if not valid_layer:
44
42
  Logger.error(
45
- f'Exportable layer must be wrapped using KerasQuantizationWrapper, but layer {layer.name} is of type '
43
+ f'Exportable layer must be a Keras layer, but layer {layer.name} is of type '
46
44
  f'{type(layer)}') # pragma: no cover
47
45
 
48
- valid_weights_quantizers = isinstance(layer.weights_quantizers, dict)
49
- if not valid_weights_quantizers:
50
- Logger.error(
51
- f'KerasQuantizationWrapper must have a weights_quantizers but has a '
52
- f'{type(layer.weights_quantizers)} object') # pragma: no cover
53
-
54
- for _, weights_quantizer in layer.weights_quantizers.items():
55
- if not isinstance(weights_quantizer, BaseInferableQuantizer):
46
+ if isinstance(layer, KerasQuantizationWrapper):
47
+ valid_weights_quantizers = isinstance(layer.weights_quantizers, dict)
48
+ if not valid_weights_quantizers:
56
49
  Logger.error(
57
- f'weights_quantizer must be a BaseInferableQuantizer object but has a '
58
- f'{type(weights_quantizer)} object') # pragma: no cover
50
+ f'KerasQuantizationWrapper must have a weights_quantizers but has a '
51
+ f'{type(layer.weights_quantizers)} object') # pragma: no cover
59
52
 
60
- valid_activation_quantizers = isinstance(layer.activation_quantizers, list)
61
- if not valid_activation_quantizers:
62
- Logger.error(
63
- f'KerasQuantizationWrapper must have a activation_quantizers list but has a '
64
- f'{type(layer.activation_quantizers)} object') # pragma: no cover
53
+ if len(layer.weights_quantizers) == 0:
54
+ Logger.error(f'KerasQuantizationWrapper must have at least one weight quantizer, but found {len(layer.weights_quantizers)} quantizers. If layer is not quantized it should be a Keras layer.')
65
55
 
66
- for activation_quantizers in layer.activation_quantizers:
67
- if not isinstance(activation_quantizers, BaseInferableQuantizer):
68
- Logger.error(
69
- f'activation_quantizers must be a BaseInferableQuantizer object but has a '
70
- f'{type(activation_quantizers)} object') # pragma: no cover
56
+ for _, weights_quantizer in layer.weights_quantizers.items():
57
+ if not isinstance(weights_quantizer, BaseInferableQuantizer):
58
+ Logger.error(
59
+ f'weights_quantizer must be a BaseInferableQuantizer object but has a '
60
+ f'{type(weights_quantizer)} object') # pragma: no cover
71
61
 
72
- quantizers = layer.activation_quantizers + list(layer.weights_quantizers.values())
73
- is_valid_quantizers = all([isinstance(x, BaseInferableQuantizer) for x in quantizers])
74
- if not is_valid_quantizers:
75
- Logger.error(f'Found a quantizer that is not of type BaseInferableQuantizer') # pragma: no cover
62
+ if isinstance(layer, KerasActivationQuantizationHolder):
63
+ if not isinstance(layer.activation_holder_quantizer, BaseInferableQuantizer):
64
+ Logger.error(
65
+ f'activation quantizer in KerasActivationQuantizationHolder'
66
+ f' must be a BaseInferableQuantizer object but has a '
67
+ f'{type(layer.activation_holder_quantizer)} object') # pragma: no cover
76
68
 
77
69
  return True
78
70
  else:
@@ -13,20 +13,23 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
-
17
- from model_compression_toolkit import quantizers_infrastructure as qi
16
+ from typing import Union, Callable
18
17
  from model_compression_toolkit.core import common
19
18
  from model_compression_toolkit.core.common import Graph
20
19
  from model_compression_toolkit.constants import FOUND_TORCH
21
20
  from model_compression_toolkit.logger import Logger
21
+ from model_compression_toolkit.core.common import BaseNode
22
22
 
23
23
  if FOUND_TORCH:
24
24
  import torch
25
+ from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
25
26
  from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
26
27
  from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizers import \
27
28
  get_quantization_quantizers
28
29
 
29
- def fully_quantized_wrapper(node: common.BaseNode, module: torch.nn.Module) -> qi.PytorchQuantizationWrapper:
30
+
31
+ def fully_quantized_wrapper(node: common.BaseNode,
32
+ module: torch.nn.Module) -> Union[torch.nn.Module,PytorchQuantizationWrapper]:
30
33
  """
31
34
  A function which takes a computational graph node and a pytorch module and
32
35
  perform the quantization wrapping
@@ -34,14 +37,32 @@ if FOUND_TORCH:
34
37
  Args:
35
38
  node: A node of mct graph.
36
39
  module: A Pytorch module
37
-
38
40
  Returns: Wrapped layer
39
41
 
40
42
  """
41
- weight_quantizers, activation_quantizers = get_quantization_quantizers(node)
42
- wrapped_layer = qi.PytorchQuantizationWrapper(module, weight_quantizers, activation_quantizers)
43
- return wrapped_layer
43
+ weight_quantizers, _ = get_quantization_quantizers(node)
44
+ if len(weight_quantizers) > 0:
45
+ return PytorchQuantizationWrapper(module, weight_quantizers)
46
+ return module
44
47
 
48
+ def get_activation_quantizer_holder(node: BaseNode) -> Callable:
49
+ """
50
+ Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node.
51
+ If the layer is not supposed to be wrapped with an activation quantizer - return None.
52
+ Args:
53
+ node: Node to attach a PytorchActivationQuantizationHolder to its output.
54
+ Returns:
55
+ A PytorchActivationQuantizationHolder module for the node's activation quantization.
56
+ """
57
+ _, activation_quantizers = get_quantization_quantizers(node)
58
+ # Holder by definition uses a single quantizer for the activation quantization
59
+ # thus we make sure this is the only possible case (unless it's a node we no activation
60
+ # quantization, which in this case has an empty list).
61
+ if len(activation_quantizers) == 1:
62
+ return PytorchActivationQuantizationHolder(activation_quantizers[0])
63
+ Logger.error(
64
+ f'PytorchActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
65
+ f'were found for node {node}')
45
66
 
46
67
  def get_exportable_pytorch_model(graph: Graph):
47
68
  """
@@ -54,7 +75,9 @@ if FOUND_TORCH:
54
75
  Fully quantized PyTorch model.
55
76
  """
56
77
  return PyTorchModelBuilder(graph=graph,
57
- wrapper=fully_quantized_wrapper).build_model()
78
+ wrapper=fully_quantized_wrapper,
79
+ get_activation_quantizer_holder_fn=get_activation_quantizer_holder).build_model()
80
+
58
81
  else:
59
82
  def get_exportable_pytorch_model(*args, **kwargs): # pragma: no cover
60
83
  Logger.error('Installing torch is mandatory '
@@ -20,11 +20,11 @@ from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RA
20
20
  SCALE_PER_CHANNEL, CLUSTER_CENTERS
21
21
  from model_compression_toolkit.logger import Logger
22
22
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
- from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
24
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
25
- get_inferable_quantizer_class
26
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
27
- constants as qi_inferable_quantizers_constants, BasePyTorchInferableQuantizer
23
+ from mct_quantizers import QuantizationTarget
24
+ from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
25
+ from mct_quantizers import \
26
+ constants as qi_inferable_quantizers_constants
27
+ from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer
28
28
  import numpy as np
29
29
 
30
30
 
@@ -17,10 +17,13 @@ from typing import Any
17
17
  from model_compression_toolkit.logger import Logger
18
18
  from model_compression_toolkit.constants import FOUND_TORCH
19
19
 
20
+
20
21
  if FOUND_TORCH:
21
- from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
22
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
23
- BasePyTorchInferableQuantizer
22
+ import torch.nn as nn
23
+ from mct_quantizers import PytorchQuantizationWrapper
24
+ from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer
25
+ from mct_quantizers.pytorch.activation_quantization_holder import PytorchActivationQuantizationHolder
26
+
24
27
  def is_pytorch_layer_exportable(layer: Any) -> bool:
25
28
  """
26
29
  Check whether a torch Module is a valid exportable module or not.
@@ -31,12 +34,35 @@ if FOUND_TORCH:
31
34
  Returns:
32
35
  Check whether a PyTorch layer is a valid exportable layer or not.
33
36
  """
37
+ if not isinstance(layer, nn.Module):
38
+ Logger.error(f'Exportable layer must be a nn.Module layer, but layer {layer.name} is of type {type(layer)}') # pragma: no cover
39
+
34
40
  if isinstance(layer, PytorchQuantizationWrapper):
35
- quantizers = list(layer.weights_quantizers.values())
36
- quantizers.extend(layer.activation_quantizers)
37
- if all([isinstance(q, BasePyTorchInferableQuantizer) for q in quantizers]):
38
- return True
39
- return False
41
+ valid_weights_quantizers = isinstance(layer.weights_quantizers, dict)
42
+ if not valid_weights_quantizers:
43
+ Logger.error(
44
+ f'PytorchQuantizationWrapper must have a weights_quantizers but has a '
45
+ f'{type(layer.weights_quantizers)} object') # pragma: no cover
46
+
47
+ if len(layer.weights_quantizers) == 0:
48
+ Logger.error(f'PytorchQuantizationWrapper must have at least one weight quantizer, but found {len(layer.weights_quantizers)} quantizers.'
49
+ f'If layer is not quantized it should be a Keras layer.')
50
+
51
+ for _, weights_quantizer in layer.weights_quantizers.items():
52
+ if not isinstance(weights_quantizer, BasePyTorchInferableQuantizer):
53
+ Logger.error(
54
+ f'weights_quantizer must be a BasePyTorchInferableQuantizer object but has a '
55
+ f'{type(weights_quantizer)} object') # pragma: no cover
56
+
57
+ elif isinstance(layer, PytorchActivationQuantizationHolder):
58
+ if not isinstance(layer.activation_holder_quantizer, BasePyTorchInferableQuantizer):
59
+ Logger.error(
60
+ f'activation quantizer in PytorchActivationQuantizationHolder'
61
+ f' must be a BasePyTorchInferableQuantizer object but has a '
62
+ f'{type(layer.activation_holder_quantizer)} object') # pragma: no cover
63
+
64
+ return True
65
+
40
66
  else:
41
67
  def is_pytorch_layer_exportable(*args, **kwargs): # pragma: no cover
42
68
  Logger.error('Installing torch is mandatory '
@@ -26,8 +26,7 @@ from model_compression_toolkit.core.keras.back2framework.keras_model_builder imp
26
26
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
27
27
  from model_compression_toolkit.gptq.keras.quantizer.quantization_builder import quantization_builder
28
28
  from model_compression_toolkit.logger import Logger
29
- from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
30
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.activation_quantization_holder import ActivationQuantizationHolder
29
+ from mct_quantizers import KerasQuantizationWrapper, KerasActivationQuantizationHolder
31
30
 
32
31
  if version.parse(tf.__version__) < version.parse("2.6"):
33
32
  from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
@@ -45,7 +44,6 @@ from model_compression_toolkit.core.common.framework_implementation import Frame
45
44
  import numpy as np
46
45
  import copy
47
46
  from model_compression_toolkit.core.keras.constants import BIAS, USE_BIAS
48
- from model_compression_toolkit import quantizers_infrastructure as qi
49
47
 
50
48
 
51
49
  class KerasGPTQTrainer(GPTQTrainer):
@@ -133,7 +131,7 @@ class KerasGPTQTrainer(GPTQTrainer):
133
131
 
134
132
  def gptq_wrapper(self,
135
133
  n: common.BaseNode,
136
- layer: Layer) -> Union[qi.KerasQuantizationWrapper, Layer]:
134
+ layer: Layer) -> Union[KerasQuantizationWrapper, Layer]:
137
135
  """
138
136
  A function which takes a computational graph node and a keras layer and perform the quantization wrapping.
139
137
 
@@ -145,22 +143,23 @@ class KerasGPTQTrainer(GPTQTrainer):
145
143
 
146
144
  """
147
145
  if self._is_gptq_weights_trainable(n):
148
- weights_quantizers, _ = quantization_builder(n, self.gptq_config) # TODO: split quantizers building into two functions: for weights and activations
149
- return qi.KerasQuantizationWrapper(layer,
150
- weights_quantizers=weights_quantizers)
151
- else:
152
- return layer
153
-
154
- def get_activation_quantizer_holder(self, n: common.BaseNode) -> Union[None, Callable]:
146
+ weights_quantizers, _ = quantization_builder(n,
147
+ self.gptq_config) # TODO: split quantizers building into two functions: for weights and activations
148
+ if len(weights_quantizers) > 0:
149
+ return KerasQuantizationWrapper(layer,
150
+ weights_quantizers=weights_quantizers)
151
+ return layer
152
+
153
+ def get_activation_quantizer_holder(self, n: common.BaseNode) -> Callable:
155
154
  """
156
- Retrieve a ActivationQuantizationHolder layer to use for activation quantization for a node.
155
+ Retrieve a KerasActivationQuantizationHolder layer to use for activation quantization for a node.
157
156
  If the layer is not supposed to be wrapped with activation quantizers - return None.
158
157
 
159
158
  Args:
160
- n: Node to get ActivationQuantizationHolder to attach in its output.
159
+ n: Node to get KerasActivationQuantizationHolder to attach in its output.
161
160
 
162
161
  Returns:
163
- A ActivationQuantizationHolder layer for the node activation quantization.
162
+ A KerasActivationQuantizationHolder layer for the node activation quantization.
164
163
  """
165
164
  _, activation_quantizers = quantization_builder(n, self.gptq_config) # TODO: split quantizers building into two functions: for weights and activations
166
165
 
@@ -168,10 +167,10 @@ class KerasGPTQTrainer(GPTQTrainer):
168
167
  # thus we make sure this is the only possible case (unless it's a node with no activation
169
168
  # quantization, which in this case has an empty list).
170
169
  if len(activation_quantizers) == 1:
171
- return ActivationQuantizationHolder(activation_quantizers[0])
170
+ return KerasActivationQuantizationHolder(activation_quantizers[0])
172
171
 
173
172
  Logger.error(
174
- f'ActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
173
+ f'KerasActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
175
174
  f'were found for node {n}')
176
175
 
177
176
 
@@ -21,8 +21,8 @@ from tensorflow.keras.models import Model
21
21
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
22
22
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
23
23
  from model_compression_toolkit.logger import Logger
24
- from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
25
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
24
+ from mct_quantizers import KerasQuantizationWrapper
25
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
26
26
 
27
27
 
28
28
  def get_gptq_trainable_parameters(fxp_model: Model,
@@ -19,15 +19,14 @@ from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.constants import FOUND_TF
20
20
  from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
21
21
 
22
- from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
23
- TrainableQuantizerActivationConfig
24
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer
22
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
23
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer
25
24
 
26
25
  if FOUND_TF:
27
26
  import tensorflow as tf
28
27
 
29
- from model_compression_toolkit.quantizers_infrastructure import BaseKerasTrainableQuantizer, \
30
- KerasQuantizationWrapper
28
+ from model_compression_toolkit.trainable_infrastructure import BaseKerasTrainableQuantizer
29
+ from mct_quantizers import KerasQuantizationWrapper
31
30
 
32
31
  class BaseKerasGPTQTrainableQuantizer(BaseKerasTrainableQuantizer):
33
32
  """
@@ -21,14 +21,12 @@ from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quan
21
21
  get_inferable_quantizer_kwargs
22
22
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
23
23
  from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
24
- from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
25
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
26
- get_inferable_quantizer_class
27
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import \
28
- BaseKerasInferableQuantizer
29
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizer_config import \
24
+ from mct_quantizers import QuantizationTarget
25
+ from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
26
+ from mct_quantizers.keras.quantizers import BaseKerasInferableQuantizer
27
+ from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
30
28
  get_trainable_quantizer_weights_config
31
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizers import \
29
+ from model_compression_toolkit.trainable_infrastructure.common.get_quantizers import \
32
30
  get_trainable_quantizer_class
33
31
 
34
32
 
@@ -19,7 +19,7 @@ from keras import Model
19
19
 
20
20
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
21
21
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
22
- from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
22
+ from mct_quantizers import KerasQuantizationWrapper
23
23
 
24
24
 
25
25
  class LinearTempDecay:
@@ -17,9 +17,9 @@ import tensorflow as tf
17
17
  import numpy as np
18
18
 
19
19
  from model_compression_toolkit.gptq import RoundingType
20
- from model_compression_toolkit import quantizers_infrastructure as qi
21
20
  from model_compression_toolkit.core.common import max_power_of_two
22
21
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
+ from mct_quantizers import QuantizationTarget
23
23
  from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
24
24
  SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
25
25
  from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
@@ -27,11 +27,11 @@ from typing import Dict, Any
27
27
  from model_compression_toolkit.constants import THRESHOLD, MIN_THRESHOLD
28
28
  from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
29
29
  from model_compression_toolkit.gptq.keras.quantizer.quant_utils import power_of_two_max, clip, calculate_delta
30
- from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
31
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
32
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
30
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig
31
+ from mct_quantizers import mark_quantizer
32
+ from model_compression_toolkit.trainable_infrastructure.common.quant_utils import \
33
33
  get_threshold_reshape_shape
34
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
34
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
35
35
 
36
36
 
37
37
  def soft_rounding_symmetric_quantizer(input_tensor: tf.Tensor,
@@ -66,7 +66,7 @@ def soft_rounding_symmetric_quantizer(input_tensor: tf.Tensor,
66
66
  return delta * clip(tensor_q, max_val=max_int, min_val=min_int)
67
67
 
68
68
 
69
- @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
69
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
70
70
  quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
71
71
  quantizer_type=RoundingType.SoftQuantizer)
72
72
  class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
@@ -17,20 +17,20 @@ import tensorflow as tf
17
17
  import numpy as np
18
18
 
19
19
  from model_compression_toolkit.gptq import RoundingType
20
- from model_compression_toolkit import quantizers_infrastructure as qi
21
- from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
20
+ from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
22
21
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
+ from mct_quantizers import QuantizationTarget
23
23
  from model_compression_toolkit.gptq.common.gptq_constants import \
24
24
  SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
25
25
  from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
26
26
  from typing import Dict, Any
27
27
  from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX
28
28
  from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
29
- from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
30
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
31
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
29
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig
30
+ from mct_quantizers import mark_quantizer
31
+ from model_compression_toolkit.trainable_infrastructure.common.quant_utils import \
32
32
  get_threshold_reshape_shape
33
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
33
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
34
34
 
35
35
 
36
36
  def soft_rounding_uniform_quantizer(input_tensor: tf.Tensor,
@@ -61,7 +61,7 @@ def soft_rounding_uniform_quantizer(input_tensor: tf.Tensor,
61
61
  max_val=2 ** num_bits - 1) + min_range
62
62
 
63
63
 
64
- @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
64
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
65
65
  quantization_method=[QuantizationMethod.UNIFORM],
66
66
  quantizer_type=RoundingType.SoftQuantizer)
67
67
  class UniformSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
@@ -19,18 +19,18 @@ import numpy as np
19
19
  import tensorflow as tf
20
20
 
21
21
  from model_compression_toolkit.gptq import RoundingType
22
- from model_compression_toolkit import quantizers_infrastructure as qi
23
22
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
+ from mct_quantizers import QuantizationTarget
24
24
  from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD
25
25
  from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
26
26
  from model_compression_toolkit.constants import THRESHOLD
27
27
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
28
28
  from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
29
- from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
30
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
31
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
29
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig
30
+ from mct_quantizers import mark_quantizer
31
+ from model_compression_toolkit.trainable_infrastructure.common.quant_utils import \
32
32
  get_threshold_reshape_shape
33
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
33
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
34
34
 
35
35
 
36
36
  def pertubation_symmetric_quantizer(input_tensor: tf.Tensor,
@@ -67,7 +67,7 @@ def pertubation_symmetric_quantizer(input_tensor: tf.Tensor,
67
67
  return delta * qutils.ste_clip(tensor_q, max_val=max_int, min_val=min_int)
68
68
 
69
69
 
70
- @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
70
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
71
71
  quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
72
72
  quantizer_type=RoundingType.STE)
73
73
  class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):
@@ -32,9 +32,8 @@ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, set_mo
32
32
  from model_compression_toolkit.gptq.pytorch.graph_info import get_gptq_trainable_parameters, \
33
33
  get_weights_for_loss
34
34
  from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder
35
- from model_compression_toolkit import quantizers_infrastructure as qi
36
35
  from model_compression_toolkit.gptq.pytorch.quantizer.regularization_factory import get_regularization
37
- from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
36
+ from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
38
37
 
39
38
 
40
39
  class PytorchGPTQTrainer(GPTQTrainer):
@@ -90,8 +89,8 @@ class PytorchGPTQTrainer(GPTQTrainer):
90
89
 
91
90
  self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
92
91
 
93
- def _is_gptq_applicable(self,
94
- node: BaseNode) -> bool:
92
+ def _is_gptq_weights_trainable(self,
93
+ node: BaseNode) -> bool:
95
94
  """
96
95
  A function for deciding if a layer should be fine-tuned during GPTQ.
97
96
  Args:
@@ -105,7 +104,9 @@ class PytorchGPTQTrainer(GPTQTrainer):
105
104
  f"without a kernel isn't supported.")
106
105
  return node.is_weights_quantization_enabled()
107
106
 
108
- def gptq_wrapper(self, n: BaseNode, layer: Module) -> Union[qi.PytorchQuantizationWrapper, Module]:
107
+ def gptq_wrapper(self,
108
+ n: BaseNode,
109
+ layer: Module) -> Union[PytorchQuantizationWrapper, Module]:
109
110
  """
110
111
  A function which takes a computational graph node and a pytorch layer and perform the quantization wrapping.
111
112
 
@@ -116,14 +117,32 @@ class PytorchGPTQTrainer(GPTQTrainer):
116
117
  Returns: Wrapped layer if the layer should be wrap, otherwise returns the layer as is.
117
118
  """
118
119
 
119
- if self._is_gptq_applicable(n):
120
+ if self._is_gptq_weights_trainable(n):
120
121
  weights_quantizers, activation_quantizers = quantization_builder(n, self.gptq_config)
121
- return qi.PytorchQuantizationWrapper(layer,
122
- weights_quantizers=weights_quantizers,
123
- activation_quantizers=activation_quantizers)
122
+ return PytorchQuantizationWrapper(layer,
123
+ weights_quantizers=weights_quantizers)
124
124
  else:
125
125
  return layer
126
126
 
127
+ def get_activation_quantizer_holder(self, n: BaseNode) -> Callable:
128
+ """
129
+ Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node.
130
+ If the layer is not supposed to be wrapped with an activation quantizer - return None.
131
+ Args:
132
+ n: Node to attach a PytorchActivationQuantizationHolder to its output.
133
+ Returns:
134
+ A PytorchActivationQuantizationHolder module for the node's activation quantization.
135
+ """
136
+ _, activation_quantizers = quantization_builder(n, self.gptq_config)
137
+ # Holder by definition uses a single quantizer for the activation quantization
138
+ # thus we make sure this is the only possible case (unless it's a node we no activation
139
+ # quantization, which in this case has an empty list).
140
+ if len(activation_quantizers) == 1:
141
+ return PytorchActivationQuantizationHolder(activation_quantizers[0])
142
+ Logger.error(
143
+ f'PytorchActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
144
+ f'were found for node {n}')
145
+
127
146
  def build_gptq_model(self):
128
147
  """
129
148
  Build the GPTQ model with QuantizationWrappers
@@ -134,7 +153,8 @@ class PytorchGPTQTrainer(GPTQTrainer):
134
153
  append2output=self.compare_points,
135
154
  fw_info=self.fw_info,
136
155
  wrapper=self.gptq_wrapper,
137
- return_float_outputs=True).build_model()
156
+ return_float_outputs=True,
157
+ get_activation_quantizer_holder_fn=self.get_activation_quantizer_holder).build_model()
138
158
 
139
159
  return gptq_model, gptq_user_info
140
160
 
@@ -18,8 +18,9 @@ from typing import List
18
18
  from model_compression_toolkit.core.pytorch.constants import BIAS
19
19
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
20
20
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
21
- from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
22
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
21
+ from model_compression_toolkit.logger import Logger
22
+ from mct_quantizers import PytorchQuantizationWrapper
23
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
23
24
 
24
25
 
25
26
  def get_gptq_trainable_parameters(fxp_model: nn.Module,
@@ -46,6 +47,8 @@ def get_gptq_trainable_parameters(fxp_model: nn.Module,
46
47
  fw_info=DEFAULT_PYTORCH_INFO)
47
48
 
48
49
  # collect trainable weights per quantizer
50
+ if kernel_attribute not in layer.weights_quantizers:
51
+ Logger.error(f'{kernel_attribute} was not found in weight quantizers of layer {layer.layer}')
49
52
  quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
50
53
  quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
51
54
  trainable_aux_weights.extend(quantizer_trainable_weights)
@@ -53,7 +53,8 @@ if FOUND_TORCH:
53
53
  optimizer: Optimizer = Adam([torch.Tensor([])], lr=LR_DEFAULT),
54
54
  optimizer_rest: Optimizer = Adam([torch.Tensor([])], lr=LR_REST_DEFAULT),
55
55
  loss: Callable = multiple_tensors_mse_loss,
56
- log_function: Callable = None) -> GradientPTQConfigV2:
56
+ log_function: Callable = None,
57
+ use_hessian_based_weights: bool = True) -> GradientPTQConfigV2:
57
58
  """
58
59
  Create a GradientPTQConfigV2 instance for Pytorch models.
59
60
 
@@ -63,6 +64,7 @@ if FOUND_TORCH:
63
64
  optimizer_rest (Optimizer): Pytorch optimizer to use for fine-tuning of the bias variable.
64
65
  loss (Callable): loss to use during fine-tuning. should accept 4 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors, the 3rd is a list of quantized weights and the 4th is a list of float weights.
65
66
  log_function (Callable): Function to log information about the gptq process.
67
+ use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
66
68
 
67
69
  returns:
68
70
  a GradientPTQConfigV2 object to use when fine-tuning the quantized model using gptq.
@@ -84,7 +86,7 @@ if FOUND_TORCH:
84
86
  """
85
87
  bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
86
88
  return GradientPTQConfigV2(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
87
- log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer)
89
+ log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer, use_hessian_based_weights=use_hessian_based_weights)
88
90
 
89
91
 
90
92
  def pytorch_gradient_post_training_quantization_experimental(model: Module,
@@ -19,16 +19,16 @@ from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.constants import FOUND_TORCH
20
20
  from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
21
21
 
22
- from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
22
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
23
23
  TrainableQuantizerActivationConfig
24
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import \
24
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import \
25
25
  BaseTrainableQuantizer
26
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.pytorch.base_pytorch_quantizer import \
26
+ from model_compression_toolkit.trainable_infrastructure.pytorch.base_pytorch_quantizer import \
27
27
  BasePytorchTrainableQuantizer
28
28
 
29
29
  if FOUND_TORCH:
30
30
  from torch import Tensor
31
- from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
31
+ from mct_quantizers import PytorchQuantizationWrapper
32
32
 
33
33
  class BasePytorchGPTQTrainableQuantizer(BasePytorchTrainableQuantizer):
34
34
  """