mct-nightly 1.8.0.20052023.post401__py3-none-any.whl → 1.8.0.20230610.post356__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/METADATA +10 -7
  2. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/RECORD +68 -115
  3. model_compression_toolkit/__init__.py +23 -3
  4. model_compression_toolkit/core/common/framework_info.py +1 -1
  5. model_compression_toolkit/core/keras/back2framework/instance_builder.py +16 -9
  6. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +8 -34
  7. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +5 -1
  8. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +103 -28
  9. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +39 -44
  10. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_tflite_exporter.py +1 -1
  11. model_compression_toolkit/exporter/model_exporter/keras/int8_tflite_exporter.py +20 -18
  12. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +3 -3
  13. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  14. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +36 -9
  15. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +4 -4
  16. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +24 -32
  17. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +31 -8
  18. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +5 -5
  19. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +34 -8
  20. model_compression_toolkit/gptq/keras/gptq_training.py +15 -16
  21. model_compression_toolkit/gptq/keras/graph_info.py +2 -2
  22. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +4 -5
  23. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +5 -7
  24. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -1
  25. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +6 -6
  26. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +7 -7
  27. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +6 -6
  28. model_compression_toolkit/gptq/pytorch/gptq_training.py +30 -10
  29. model_compression_toolkit/gptq/pytorch/graph_info.py +5 -2
  30. model_compression_toolkit/gptq/pytorch/quantization_facade.py +4 -2
  31. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +4 -4
  32. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +5 -7
  33. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -1
  34. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +7 -7
  35. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +7 -8
  36. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +7 -8
  37. model_compression_toolkit/qat/common/__init__.py +2 -1
  38. model_compression_toolkit/qat/common/qat_config.py +2 -2
  39. model_compression_toolkit/qat/keras/quantization_facade.py +18 -8
  40. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +1 -1
  41. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +11 -11
  42. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +11 -12
  43. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +12 -13
  44. model_compression_toolkit/qat/pytorch/quantization_facade.py +27 -16
  45. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  46. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +31 -4
  47. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +10 -9
  48. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +11 -10
  49. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +2 -1
  50. model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +1 -25
  51. model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py → trainable_infrastructure/__init__.py} +3 -10
  52. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/base_trainable_quantizer.py +3 -3
  53. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/get_quantizer_config.py +1 -1
  54. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/get_quantizers.py +3 -3
  55. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/base_keras_quantizer.py +4 -4
  56. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/config_serialization.py +2 -2
  57. model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure → trainable_infrastructure}/keras/load_model.py +16 -23
  58. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/pytorch/base_pytorch_quantizer.py +3 -3
  59. model_compression_toolkit/quantizers_infrastructure/__init__.py +0 -23
  60. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +0 -87
  61. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +0 -46
  62. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +0 -31
  63. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +0 -53
  64. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +0 -49
  65. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +0 -147
  66. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +0 -345
  67. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +0 -85
  68. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +0 -27
  69. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +0 -14
  70. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +0 -148
  71. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +0 -65
  72. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +0 -86
  73. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +0 -111
  74. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +0 -56
  75. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +0 -14
  76. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +0 -79
  77. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +0 -179
  78. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +0 -67
  79. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +0 -87
  80. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +0 -163
  81. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +0 -66
  82. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +0 -14
  83. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +0 -269
  84. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +0 -152
  85. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +0 -35
  86. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/__init__.py +0 -14
  87. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +0 -96
  88. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +0 -62
  89. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +0 -83
  90. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +0 -100
  91. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +0 -95
  92. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +0 -48
  93. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +0 -70
  94. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +0 -57
  95. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +0 -26
  96. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +0 -14
  97. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +0 -77
  98. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +0 -106
  99. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +0 -66
  100. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +0 -104
  101. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +0 -109
  102. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +0 -14
  103. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +0 -14
  104. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +0 -14
  105. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +0 -14
  106. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/LICENSE.md +0 -0
  107. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/WHEEL +0 -0
  108. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/top_level.txt +0 -0
  109. /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure → trainable_infrastructure/common}/__init__.py +0 -0
  110. /model_compression_toolkit/{quantizers_infrastructure → trainable_infrastructure/common}/constants.py +0 -0
  111. /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/quant_utils.py +0 -0
  112. /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/trainable_quantizer_config.py +0 -0
  113. /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/common → trainable_infrastructure/keras}/__init__.py +0 -0
  114. /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/quantizer_utils.py +0 -0
  115. /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/keras → trainable_infrastructure/pytorch}/__init__.py +0 -0
@@ -21,15 +21,13 @@ from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_qu
21
21
  get_activation_inferable_quantizer_kwargs
22
22
  from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
23
23
  BasePytorchGPTQTrainableQuantizer
24
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
25
- get_inferable_quantizer_class
26
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
27
- BasePyTorchInferableQuantizer
28
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizer_config import \
24
+ from mct_quantizers import QuantizationTarget
25
+ from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
26
+ from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer
27
+ from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
29
28
  get_trainable_quantizer_weights_config
30
29
  from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
31
- from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
32
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizers import \
30
+ from model_compression_toolkit.trainable_infrastructure.common.get_quantizers import \
33
31
  get_trainable_quantizer_class
34
32
 
35
33
 
@@ -21,7 +21,7 @@ from torch import nn
21
21
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
22
22
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
23
23
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
24
- from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
24
+ from mct_quantizers import PytorchQuantizationWrapper
25
25
 
26
26
 
27
27
  class LinearTempDecay:
@@ -18,8 +18,8 @@ from typing import Dict
18
18
  import numpy as np
19
19
 
20
20
  from model_compression_toolkit.core.common import max_power_of_two
21
- from model_compression_toolkit import quantizers_infrastructure as qi
22
21
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
+ from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
23
23
  from model_compression_toolkit.gptq.common.gptq_config import RoundingType
24
24
  from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
25
25
  BasePytorchGPTQTrainableQuantizer
@@ -28,11 +28,11 @@ from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as quti
28
28
  from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
29
29
  SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
30
30
  from model_compression_toolkit.constants import THRESHOLD, MIN_THRESHOLD
31
- from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
32
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
33
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
31
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig
32
+ from mct_quantizers import mark_quantizer
33
+ from model_compression_toolkit.trainable_infrastructure.common.quant_utils import \
34
34
  get_threshold_reshape_shape
35
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
35
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
36
36
 
37
37
 
38
38
  def soft_rounding_symmetric_quantizer(input_tensor: torch.Tensor,
@@ -68,7 +68,7 @@ def soft_rounding_symmetric_quantizer(input_tensor: torch.Tensor,
68
68
  max_val=int_threshold - 1)
69
69
 
70
70
 
71
- @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
71
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
72
72
  quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
73
73
  quantizer_type=RoundingType.SoftQuantizer)
74
74
  class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
@@ -110,7 +110,7 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
110
110
  def initialize_quantization(self,
111
111
  tensor_shape: torch.Size,
112
112
  name: str,
113
- layer: qi.PytorchQuantizationWrapper):
113
+ layer: PytorchQuantizationWrapper):
114
114
  """
115
115
  Add quantizer parameters to the quantizer parameters dictionary
116
116
 
@@ -17,9 +17,9 @@ import torch.nn as nn
17
17
  from typing import Dict
18
18
  import numpy as np
19
19
 
20
- from model_compression_toolkit import quantizers_infrastructure as qi
21
- from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
20
+ from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
22
21
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
+ from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
23
23
  from model_compression_toolkit.gptq.common.gptq_config import RoundingType
24
24
  from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
25
25
  BasePytorchGPTQTrainableQuantizer
@@ -27,10 +27,9 @@ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_
27
27
  from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
28
28
  from model_compression_toolkit.gptq.common.gptq_constants import SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
29
29
  from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import fix_range_to_include_zero
30
- from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
31
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
32
- mark_quantizer
33
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import \
30
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig
31
+ from mct_quantizers import mark_quantizer
32
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import \
34
33
  VariableGroup
35
34
  from model_compression_toolkit.constants import RANGE_MAX, RANGE_MIN
36
35
 
@@ -63,7 +62,7 @@ def soft_rounding_unifrom_quantizer(input_tensor: torch.Tensor,
63
62
  max_val=2 ** num_bits - 1) + min_range
64
63
 
65
64
 
66
- @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
65
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
67
66
  quantization_method=[QuantizationMethod.UNIFORM],
68
67
  quantizer_type=RoundingType.SoftQuantizer)
69
68
  class UniformSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
@@ -100,7 +99,7 @@ class UniformSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
100
99
  def initialize_quantization(self,
101
100
  tensor_shape: torch.Size,
102
101
  name: str,
103
- layer: qi.PytorchQuantizationWrapper):
102
+ layer: PytorchQuantizationWrapper):
104
103
  """
105
104
  Add quantizer parameters to the quantizer parameters dictionary
106
105
 
@@ -18,8 +18,8 @@ from typing import Dict
18
18
  import numpy as np
19
19
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
20
20
 
21
- from model_compression_toolkit import quantizers_infrastructure as qi
22
21
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
+ from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
23
23
  from model_compression_toolkit.gptq.common.gptq_config import RoundingType
24
24
  from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
25
25
  BasePytorchGPTQTrainableQuantizer
@@ -27,11 +27,10 @@ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_
27
27
  from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
28
28
  from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD, MAX_LSB_CHANGE
29
29
  from model_compression_toolkit.constants import THRESHOLD
30
- from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
31
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
32
- mark_quantizer
33
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
34
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
30
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig
31
+ from mct_quantizers import mark_quantizer
32
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
33
+ from model_compression_toolkit.trainable_infrastructure.common.quant_utils import \
35
34
  get_threshold_reshape_shape
36
35
 
37
36
 
@@ -75,7 +74,7 @@ def pertubation_symmetric_quantizer(input_tensor: torch.Tensor,
75
74
  return delta * qutils.ste_clip(tensor_q, max_val=max_int, min_val=min_int)
76
75
 
77
76
 
78
- @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
77
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
79
78
  quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
80
79
  quantizer_type=RoundingType.STE)
81
80
  class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
@@ -109,7 +108,7 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
109
108
  def initialize_quantization(self,
110
109
  tensor_shape: torch.Size,
111
110
  name: str,
112
- layer: qi.PytorchQuantizationWrapper):
111
+ layer: PytorchQuantizationWrapper):
113
112
  """
114
113
  Add quantizer parameters to the quantizer parameters dictionary
115
114
 
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from model_compression_toolkit.quantizers_infrastructure.constants import THRESHOLD_TENSOR, WEIGHTS_QUANTIZATION_PARAMS
15
+ from model_compression_toolkit.trainable_infrastructure.common.constants import THRESHOLD_TENSOR, \
16
+ WEIGHTS_QUANTIZATION_PARAMS
@@ -20,8 +20,8 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
20
20
  from model_compression_toolkit.logger import Logger
21
21
 
22
22
 
23
- def _is_qat_applicable(node: common.BaseNode,
24
- fw_info: FrameworkInfo) -> bool:
23
+ def is_qat_applicable(node: common.BaseNode,
24
+ fw_info: FrameworkInfo) -> bool:
25
25
  """
26
26
  A function for deciding if a layer should be fine-tuned during QAT
27
27
  Args:
@@ -22,7 +22,7 @@ from model_compression_toolkit.constants import FOUND_TF
22
22
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
23
23
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
24
24
  MixedPrecisionQuantizationConfigV2
25
- from model_compression_toolkit.quantizers_infrastructure import ActivationQuantizationHolder
25
+ from mct_quantizers import KerasActivationQuantizationHolder, KerasQuantizationWrapper
26
26
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
27
27
  from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
28
28
  from model_compression_toolkit.ptq.runner import ptq_runner
@@ -40,20 +40,18 @@ if FOUND_TF:
40
40
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
41
41
 
42
42
  from model_compression_toolkit import get_target_platform_capabilities
43
- from model_compression_toolkit import quantizers_infrastructure as qi
44
43
 
45
44
  from model_compression_toolkit import get_target_platform_capabilities
46
45
  from model_compression_toolkit.core import common
47
46
  from model_compression_toolkit.core.common import BaseNode
48
47
  from model_compression_toolkit.constants import TENSORFLOW
49
48
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
50
- from model_compression_toolkit.qat.common.qat_config import _is_qat_applicable
49
+ from model_compression_toolkit.qat.common.qat_config import is_qat_applicable
51
50
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
52
51
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
53
52
  from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder, \
54
53
  get_activation_quantizer_holder
55
54
  from model_compression_toolkit.qat.common.qat_config import QATConfig
56
- from model_compression_toolkit import quantizers_infrastructure as qi
57
55
 
58
56
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
59
57
 
@@ -71,9 +69,12 @@ if FOUND_TF:
71
69
  Returns: Wrapped layer
72
70
 
73
71
  """
74
- if _is_qat_applicable(n, DEFAULT_KERAS_INFO):
75
- weights_quantizers, activation_quantizers = quantization_builder(n, qat_config, DEFAULT_KERAS_INFO)
76
- return qi.KerasQuantizationWrapper(layer, weights_quantizers)
72
+ if is_qat_applicable(n, DEFAULT_KERAS_INFO):
73
+ weights_quantizers, _ = quantization_builder(n,
74
+ qat_config,
75
+ DEFAULT_KERAS_INFO)
76
+ if len(weights_quantizers) > 0:
77
+ return KerasQuantizationWrapper(layer, weights_quantizers)
77
78
  return layer
78
79
 
79
80
 
@@ -255,8 +256,17 @@ if FOUND_TF:
255
256
 
256
257
  """
257
258
  def _export(layer):
258
- if isinstance(layer, (qi.KerasQuantizationWrapper, ActivationQuantizationHolder)):
259
+ if isinstance(layer, KerasQuantizationWrapper):
259
260
  layer.convert_to_inferable_quantizers()
261
+ # In the KerasActivationQuantizationHolder case - converting the quantizers only
262
+ # is not enough. We need to create a new layer with inferable quantizers. The reason for that
263
+ # is that if we only convert the quantizers, the layer will have some weights (such as min, max,
264
+ # threshold) that do not match the configuration, thus loading such a model will fail.
265
+ # To overcome this, the convert_to_inferable_quantizers of KerasActivationQuantizationHolder
266
+ # creates a new layer from its new configuration after converting the trainable quantizer
267
+ # to an inferable quantizer.
268
+ elif isinstance(layer, KerasActivationQuantizationHolder):
269
+ layer = layer.convert_to_inferable_quantizers()
260
270
  return layer
261
271
 
262
272
  # clone each layer in the model and apply _export to layers with TrainableQuantizeWrappers
@@ -17,7 +17,7 @@ from typing import Union
17
17
  from model_compression_toolkit.logger import Logger
18
18
  from model_compression_toolkit.constants import FOUND_TF
19
19
 
20
- from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
20
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
21
21
  TrainableQuantizerActivationConfig, BaseKerasTrainableQuantizer
22
22
 
23
23
  if FOUND_TF:
@@ -12,34 +12,34 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Tuple, Dict, List, Union, Callable
15
+ from typing import Tuple, Dict, List, Callable
16
16
 
17
17
  from model_compression_toolkit.core import common
18
18
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
19
19
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
20
20
  from model_compression_toolkit.logger import Logger
21
- from model_compression_toolkit.qat.common.qat_config import QATConfig, _is_qat_applicable
21
+ from model_compression_toolkit.qat.common.qat_config import QATConfig
22
22
  from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
23
- from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget, ActivationQuantizationHolder
24
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizer_config import \
23
+ from mct_quantizers import QuantizationTarget, KerasActivationQuantizationHolder
24
+ from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
25
25
  get_trainable_quantizer_weights_config, get_trainable_quantizer_activation_config, \
26
26
  get_trainable_quantizer_quantization_candidates
27
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizers import \
27
+ from model_compression_toolkit.trainable_infrastructure.common.get_quantizers import \
28
28
  get_trainable_quantizer_class
29
29
 
30
30
 
31
31
  def get_activation_quantizer_holder(n: common.BaseNode,
32
- qat_config: QATConfig) -> Union[None, Callable]:
32
+ qat_config: QATConfig) -> Callable:
33
33
  """
34
- Retrieve a ActivationQuantizationHolder layer to use for activation quantization for a node.
34
+ Retrieve a KerasActivationQuantizationHolder layer to use for activation quantization for a node.
35
35
  If the layer is not supposed to be wrapped with activation quantizers - return None.
36
36
 
37
37
  Args:
38
- n: Node to get ActivationQuantizationHolder to attach in its output.
38
+ n: Node to get KerasActivationQuantizationHolder to attach in its output.
39
39
  qat_config: Configuration of QAT (such as training methods for example).
40
40
 
41
41
  Returns:
42
- A ActivationQuantizationHolder layer for the node activation quantization.
42
+ A KerasActivationQuantizationHolder layer for the node activation quantization.
43
43
  """
44
44
  _, activation_quantizers = quantization_builder(n,
45
45
  qat_config,
@@ -49,8 +49,8 @@ def get_activation_quantizer_holder(n: common.BaseNode,
49
49
  # thus we make sure this is the only possible case (unless it's a node with no activation
50
50
  # quantization, which in this case has an empty list).
51
51
  if len(activation_quantizers) == 1:
52
- return ActivationQuantizationHolder(activation_quantizers[0])
53
- Logger.error(f'ActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers were found for node {n}')
52
+ return KerasActivationQuantizationHolder(activation_quantizers[0])
53
+ Logger.error(f'KerasActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers were found for node {n}')
54
54
 
55
55
 
56
56
  def quantization_builder(n: common.BaseNode,
@@ -19,25 +19,24 @@ import numpy as np
19
19
  import tensorflow as tf
20
20
  from tensorflow.python.framework.tensor_shape import TensorShape
21
21
  from model_compression_toolkit.constants import SIGNED
22
- from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
22
+ from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
23
23
 
24
24
  from model_compression_toolkit.qat import TrainingMethod
25
25
 
26
26
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
27
+ from mct_quantizers import QuantizationTarget, mark_quantizer, KerasQuantizationWrapper
27
28
  from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
28
- from model_compression_toolkit import quantizers_infrastructure as qi, constants as C
29
+ from model_compression_toolkit import constants as C
29
30
 
30
31
  from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
31
- from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
32
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
32
33
  TrainableQuantizerActivationConfig
33
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
34
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import \
35
- WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer, ActivationPOTInferableQuantizer, \
36
- ActivationSymmetricInferableQuantizer
37
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
34
+ from mct_quantizers.keras.quantizers import WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer, \
35
+ ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer
36
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
38
37
 
39
38
 
40
- @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
39
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
41
40
  quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
42
41
  quantizer_type=TrainingMethod.STE)
43
42
  class STEWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
@@ -84,7 +83,7 @@ class STEWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
84
83
  def initialize_quantization(self,
85
84
  tensor_shape: TensorShape,
86
85
  name: str,
87
- layer: qi.KerasQuantizationWrapper):
86
+ layer: KerasQuantizationWrapper):
88
87
  """
89
88
  Add quantizer parameters to the quantizer parameters dictionary
90
89
 
@@ -171,7 +170,7 @@ class STEWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
171
170
  input_rank=len(self.threshold_shape))
172
171
 
173
172
 
174
- @mark_quantizer(quantization_target=qi.QuantizationTarget.Activation,
173
+ @mark_quantizer(quantization_target=QuantizationTarget.Activation,
175
174
  quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
176
175
  quantizer_type=TrainingMethod.STE)
177
176
  class STEActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
@@ -206,7 +205,7 @@ class STEActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
206
205
  def initialize_quantization(self,
207
206
  tensor_shape: TensorShape,
208
207
  name: str,
209
- layer: qi.KerasQuantizationWrapper):
208
+ layer: KerasQuantizationWrapper):
210
209
  """
211
210
  Add quantizer parameters to the quantizer parameters dictionary
212
211
 
@@ -16,25 +16,24 @@ import numpy as np
16
16
  import tensorflow as tf
17
17
  from tensorflow.python.framework.tensor_shape import TensorShape
18
18
  from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX
19
- from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
19
+ from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
20
20
  from model_compression_toolkit.qat import TrainingMethod
21
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
+
22
+ from mct_quantizers import mark_quantizer, QuantizationMethod, QuantizationTarget, KerasQuantizationWrapper
23
+ from mct_quantizers.keras.quantizers import \
24
+ BaseKerasInferableQuantizer, WeightsUniformInferableQuantizer, ActivationUniformInferableQuantizer
22
25
 
23
26
  from model_compression_toolkit.qat.keras.quantizer.quant_utils import adjust_range_to_include_zero
24
27
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero
25
- from model_compression_toolkit import quantizers_infrastructure as qi, constants as C
28
+ from model_compression_toolkit import constants as C
26
29
 
27
30
  from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
28
- from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
31
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
29
32
  TrainableQuantizerActivationConfig
30
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
31
- mark_quantizer
32
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import \
33
- BaseKerasInferableQuantizer, WeightsUniformInferableQuantizer, ActivationUniformInferableQuantizer
34
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
33
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
35
34
 
36
35
 
37
- @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
36
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
38
37
  quantization_method=[QuantizationMethod.UNIFORM],
39
38
  quantizer_type=TrainingMethod.STE)
40
39
  class STEUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
@@ -73,7 +72,7 @@ class STEUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
73
72
  def initialize_quantization(self,
74
73
  tensor_shape: TensorShape,
75
74
  name: str,
76
- layer: qi.KerasQuantizationWrapper):
75
+ layer: KerasQuantizationWrapper):
77
76
  """
78
77
  Add quantizer parameters to the quantizer parameters dictionary
79
78
 
@@ -148,7 +147,7 @@ class STEUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
148
147
  input_rank=len(self.min_max_shape))
149
148
 
150
149
 
151
- @mark_quantizer(quantization_target=qi.QuantizationTarget.Activation,
150
+ @mark_quantizer(quantization_target=QuantizationTarget.Activation,
152
151
  quantization_method=[QuantizationMethod.UNIFORM],
153
152
  quantizer_type=TrainingMethod.STE)
154
153
  class STEUniformActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
@@ -173,7 +172,7 @@ class STEUniformActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
173
172
  def initialize_quantization(self,
174
173
  tensor_shape: TensorShape,
175
174
  name: str,
176
- layer: qi.KerasQuantizationWrapper):
175
+ layer: KerasQuantizationWrapper):
177
176
  """
178
177
  Add quantizer parameters to the quantizer parameters dictionary
179
178
 
@@ -25,28 +25,32 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
25
25
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
26
26
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
27
27
  MixedPrecisionQuantizationConfigV2
28
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
28
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
29
+ TargetPlatformCapabilities
29
30
  from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
30
31
  from model_compression_toolkit.ptq.runner import ptq_runner
31
32
 
32
-
33
33
  if FOUND_TORCH:
34
34
  import torch.nn as nn
35
35
  from torch.nn import Module
36
+ from mct_quantizers import PytorchActivationQuantizationHolder
36
37
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
37
38
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
38
39
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
39
- from model_compression_toolkit.qat.common.qat_config import _is_qat_applicable
40
+ from model_compression_toolkit.qat.common.qat_config import is_qat_applicable
40
41
  from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
41
- from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
42
- from model_compression_toolkit import quantizers_infrastructure as qi
42
+ from mct_quantizers import PytorchQuantizationWrapper
43
43
  from model_compression_toolkit import get_target_platform_capabilities
44
44
  from model_compression_toolkit.qat.common.qat_config import QATConfig
45
+ from model_compression_toolkit.qat.pytorch.quantizer.quantization_builder import get_activation_quantizer_holder
45
46
  from model_compression_toolkit.qat.pytorch.quantizer.quantization_builder import quantization_builder
47
+
46
48
  DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
47
49
 
48
50
 
49
- def qat_wrapper(n: common.BaseNode, module: nn.Module, qat_config: QATConfig):
51
+ def qat_wrapper(n: common.BaseNode,
52
+ module: nn.Module,
53
+ qat_config: QATConfig):
50
54
  """
51
55
  A function which takes a computational graph node and a pytorch module and perform the quantization wrapping
52
56
  Args:
@@ -56,11 +60,11 @@ if FOUND_TORCH:
56
60
  Returns: Wrapped layer
57
61
 
58
62
  """
59
- if _is_qat_applicable(n, DEFAULT_PYTORCH_INFO):
60
- weights_quantizers, activation_quantizers = quantization_builder(n, qat_config, DEFAULT_PYTORCH_INFO)
61
- return qi.PytorchQuantizationWrapper(module, weights_quantizers, activation_quantizers)
62
- else:
63
- return module
63
+ if is_qat_applicable(n, DEFAULT_PYTORCH_INFO):
64
+ weights_quantizers, _ = quantization_builder(n, qat_config, DEFAULT_PYTORCH_INFO)
65
+ if len(weights_quantizers) > 0:
66
+ return PytorchQuantizationWrapper(module, weights_quantizers)
67
+ return module
64
68
 
65
69
 
66
70
  def pytorch_quantization_aware_training_init(in_model: Module,
@@ -135,11 +139,11 @@ if FOUND_TORCH:
135
139
  if core_config.mixed_precision_enable:
136
140
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
137
141
  Logger.error("Given quantization config to mixed-precision facade is not of type "
138
- "MixedPrecisionQuantizationConfigV2. Please use pytorch_post_training_quantization API,"
139
- "or pass a valid mixed precision configuration.")
142
+ "MixedPrecisionQuantizationConfigV2. Please use pytorch_post_training_quantization API,"
143
+ "or pass a valid mixed precision configuration.")
140
144
 
141
145
  Logger.info("Using experimental mixed-precision quantization. "
142
- "If you encounter an issue please file a bug.")
146
+ "If you encounter an issue please file a bug.")
143
147
 
144
148
  tb_w = _init_tensorboard_writer(fw_info)
145
149
 
@@ -158,12 +162,18 @@ if FOUND_TORCH:
158
162
 
159
163
  _qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
160
164
 
161
- qat_model, user_info = PyTorchModelBuilder(graph=tg, fw_info=fw_info, wrapper=_qat_wrapper).build_model()
165
+ qat_model, user_info = PyTorchModelBuilder(graph=tg,
166
+ fw_info=fw_info,
167
+ wrapper=_qat_wrapper,
168
+ get_activation_quantizer_holder_fn=partial(
169
+ get_activation_quantizer_holder,
170
+ qat_config=qat_config)).build_model()
162
171
 
163
172
  user_info.mixed_precision_cfg = bit_widths_config
164
173
 
165
174
  return qat_model, user_info
166
175
 
176
+
167
177
  def pytorch_quantization_aware_training_finalize(in_model: Module):
168
178
  """
169
179
  Convert a model fine-tuned by the user to a network with QuantizeWrappers containing
@@ -207,7 +217,7 @@ if FOUND_TORCH:
207
217
  """
208
218
  exported_model = copy.deepcopy(in_model)
209
219
  for _, layer in exported_model.named_children():
210
- if isinstance(layer, PytorchQuantizationWrapper):
220
+ if isinstance(layer, (PytorchQuantizationWrapper, PytorchActivationQuantizationHolder)):
211
221
  layer.convert_to_inferable_quantizers()
212
222
 
213
223
  return exported_model
@@ -221,6 +231,7 @@ else:
221
231
  'when using pytorch_quantization_aware_training_init. '
222
232
  'Could not find the torch package.') # pragma: no cover
223
233
 
234
+
224
235
  def pytorch_quantization_aware_training_finalize(*args, **kwargs):
225
236
  Logger.critical('Installing Pytorch is mandatory '
226
237
  'when using pytorch_quantization_aware_training_finalize. '
@@ -17,9 +17,9 @@ from typing import Union
17
17
  from model_compression_toolkit.logger import Logger
18
18
  from model_compression_toolkit.constants import FOUND_TORCH
19
19
 
20
- from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
20
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
21
21
  TrainableQuantizerActivationConfig
22
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.pytorch.base_pytorch_quantizer import \
22
+ from model_compression_toolkit.trainable_infrastructure.pytorch.base_pytorch_quantizer import \
23
23
  BasePytorchTrainableQuantizer
24
24
 
25
25
  if FOUND_TORCH:
@@ -12,19 +12,46 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import List, Dict, Tuple
15
+ from typing import List, Dict, Tuple, Callable
16
+
17
+ from mct_quantizers import PytorchActivationQuantizationHolder, QuantizationTarget
16
18
 
17
19
  from model_compression_toolkit.core import common
18
20
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
19
21
  from model_compression_toolkit.qat.common.qat_config import QATConfig
20
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizer_config import \
22
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
23
+ from model_compression_toolkit.logger import Logger
24
+ from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
21
25
  get_trainable_quantizer_quantization_candidates, get_trainable_quantizer_weights_config, \
22
26
  get_trainable_quantizer_activation_config
23
27
  from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
24
- from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
25
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizers import \
28
+ from model_compression_toolkit.trainable_infrastructure.common.get_quantizers import \
26
29
  get_trainable_quantizer_class
27
30
 
31
+ def get_activation_quantizer_holder(n: common.BaseNode,
32
+ qat_config: QATConfig) -> Callable:
33
+ """
34
+ Retrieve a ActivationQuantizationHolder layer to use for activation quantization for a node.
35
+ If the layer is not supposed to be wrapped with activation quantizers - return None.
36
+
37
+ Args:
38
+ n: Node for which to retrieve anActivationQuantizationHolder to attach to its output.
39
+ qat_config: QAT configuration (for example, training methods).
40
+
41
+ Returns:
42
+ A ActivationQuantizationHolder layer for the node's activation quantization.
43
+ """
44
+ _, activation_quantizers = quantization_builder(n,
45
+ qat_config,
46
+ DEFAULT_PYTORCH_INFO)
47
+
48
+ # Holder by definition uses a single quantizer for the activation quantization
49
+ # thus we make sure this is the only possible case (unless it's a node with no activation
50
+ # quantization, which in this case has an empty list).
51
+ if len(activation_quantizers) == 1:
52
+ return PytorchActivationQuantizationHolder(activation_quantizers[0])
53
+ Logger.error(f'ActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers were found for node {n}')
54
+
28
55
 
29
56
  def quantization_builder(n: common.BaseNode,
30
57
  qat_config: QATConfig,