mct-nightly 1.8.0.20052023.post401__py3-none-any.whl → 1.8.0.20230610.post356__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/METADATA +10 -7
  2. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/RECORD +68 -115
  3. model_compression_toolkit/__init__.py +23 -3
  4. model_compression_toolkit/core/common/framework_info.py +1 -1
  5. model_compression_toolkit/core/keras/back2framework/instance_builder.py +16 -9
  6. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +8 -34
  7. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +5 -1
  8. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +103 -28
  9. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +39 -44
  10. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_tflite_exporter.py +1 -1
  11. model_compression_toolkit/exporter/model_exporter/keras/int8_tflite_exporter.py +20 -18
  12. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +3 -3
  13. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  14. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +36 -9
  15. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +4 -4
  16. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +24 -32
  17. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +31 -8
  18. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +5 -5
  19. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +34 -8
  20. model_compression_toolkit/gptq/keras/gptq_training.py +15 -16
  21. model_compression_toolkit/gptq/keras/graph_info.py +2 -2
  22. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +4 -5
  23. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +5 -7
  24. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -1
  25. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +6 -6
  26. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +7 -7
  27. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +6 -6
  28. model_compression_toolkit/gptq/pytorch/gptq_training.py +30 -10
  29. model_compression_toolkit/gptq/pytorch/graph_info.py +5 -2
  30. model_compression_toolkit/gptq/pytorch/quantization_facade.py +4 -2
  31. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +4 -4
  32. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +5 -7
  33. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -1
  34. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +7 -7
  35. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +7 -8
  36. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +7 -8
  37. model_compression_toolkit/qat/common/__init__.py +2 -1
  38. model_compression_toolkit/qat/common/qat_config.py +2 -2
  39. model_compression_toolkit/qat/keras/quantization_facade.py +18 -8
  40. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +1 -1
  41. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +11 -11
  42. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +11 -12
  43. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +12 -13
  44. model_compression_toolkit/qat/pytorch/quantization_facade.py +27 -16
  45. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  46. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +31 -4
  47. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +10 -9
  48. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +11 -10
  49. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +2 -1
  50. model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +1 -25
  51. model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py → trainable_infrastructure/__init__.py} +3 -10
  52. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/base_trainable_quantizer.py +3 -3
  53. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/get_quantizer_config.py +1 -1
  54. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/get_quantizers.py +3 -3
  55. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/base_keras_quantizer.py +4 -4
  56. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/config_serialization.py +2 -2
  57. model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure → trainable_infrastructure}/keras/load_model.py +16 -23
  58. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/pytorch/base_pytorch_quantizer.py +3 -3
  59. model_compression_toolkit/quantizers_infrastructure/__init__.py +0 -23
  60. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +0 -87
  61. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +0 -46
  62. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +0 -31
  63. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +0 -53
  64. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +0 -49
  65. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +0 -147
  66. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +0 -345
  67. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +0 -85
  68. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +0 -27
  69. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +0 -14
  70. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +0 -148
  71. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +0 -65
  72. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +0 -86
  73. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +0 -111
  74. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +0 -56
  75. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +0 -14
  76. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +0 -79
  77. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +0 -179
  78. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +0 -67
  79. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +0 -87
  80. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +0 -163
  81. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +0 -66
  82. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +0 -14
  83. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +0 -269
  84. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +0 -152
  85. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +0 -35
  86. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/__init__.py +0 -14
  87. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +0 -96
  88. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +0 -62
  89. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +0 -83
  90. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +0 -100
  91. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +0 -95
  92. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +0 -48
  93. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +0 -70
  94. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +0 -57
  95. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +0 -26
  96. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +0 -14
  97. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +0 -77
  98. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +0 -106
  99. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +0 -66
  100. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +0 -104
  101. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +0 -109
  102. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +0 -14
  103. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +0 -14
  104. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +0 -14
  105. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +0 -14
  106. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/LICENSE.md +0 -0
  107. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/WHEEL +0 -0
  108. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/top_level.txt +0 -0
  109. /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure → trainable_infrastructure/common}/__init__.py +0 -0
  110. /model_compression_toolkit/{quantizers_infrastructure → trainable_infrastructure/common}/constants.py +0 -0
  111. /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/quant_utils.py +0 -0
  112. /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/trainable_quantizer_config.py +0 -0
  113. /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/common → trainable_infrastructure/keras}/__init__.py +0 -0
  114. /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/quantizer_utils.py +0 -0
  115. /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/keras → trainable_infrastructure/pytorch}/__init__.py +0 -0
@@ -20,22 +20,23 @@ import torch.nn as nn
20
20
 
21
21
  from model_compression_toolkit.qat import TrainingMethod
22
22
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
+ from mct_quantizers import PytorchQuantizationWrapper
23
24
  from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
24
- from model_compression_toolkit import quantizers_infrastructure as qi, constants as C
25
+ from model_compression_toolkit import constants as C
25
26
  from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
26
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
27
+ from mct_quantizers.common.base_inferable_quantizer import mark_quantizer, QuantizationTarget
27
28
 
28
29
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
29
30
  from model_compression_toolkit.qat.pytorch.quantizer.quantizer_utils import ste_round, ste_clip, symmetric_quantizer
30
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
31
+ from mct_quantizers.pytorch.quantizers import \
31
32
  WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer, ActivationPOTInferableQuantizer, \
32
33
  ActivationSymmetricInferableQuantizer
33
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
34
+ from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \
34
35
  TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
35
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
36
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
36
37
 
37
38
 
38
- @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
39
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
39
40
  quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
40
41
  quantizer_type=TrainingMethod.STE)
41
42
  class STEWeightQATQuantizer(BasePytorchQATTrainableQuantizer):
@@ -73,7 +74,7 @@ class STEWeightQATQuantizer(BasePytorchQATTrainableQuantizer):
73
74
  def initialize_quantization(self,
74
75
  tensor_shape: torch.Size,
75
76
  name: str,
76
- layer: qi.PytorchQuantizationWrapper):
77
+ layer: PytorchQuantizationWrapper):
77
78
  """
78
79
  Add quantizer parameters to the quantizer parameters dictionary
79
80
 
@@ -129,7 +130,7 @@ class STEWeightQATQuantizer(BasePytorchQATTrainableQuantizer):
129
130
 
130
131
 
131
132
 
132
- @mark_quantizer(quantization_target=qi.QuantizationTarget.Activation,
133
+ @mark_quantizer(quantization_target=QuantizationTarget.Activation,
133
134
  quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
134
135
  quantizer_type=TrainingMethod.STE)
135
136
  class STEActivationQATQuantizer(BasePytorchQATTrainableQuantizer):
@@ -155,7 +156,7 @@ class STEActivationQATQuantizer(BasePytorchQATTrainableQuantizer):
155
156
  def initialize_quantization(self,
156
157
  tensor_shape: torch.Size,
157
158
  name: str,
158
- layer: qi.PytorchQuantizationWrapper):
159
+ layer: PytorchQuantizationWrapper):
159
160
  """
160
161
  Add quantizer parameters to the quantizer parameters dictionary
161
162
 
@@ -18,24 +18,25 @@ import torch.nn as nn
18
18
  from torch import Tensor
19
19
 
20
20
  from model_compression_toolkit.constants import RANGE_MAX, RANGE_MIN
21
- from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
21
+ from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
22
22
 
23
23
  from model_compression_toolkit.qat import TrainingMethod
24
24
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
25
- from model_compression_toolkit import quantizers_infrastructure as qi, constants as C
25
+ from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
26
+ from model_compression_toolkit import constants as C
26
27
 
27
28
  from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
28
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
29
+ from mct_quantizers import mark_quantizer
29
30
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
30
31
  from model_compression_toolkit.qat.pytorch.quantizer.quantizer_utils import uniform_quantizer
31
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
32
+ from mct_quantizers.pytorch.quantizers import \
32
33
  WeightsUniformInferableQuantizer, ActivationUniformInferableQuantizer
33
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
34
+ from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \
34
35
  TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
35
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
36
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
36
37
 
37
38
 
38
- @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
39
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
39
40
  quantization_method=[QuantizationMethod.UNIFORM],
40
41
  quantizer_type=TrainingMethod.STE)
41
42
  class STEUniformWeightQATQuantizer(BasePytorchQATTrainableQuantizer):
@@ -69,7 +70,7 @@ class STEUniformWeightQATQuantizer(BasePytorchQATTrainableQuantizer):
69
70
  def initialize_quantization(self,
70
71
  tensor_shape: torch.Size,
71
72
  name: str,
72
- layer: qi.PytorchQuantizationWrapper):
73
+ layer: PytorchQuantizationWrapper):
73
74
  """
74
75
  Add quantizer parameters to the quantizer parameters dictionary
75
76
 
@@ -117,7 +118,7 @@ class STEUniformWeightQATQuantizer(BasePytorchQATTrainableQuantizer):
117
118
  channel_axis=self.quantization_config.weights_channels_axis)
118
119
 
119
120
 
120
- @mark_quantizer(quantization_target=qi.QuantizationTarget.Activation,
121
+ @mark_quantizer(quantization_target=QuantizationTarget.Activation,
121
122
  quantization_method=[QuantizationMethod.UNIFORM],
122
123
  quantizer_type=TrainingMethod.STE)
123
124
  class STEUniformActivationQATQuantizer(BasePytorchQATTrainableQuantizer):
@@ -144,7 +145,7 @@ class STEUniformActivationQATQuantizer(BasePytorchQATTrainableQuantizer):
144
145
  def initialize_quantization(self,
145
146
  tensor_shape: torch.Size,
146
147
  name: str,
147
- layer: qi.PytorchQuantizationWrapper):
148
+ layer: PytorchQuantizationWrapper):
148
149
  """
149
150
  Add quantizer parameters to the quantizer parameters dictionary
150
151
 
@@ -21,8 +21,9 @@ from model_compression_toolkit.target_platform_capabilities.target_platform.targ
21
21
  get_default_quantization_config_options, TargetPlatformModel
22
22
 
23
23
  from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import OpQuantizationConfig, \
24
- QuantizationConfigOptions, QuantizationMethod
24
+ QuantizationConfigOptions
25
25
  from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSet, OperatorSetConcat
26
26
 
27
+ from mct_quantizers import QuantizationMethod
27
28
 
28
29
 
@@ -14,33 +14,9 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import copy
17
- from enum import Enum
18
17
  from typing import List
19
18
 
20
-
21
- class QuantizationMethod(Enum):
22
- """
23
- Method for quantization function selection:
24
-
25
- POWER_OF_TWO - Symmetric, uniform, threshold is power of two quantization.
26
-
27
- KMEANS - k-means quantization.
28
-
29
- LUT_POT_QUANTIZER - quantization using a lookup table and power of 2 threshold.
30
-
31
- SYMMETRIC - Symmetric, uniform, quantization.
32
-
33
- UNIFORM - uniform quantization,
34
-
35
- LUT_SYM_QUANTIZER - quantization using a lookup table and symmetric threshold.
36
-
37
- """
38
- POWER_OF_TWO = 0
39
- KMEANS = 1
40
- LUT_POT_QUANTIZER = 2
41
- SYMMETRIC = 3
42
- UNIFORM = 4
43
- LUT_SYM_QUANTIZER = 5
19
+ from mct_quantizers import QuantizationMethod
44
20
 
45
21
 
46
22
  class OpQuantizationConfig:
@@ -13,13 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- # Inferable keras quantizer signature parameters:
17
- NUM_BITS = 'num_bits'
18
- SIGNED = 'signed'
19
- THRESHOLD = 'threshold'
20
- PER_CHANNEL = 'per_channel'
21
- MIN_RANGE = 'min_range'
22
- MAX_RANGE = 'max_range'
23
- CHANNEL_AXIS = 'channel_axis'
24
- INPUT_RANK = 'input_rank'
25
- CLUSTER_CENTERS = 'cluster_centers'
16
+ from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
17
+ from model_compression_toolkit.trainable_infrastructure.keras.base_keras_quantizer import BaseKerasTrainableQuantizer
18
+ from model_compression_toolkit.trainable_infrastructure.pytorch.base_pytorch_quantizer import BasePytorchTrainableQuantizer
@@ -20,11 +20,11 @@ from inspect import signature
20
20
  from model_compression_toolkit.core import common
21
21
  from model_compression_toolkit.logger import Logger
22
22
 
23
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer, \
23
+ from mct_quantizers.common.base_inferable_quantizer import BaseInferableQuantizer, \
24
24
  QuantizationTarget
25
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
25
+ from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \
26
26
  TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig
27
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import QUANTIZATION_METHOD, \
27
+ from mct_quantizers.common.constants import QUANTIZATION_METHOD, \
28
28
  QUANTIZATION_TARGET
29
29
 
30
30
 
@@ -15,7 +15,7 @@
15
15
  from typing import List
16
16
  from model_compression_toolkit.core.common import BaseNode
17
17
  from model_compression_toolkit.logger import Logger
18
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
18
+ from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \
19
19
  TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig, TrainableQuantizerCandidateConfig
20
20
 
21
21
 
@@ -16,10 +16,10 @@ from typing import Union, Any
16
16
 
17
17
  from model_compression_toolkit.logger import Logger
18
18
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
19
- from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
20
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants \
19
+ from mct_quantizers import QuantizationTarget
20
+ from mct_quantizers.common.constants \
21
21
  import QUANTIZATION_TARGET, QUANTIZATION_METHOD, QUANTIZER_TYPE
22
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_all_subclasses \
22
+ from mct_quantizers.common.get_all_subclasses \
23
23
  import get_all_subclasses
24
24
 
25
25
 
@@ -16,14 +16,14 @@ from typing import Dict, Any, Union, List
16
16
 
17
17
  from model_compression_toolkit.logger import Logger
18
18
  from model_compression_toolkit.constants import FOUND_TF
19
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
20
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer, VAR, GROUP
21
- from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
19
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
20
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer, VAR, GROUP
21
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
22
22
  TrainableQuantizerActivationConfig
23
23
 
24
24
  if FOUND_TF:
25
25
  QUANTIZATION_CONFIG = 'quantization_config'
26
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.keras.config_serialization import config_serialization, \
26
+ from model_compression_toolkit.trainable_infrastructure.keras.config_serialization import config_serialization, \
27
27
  config_deserialization
28
28
  import tensorflow as tf
29
29
 
@@ -18,9 +18,9 @@ from typing import Any, Union
18
18
  from enum import Enum
19
19
 
20
20
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
21
+ from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \
22
22
  TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig
23
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common import constants as C
23
+ from mct_quantizers.common import constants as C
24
24
 
25
25
 
26
26
  def transform_enum(v: Any):
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,20 +12,24 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from model_compression_toolkit.logger import Logger
15
+ from typing import Any
16
+
17
+ import mct_quantizers
18
+ from mct_quantizers.common.get_all_subclasses import get_all_subclasses
19
+
16
20
  from model_compression_toolkit.constants import FOUND_TF
17
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_all_subclasses import get_all_subclasses
21
+ from model_compression_toolkit.logger import Logger
18
22
 
19
23
  if FOUND_TF:
20
24
  import tensorflow as tf
21
- from model_compression_toolkit import quantizers_infrastructure as qi
22
- from model_compression_toolkit.quantizers_infrastructure import BaseKerasTrainableQuantizer
23
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import BaseKerasInferableQuantizer
25
+ from tensorflow.python.saved_model.load_options import LoadOptions
26
+ from model_compression_toolkit.trainable_infrastructure import BaseKerasTrainableQuantizer
24
27
  keras = tf.keras
25
28
 
26
- def keras_load_quantized_model(filepath, custom_objects=None, compile=True, options=None):
29
+ def keras_load_quantized_model(filepath: str, custom_objects: Any = None, compile: bool = True,
30
+ options: LoadOptions = None):
27
31
  """
28
- This function wraps the keras load model and MCT quantization custom class to it.
32
+ This function wraps the keras load model and adds trainable quantizers classes to its custom objects.
29
33
 
30
34
  Args:
31
35
  filepath: the model file path.
@@ -36,12 +40,6 @@ if FOUND_TF:
36
40
  Returns: A keras Model
37
41
 
38
42
  """
39
- qi_inferable_custom_objects = {subclass.__name__: subclass for subclass in
40
- get_all_subclasses(BaseKerasInferableQuantizer)}
41
- all_inferable_names = list(qi_inferable_custom_objects.keys())
42
- if len(set(all_inferable_names)) < len(all_inferable_names):
43
- Logger.error(f"Found multiple quantizers with the same name that inherit from BaseKerasInferableQuantizer"
44
- f"while trying to load a model.")
45
43
 
46
44
  qi_trainable_custom_objects = {subclass.__name__: subclass for subclass in
47
45
  get_all_subclasses(BaseKerasTrainableQuantizer)}
@@ -50,18 +48,13 @@ if FOUND_TF:
50
48
  Logger.error(f"Found multiple quantizers with the same name that inherit from BaseKerasTrainableQuantizer"
51
49
  f"while trying to load a model.")
52
50
 
53
- # Merge dictionaries into one dict
54
- qi_custom_objects = {**qi_inferable_custom_objects, **qi_trainable_custom_objects}
55
-
56
- # Add non-quantizers custom objects
57
- qi_custom_objects.update({qi.KerasQuantizationWrapper.__name__: qi.KerasQuantizationWrapper})
58
- qi_custom_objects.update({qi.ActivationQuantizationHolder.__name__: qi.ActivationQuantizationHolder})
51
+ qi_custom_objects = {**qi_trainable_custom_objects}
59
52
 
60
53
  if custom_objects is not None:
61
54
  qi_custom_objects.update(custom_objects)
62
- return tf.keras.models.load_model(filepath,
63
- custom_objects=qi_custom_objects, compile=compile,
64
- options=options)
55
+ return mct_quantizers.keras_load_quantized_model(filepath,
56
+ custom_objects=qi_custom_objects, compile=compile,
57
+ options=options)
65
58
  else:
66
59
  def keras_load_quantized_model(filepath, custom_objects=None, compile=True, options=None):
67
60
  """
@@ -16,9 +16,9 @@ from typing import Union, List
16
16
 
17
17
  from model_compression_toolkit.logger import Logger
18
18
  from model_compression_toolkit.constants import FOUND_TORCH
19
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
20
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer, VAR, GROUP
21
- from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
19
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
20
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer, VAR, GROUP
21
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
22
22
  TrainableQuantizerActivationConfig
23
23
 
24
24
 
@@ -1,23 +0,0 @@
1
- # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget, BaseInferableQuantizer
17
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
18
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.keras.base_keras_quantizer import BaseKerasTrainableQuantizer
19
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.load_model import keras_load_quantized_model
20
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantize_wrapper import KerasQuantizationWrapper
21
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantize_wrapper import PytorchQuantizationWrapper
22
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.activation_quantization_holder import ActivationQuantizationHolder
23
-
@@ -1,87 +0,0 @@
1
- # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from enum import Enum
16
- from typing import Any, Dict, List
17
-
18
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
19
-
20
-
21
- class QuantizationTarget(Enum):
22
- Activation = "Activation"
23
- Weights = "Weights"
24
-
25
-
26
- def mark_quantizer(quantization_target: QuantizationTarget = None,
27
- quantization_method: List[QuantizationMethod] = None,
28
- quantizer_type: Any = None):
29
- """
30
- A function to be used as decoration for all inferable quantizers (which inherit from BaseInferableQuantizer).
31
- By decorating a class with this decoration, we can define required static properties of the quantizer.
32
-
33
- Args:
34
- quantization_target: QuantizationTarget value which indicates what is the target for quantization to
35
- use the quantizer for.
36
- quantization_method: A list of QuantizationMethod values to indicate all type of quantization methods that the
37
- quantizer supports.
38
- quantizer_type: The type of the quantizer (quantization technique).
39
- This can differ, depending on the purpose the quantizer is for.
40
-
41
- Returns: A function that decorates a class object.
42
-
43
- """
44
- def mark(quantizer_class_object: BaseInferableQuantizer):
45
- """
46
- Initializes the parameters for the decorator.
47
-
48
- Args:
49
- quantizer_class_object: The class to be decorated.
50
-
51
- Returns: A decorated class.
52
-
53
- """
54
- quantizer_class_object.quantization_target = quantization_target
55
- quantizer_class_object.quantization_method = quantization_method
56
- quantizer_class_object.quantizer_type = quantizer_type
57
-
58
- return quantizer_class_object
59
-
60
- return mark
61
-
62
-
63
- class BaseInferableQuantizer:
64
-
65
- def __init__(self):
66
- """
67
- This class is a base quantizer which defines an abstract
68
- function which any quantizer needs to implement.
69
- """
70
- pass
71
-
72
- def initialize_quantization(self,
73
- tensor_shape: Any,
74
- name: str,
75
- layer: Any) -> Dict[Any, Any]:
76
- """
77
- Return a dictionary of quantizer parameters and their names.
78
-
79
- Args:
80
- tensor_shape: tensor shape of the quantized tensor.
81
- name: Tensor name.
82
- layer: Layer to quantize.
83
-
84
- Returns:
85
- Dictionary of parameters names to the variables.
86
- """
87
- return {}
@@ -1,46 +0,0 @@
1
- # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- IS_WEIGHTS = "is_weights"
17
- IS_ACTIVATIONS = "is_activations"
18
-
19
- # In KerasQuantizationWrapper and PytorchQuantizationWrapper multiple quantizers are kept
20
- ACTIVATION_QUANTIZERS = "activation_quantizers"
21
- # In ActivationQuantizationHolder only one quantizer is used thus a new attribute name is needed
22
- ACTIVATION_HOLDER_QUANTIZER = "activation_holder_quantizer"
23
-
24
- WEIGHTS_QUANTIZERS = "weights_quantizer"
25
- WEIGHTS_QUANTIZATION_METHOD = 'weights_quantization_method'
26
- WEIGHTS_N_BITS = 'weights_n_bits'
27
- WEIGHTS_QUANTIZATION_PARAMS = 'weights_quantization_params'
28
- ENABLE_WEIGHTS_QUANTIZATION = 'enable_weights_quantization'
29
- WEIGHTS_CHANNELS_AXIS = 'weights_channels_axis'
30
- WEIGHTS_PER_CHANNEL_THRESHOLD = 'weights_per_channel_threshold'
31
- MIN_THRESHOLD = 'min_threshold'
32
- ACTIVATION_QUANTIZATION_METHOD = 'activation_quantization_method'
33
- ACTIVATION_N_BITS = 'activation_n_bits'
34
- ACTIVATION_QUANTIZATION_PARAMS = 'activation_quantization_params'
35
- ENABLE_ACTIVATION_QUANTIZATION = 'enable_activation_quantization'
36
- LAYER = "layer"
37
- STEPS = "optimizer_step"
38
- TRAINING = "training"
39
-
40
-
41
- QUANTIZATION_TARGET = 'quantization_target'
42
- QUANTIZATION_METHOD = 'quantization_method'
43
- QUANTIZER_TYPE = 'quantizer_type'
44
-
45
- EPS = 1e-8
46
- MULTIPLIER_N_BITS = 8
@@ -1,31 +0,0 @@
1
- # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from typing import Set
16
-
17
-
18
- def get_all_subclasses(cls: type) -> Set[type]:
19
- """
20
- This function returns a list of all subclasses of the given class,
21
- including all subclasses of those subclasses, and so on.
22
- Recursively get all subclasses of the subclass and add them to the list of all subclasses.
23
-
24
- Args:
25
- cls: A class object.
26
-
27
- Returns: All classes that inherit from cls.
28
-
29
- """
30
-
31
- return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in get_all_subclasses(c)])
@@ -1,53 +0,0 @@
1
- # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from model_compression_toolkit.logger import Logger
17
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
18
- from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
19
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import QUANTIZATION_TARGET, \
20
- QUANTIZATION_METHOD
21
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_all_subclasses import get_all_subclasses
22
-
23
-
24
- def get_inferable_quantizer_class(quant_target: QuantizationTarget,
25
- quant_method: QuantizationMethod,
26
- quantizer_base_class: type) -> type:
27
- """
28
- Searches for an inferable quantizer class that matches the requested QuantizationTarget and QuantizationMethod.
29
- Exactly one class should be found.
30
-
31
- Args:
32
- quant_target: QuantizationTarget value (Weights or Activation) which indicates what is the target for
33
- quantization to use the quantizer for.
34
- quant_method: A list of QuantizationMethod values to indicate all type of quantization methods that the
35
- quantizer supports.
36
- quantizer_base_class: A type of quantizer that the requested quantizer should inherit from.
37
-
38
- Returns: A class of a quantizer that inherits from BaseKerasInferableQuantizer.
39
-
40
- """
41
- qat_quantizer_classes = get_all_subclasses(quantizer_base_class)
42
- filtered_quantizers = list(filter(lambda q_class: getattr(q_class, QUANTIZATION_TARGET) == quant_target and
43
- getattr(q_class, QUANTIZATION_METHOD) is not None and
44
- quant_method in getattr(q_class, QUANTIZATION_METHOD),
45
- qat_quantizer_classes))
46
-
47
- if len(filtered_quantizers) != 1:
48
- Logger.error(f"Found {len(filtered_quantizers)} quantizer for target {quant_target.value} "
49
- f"that matches the requested quantization method {quant_method.name} "
50
- f"but there should be exactly one."
51
- f"The possible quantizers that were found are {filtered_quantizers}.")
52
-
53
- return filtered_quantizers[0]
@@ -1,49 +0,0 @@
1
- # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import numpy as np
16
- from typing import Tuple
17
-
18
-
19
- def adjust_range_to_include_zero(range_min: np.ndarray,
20
- range_max: np.ndarray,
21
- n_bits: int) -> Tuple[np.ndarray, np.ndarray]:
22
- """
23
- Adjusting the quantization range to include representation of 0.0 in the quantization grid.
24
- For per_channel quantization range_min\range_max should be tensors in the specific shape that allows
25
- quantization along the channel_axis.
26
-
27
- Args:
28
- range_min: min bound of the quantization range (before adjustment).
29
- range_max: max bound of the quantization range (before adjustment).
30
- n_bits: Number of bits to quantize the tensor.
31
-
32
- Returns: adjusted quantization range
33
- """
34
- scale = (range_max - range_min) / (2 ** n_bits - 1)
35
- min_range_adj = scale * np.round(range_min / scale)
36
- max_range_adj = range_max - range_min + min_range_adj
37
-
38
- min_positive = range_min > 0
39
- max_negative = range_max < 0
40
- mid_range = np.logical_and(np.logical_not(min_positive), np.logical_not(max_negative))
41
-
42
- min_range_adj = min_range_adj * mid_range + max_negative * range_min
43
- max_range_adj = max_range_adj * mid_range + min_positive * range_max
44
-
45
- # Make sure min_range_adj < 0 and max_range_adj > 0 to avoid small numeric error
46
- min_range_adj = np.minimum(min_range_adj, 0)
47
- max_range_adj = np.maximum(max_range_adj, 0)
48
-
49
- return min_range_adj, max_range_adj