mct-nightly 2.2.0.20240902.511__py3-none-any.whl → 2.2.0.20240904.449__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {mct_nightly-2.2.0.20240902.511.dist-info → mct_nightly-2.2.0.20240904.449.dist-info}/METADATA +6 -6
  2. {mct_nightly-2.2.0.20240902.511.dist-info → mct_nightly-2.2.0.20240904.449.dist-info}/RECORD +35 -26
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +1 -2
  5. model_compression_toolkit/qat/__init__.py +2 -2
  6. model_compression_toolkit/qat/common/qat_config.py +1 -19
  7. model_compression_toolkit/qat/keras/quantization_facade.py +1 -1
  8. model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +1 -1
  9. model_compression_toolkit/qat/keras/quantizer/lsq/uniform_lsq.py +1 -1
  10. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  11. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +1 -1
  12. model_compression_toolkit/qat/pytorch/quantizer/{base_pytorch_qat_quantizer.py → base_pytorch_qat_weight_quantizer.py} +4 -13
  13. model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +6 -116
  14. model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +12 -122
  15. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +8 -7
  16. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +6 -84
  17. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +6 -85
  18. model_compression_toolkit/trainable_infrastructure/__init__.py +9 -3
  19. model_compression_toolkit/trainable_infrastructure/common/base_trainable_quantizer.py +9 -8
  20. model_compression_toolkit/trainable_infrastructure/common/training_method.py +31 -0
  21. model_compression_toolkit/trainable_infrastructure/keras/base_keras_quantizer.py +2 -2
  22. model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py +2 -2
  23. model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/__init__.py +19 -0
  24. model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/base_activation_quantizer.py +22 -0
  25. model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/lsq/__init__.py +14 -0
  26. model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/lsq/symmetric_lsq.py +111 -0
  27. model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/lsq/uniform_lsq.py +106 -0
  28. model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/ste/__init__.py +14 -0
  29. model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/ste/symmetric_ste.py +108 -0
  30. model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/ste/uniform_ste.py +105 -0
  31. model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +7 -14
  32. model_compression_toolkit/{qat/pytorch/quantizer → trainable_infrastructure/pytorch}/quantizer_utils.py +79 -2
  33. {mct_nightly-2.2.0.20240902.511.dist-info → mct_nightly-2.2.0.20240904.449.dist-info}/LICENSE.md +0 -0
  34. {mct_nightly-2.2.0.20240902.511.dist-info → mct_nightly-2.2.0.20240904.449.dist-info}/WHEEL +0 -0
  35. {mct_nightly-2.2.0.20240902.511.dist-info → mct_nightly-2.2.0.20240904.449.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,106 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import numpy as np
16
+ import torch
17
+ from torch import nn
18
+
19
+ from mct_quantizers import mark_quantizer, QuantizationTarget, QuantizationMethod, PytorchQuantizationWrapper
20
+ from mct_quantizers.pytorch.quantizers import ActivationUniformInferableQuantizer
21
+ from model_compression_toolkit import constants as C
22
+ from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero
23
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
24
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerActivationConfig, TrainingMethod
25
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
26
+ from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
27
+ from model_compression_toolkit.trainable_infrastructure.pytorch.activation_quantizers import \
28
+ BasePytorchActivationTrainableQuantizer
29
+ from model_compression_toolkit.trainable_infrastructure.pytorch.quantizer_utils import uniform_lsq_quantizer
30
+
31
+
32
+ # moved (and renamed) from model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py
33
+ @mark_quantizer(quantization_target=QuantizationTarget.Activation,
34
+ quantization_method=[QuantizationMethod.UNIFORM],
35
+ identifier=TrainingMethod.LSQ)
36
+ class LSQUniformActivationTrainableQuantizer(BasePytorchActivationTrainableQuantizer):
37
+ """
38
+ Trainable constrained quantizer to quantize layer activations.
39
+ """
40
+
41
+ def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
42
+ """
43
+ Initialize a LSQUniformActivationTrainableQuantizer object with parameters to use
44
+ for uniform quantization.
45
+
46
+ Args:
47
+ quantization_config: trainable quantizer config class
48
+ """
49
+ super().__init__(quantization_config)
50
+ self.num_bits = self.quantization_config.activation_n_bits
51
+ self.min_int = 0
52
+ self.max_int = 2 ** self.num_bits - 1
53
+ self.min_range = np.array([quantization_config.activation_quantization_params[C.RANGE_MIN]])
54
+ self.max_range = np.array([quantization_config.activation_quantization_params[C.RANGE_MAX]])
55
+
56
+ def initialize_quantization(self,
57
+ tensor_shape: torch.Size,
58
+ name: str,
59
+ layer: PytorchQuantizationWrapper):
60
+ """
61
+ Add quantizer parameters to the quantizer parameters dictionary
62
+
63
+ Args:
64
+ tensor_shape: tensor shape of the quantized tensor.
65
+ name: Tensor name.
66
+ layer: Layer to quantize.
67
+ """
68
+ layer.register_parameter(name+"_"+FQ_MIN, nn.Parameter(to_torch_tensor(self.min_range), requires_grad=True))
69
+ layer.register_parameter(name+"_"+FQ_MAX, nn.Parameter(to_torch_tensor(self.max_range), requires_grad=True))
70
+
71
+ # Save the quantizer parameters for later calculations
72
+ self.add_quantizer_variable(FQ_MIN, layer.get_parameter(name+"_"+FQ_MIN), VariableGroup.QPARAMS)
73
+ self.add_quantizer_variable(FQ_MAX, layer.get_parameter(name+"_"+FQ_MAX), VariableGroup.QPARAMS)
74
+
75
+ def __call__(self,
76
+ inputs: torch.Tensor,
77
+ training: bool = True) -> torch.Tensor:
78
+ """
79
+ Quantize a tensor.
80
+ Args:
81
+ inputs: Input tensor to quantize.
82
+ training: Whether the graph is in training mode.
83
+
84
+ Returns:
85
+ The quantized tensor.
86
+ """
87
+ min_range = self.get_quantizer_variable(FQ_MIN)
88
+ max_range = self.get_quantizer_variable(FQ_MAX)
89
+ n_channels = inputs.shape[1]
90
+ scale_factor = 1.0 / np.sqrt(self.max_int * n_channels)
91
+ inputs_quantized = uniform_lsq_quantizer(inputs, min_range, max_range, self.num_bits, self.min_int, self.max_int, scale_factor)
92
+ return inputs_quantized
93
+
94
+ def convert2inferable(self) -> ActivationUniformInferableQuantizer:
95
+ """
96
+ Convert quantizer to inferable quantizer.
97
+
98
+ Returns:
99
+ A pytorch inferable quanizer object.
100
+ """
101
+ min_range = self.get_quantizer_variable(FQ_MIN).cpu().detach().numpy()
102
+ max_range = self.get_quantizer_variable(FQ_MAX).cpu().detach().numpy()
103
+ min_range, max_range = fix_range_to_include_zero(min_range, max_range, self.num_bits)
104
+ return ActivationUniformInferableQuantizer(num_bits=self.num_bits,
105
+ min_range=min_range.tolist(),
106
+ max_range=max_range.tolist())
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,108 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Union
16
+
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn
20
+
21
+ from mct_quantizers import mark_quantizer, QuantizationTarget, QuantizationMethod, PytorchQuantizationWrapper
22
+ from mct_quantizers.pytorch.quantizers import ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer
23
+ from model_compression_toolkit import constants as C
24
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
25
+ from model_compression_toolkit.trainable_infrastructure import TrainingMethod, TrainableQuantizerActivationConfig
26
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
27
+ from model_compression_toolkit.trainable_infrastructure.common.constants import THRESHOLD_TENSOR
28
+ from model_compression_toolkit.trainable_infrastructure.pytorch.activation_quantizers import \
29
+ BasePytorchActivationTrainableQuantizer
30
+ from model_compression_toolkit.trainable_infrastructure.pytorch.quantizer_utils import symmetric_quantizer
31
+
32
+
33
+ # moved (and renamed) from model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py
34
+ @mark_quantizer(quantization_target=QuantizationTarget.Activation,
35
+ quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
36
+ identifier=TrainingMethod.STE)
37
+ class STESymmetricActivationTrainableQuantizer(BasePytorchActivationTrainableQuantizer):
38
+ """
39
+ Trainable constrained quantizer to quantize a layer activations.
40
+ """
41
+
42
+ def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
43
+ """
44
+ Initialize a STESymmetricActivationTrainableQuantizer object with parameters to use for symmetric or power of two quantization.
45
+
46
+ Args:
47
+ quantization_config: trainable quantizer config class
48
+ """
49
+ super().__init__(quantization_config)
50
+ self.power_of_two = quantization_config.activation_quantization_method == QuantizationMethod.POWER_OF_TWO
51
+ self.sign = quantization_config.activation_quantization_params['is_signed']
52
+ np_threshold_values = quantization_config.activation_quantization_params[C.THRESHOLD]
53
+ self.threshold_tensor = torch.Tensor([np_threshold_values])
54
+ self.num_bits = quantization_config.activation_n_bits
55
+
56
+ def initialize_quantization(self,
57
+ tensor_shape: torch.Size,
58
+ name: str,
59
+ layer: PytorchQuantizationWrapper):
60
+ """
61
+ Add quantizer parameters to the quantizer parameters dictionary
62
+
63
+ Args:
64
+ tensor_shape: tensor shape of the quantized tensor.
65
+ name: Tensor name.
66
+ layer: Layer to quantize.
67
+ """
68
+ layer.register_parameter(name, nn.Parameter(to_torch_tensor(self.threshold_tensor),
69
+ requires_grad=True))
70
+
71
+ # save the quantizer added parameters for later calculations
72
+ self.add_quantizer_variable(THRESHOLD_TENSOR, layer.get_parameter(name), VariableGroup.QPARAMS)
73
+
74
+ def __call__(self,
75
+ inputs: torch.Tensor,
76
+ training: bool = True) -> torch.Tensor:
77
+ """
78
+ Quantize a tensor.
79
+
80
+ Args:
81
+ inputs: Input tensor to quantize.
82
+ training: Whether the graph is in training mode.
83
+
84
+ Returns:
85
+ The quantized tensor.
86
+ """
87
+
88
+ _t = self.get_quantizer_variable(THRESHOLD_TENSOR)
89
+ q_tensor = symmetric_quantizer(inputs, _t, self.num_bits, sign=self.sign)
90
+ return q_tensor
91
+
92
+ def convert2inferable(self) -> Union[ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer]:
93
+ """
94
+ Convert quantizer to inferable quantizer.
95
+
96
+ Returns:
97
+ A pytorch inferable quanizer object.
98
+ """
99
+ np_threshold = self.get_quantizer_variable(THRESHOLD_TENSOR).cpu().detach().numpy()
100
+ if self.power_of_two:
101
+ pot_threshold = np.power(2.0, np.ceil(np.log2(np_threshold)))
102
+ return ActivationPOTInferableQuantizer(num_bits=self.num_bits,
103
+ threshold=pot_threshold.tolist(),
104
+ signed=self.sign)
105
+ else:
106
+ return ActivationSymmetricInferableQuantizer(num_bits=self.num_bits,
107
+ threshold=np_threshold.tolist(),
108
+ signed=self.sign)
@@ -0,0 +1,105 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import torch
16
+ from torch import nn
17
+
18
+ from mct_quantizers import mark_quantizer, QuantizationTarget, QuantizationMethod, PytorchQuantizationWrapper
19
+ from mct_quantizers.pytorch.quantizers import ActivationUniformInferableQuantizer
20
+ from model_compression_toolkit import constants as C
21
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
22
+ from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerActivationConfig, TrainingMethod
23
+ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
24
+ from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
25
+ from model_compression_toolkit.trainable_infrastructure.pytorch.activation_quantizers import \
26
+ BasePytorchActivationTrainableQuantizer
27
+ from model_compression_toolkit.trainable_infrastructure.pytorch.quantizer_utils import uniform_quantizer
28
+
29
+
30
+ # moved (and renamed) from model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py
31
+ @mark_quantizer(quantization_target=QuantizationTarget.Activation,
32
+ quantization_method=[QuantizationMethod.UNIFORM],
33
+ identifier=TrainingMethod.STE)
34
+ class STEUniformActivationTrainableQuantizer(BasePytorchActivationTrainableQuantizer):
35
+ """
36
+ Trainable constrained quantizer to quantize a layer activations.
37
+ """
38
+
39
+ def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
40
+ """
41
+ Initialize a STEUniformActivationTrainableQuantizer object with parameters to use for uniform quantization.
42
+
43
+ Args:
44
+ quantization_config: trainable quantizer config class
45
+ """
46
+ super().__init__(quantization_config)
47
+
48
+ np_min_range = quantization_config.activation_quantization_params[C.RANGE_MIN]
49
+ np_max_range = quantization_config.activation_quantization_params[C.RANGE_MAX]
50
+ self.min_range_tensor = torch.Tensor([np_min_range])
51
+ self.max_range_tensor = torch.Tensor([np_max_range])
52
+ self.num_bits = quantization_config.activation_n_bits
53
+
54
+ def initialize_quantization(self,
55
+ tensor_shape: torch.Size,
56
+ name: str,
57
+ layer: PytorchQuantizationWrapper):
58
+ """
59
+ Add quantizer parameters to the quantizer parameters dictionary
60
+
61
+ Args:
62
+ tensor_shape: tensor shape of the quantized tensor.
63
+ name: Tensor name.
64
+ layer: Layer to quantize.
65
+ """
66
+ layer.register_parameter(name+"_"+FQ_MIN, nn.Parameter(to_torch_tensor(self.min_range_tensor),
67
+ requires_grad=True))
68
+ layer.register_parameter(name+"_"+FQ_MAX, nn.Parameter(to_torch_tensor(self.max_range_tensor),
69
+ requires_grad=True))
70
+
71
+ # Save the quantizer parameters for later calculations
72
+ self.add_quantizer_variable(FQ_MIN, layer.get_parameter(name+"_"+FQ_MIN), VariableGroup.QPARAMS)
73
+ self.add_quantizer_variable(FQ_MAX, layer.get_parameter(name+"_"+FQ_MAX), VariableGroup.QPARAMS)
74
+
75
+ def __call__(self,
76
+ inputs: torch.Tensor,
77
+ training: bool = True) -> torch.Tensor:
78
+ """
79
+ Quantize a tensor.
80
+ Args:
81
+ inputs: Input tensor to quantize.
82
+ training: Whether the graph is in training mode.
83
+
84
+ Returns:
85
+ The quantized tensor.
86
+ """
87
+
88
+ _min = self.get_quantizer_variable(FQ_MIN)
89
+ _max = self.get_quantizer_variable(FQ_MAX)
90
+ q_tensor = uniform_quantizer(inputs, _min, _max, self.num_bits)
91
+ return q_tensor
92
+
93
+ def convert2inferable(self) -> ActivationUniformInferableQuantizer:
94
+ """
95
+ Convert quantizer to inferable quantizer.
96
+
97
+ Returns:
98
+ A pytorch inferable quanizer object.
99
+ """
100
+ _min = self.get_quantizer_variable(FQ_MIN).cpu().detach().numpy()
101
+ _max = self.get_quantizer_variable(FQ_MAX).cpu().detach().numpy()
102
+
103
+ return ActivationUniformInferableQuantizer(num_bits=self.num_bits,
104
+ min_range=_min.tolist(),
105
+ max_range=_max.tolist())
@@ -12,31 +12,24 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Union, List
15
+ from typing import List
16
+
17
+ from abc import ABC
16
18
 
17
19
  from model_compression_toolkit.logger import Logger
18
20
  from model_compression_toolkit.verify_packages import FOUND_TORCH
19
21
  from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
20
22
  from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer, VAR, GROUP
21
- from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
22
- TrainableQuantizerActivationConfig
23
23
 
24
24
 
25
25
  if FOUND_TORCH:
26
26
 
27
27
  import torch
28
28
 
29
- class BasePytorchTrainableQuantizer(BaseTrainableQuantizer):
30
- def __init__(self,
31
- quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
32
- """
33
- This class is a base Pytorch quantizer which validates the provided quantization config and defines an
34
- abstract function which any quantizer needs to implement.
35
-
36
- Args:
37
- quantization_config: quantizer config class contains all the information about the quantizer configuration.
38
- """
39
- super().__init__(quantization_config)
29
+ class BasePytorchTrainableQuantizer(BaseTrainableQuantizer, ABC):
30
+ """
31
+ Base class for PyTorch trainable quantizers
32
+ """
40
33
 
41
34
  def get_trainable_variables(self, group: VariableGroup) -> List[torch.Tensor]:
42
35
  """
@@ -12,15 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Tuple
15
+ from typing import Tuple, Union
16
16
  import torch
17
+ from torch import nn
17
18
 
18
19
 
19
20
  def ste_round(x: torch.Tensor) -> torch.Tensor:
20
21
  """
21
22
  Calculate the rounded values of a tensor
23
+
22
24
  Args:
23
25
  x: input variable
26
+
24
27
  Returns:
25
28
  rounded value
26
29
  """
@@ -30,10 +33,12 @@ def ste_round(x: torch.Tensor) -> torch.Tensor:
30
33
  def ste_clip(x: torch.Tensor, min_val=-1.0, max_val=1.0) -> torch.Tensor:
31
34
  """
32
35
  Clip a variable between fixed values such that min_val<=output<=max_val
36
+
33
37
  Args:
34
38
  x: input variable
35
39
  min_val: minimum value for clipping
36
40
  max_val: maximum value for clipping
41
+
37
42
  Returns:
38
43
  clipped variable
39
44
  """
@@ -43,9 +48,11 @@ def ste_clip(x: torch.Tensor, min_val=-1.0, max_val=1.0) -> torch.Tensor:
43
48
  def grad_scale(x: torch.Tensor, scale=1.0) -> torch.Tensor:
44
49
  """
45
50
  Gradient scale
51
+
46
52
  Args:
47
53
  x: input variable
48
54
  scale: scale factor
55
+
49
56
  Returns:
50
57
  x in forward and x*scale in backward (for scaling the gradients).
51
58
  """
@@ -60,11 +67,14 @@ def adjust_range_to_include_zero(range_min: torch.Tensor,
60
67
  Adjusting the quantization range to include representation of 0.0 in the quantization grid.
61
68
  If quantization per-channel, then range_min and range_max should be tensors in the specific shape that allows
62
69
  quantization along the channel_axis.
70
+
63
71
  Args:
64
72
  range_min: min bound of the quantization range (before adjustment).
65
73
  range_max: max bound of the quantization range (before adjustment).
66
74
  n_bits: Number of bits to quantize the tensor.
67
- Returns: adjusted quantization range
75
+
76
+ Returns:
77
+ adjusted quantization range
68
78
  """
69
79
  min_positive = range_min > 0
70
80
  max_negative = range_max < 0
@@ -89,11 +99,13 @@ def symmetric_quantizer(tensor_data: torch.Tensor,
89
99
  """
90
100
  Quantize a tensor according to the number of bits and threshold.
91
101
  Symmetric quantization.
102
+
92
103
  Args:
93
104
  tensor_data: Tensor values to quantize.
94
105
  threshold: threshold for quantization.
95
106
  n_bits: Number of bits to quantize the tensor.
96
107
  sign: sign of tensor_data
108
+
97
109
  Returns:
98
110
  Quantized data.
99
111
  """
@@ -124,11 +136,13 @@ def uniform_quantizer(tensor_data: torch.Tensor,
124
136
  """
125
137
  Quantize a tensor according to given range (min, max) and number of bits.
126
138
  Uniform quantization.
139
+
127
140
  Args:
128
141
  tensor_data: Tensor values to quantize.
129
142
  range_min: minimum bound of the range for quantization (or array of min values per channel).
130
143
  range_max: maximum bound of the range for quantization (or array of max values per channel).
131
144
  n_bits: Number of bits to quantize the tensor.
145
+
132
146
  Returns:
133
147
  Quantized data.
134
148
  """
@@ -147,3 +161,66 @@ def uniform_quantizer(tensor_data: torch.Tensor,
147
161
  # Quantize the data between min/max of quantization range.
148
162
  q = delta_tensor * clipped_tensor + a
149
163
  return q
164
+
165
+
166
+ # moved from model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py
167
+ def symmetric_lsq_quantizer(x: nn.Parameter,
168
+ thresholds: nn.Parameter,
169
+ num_bits: int,
170
+ sign: bool,
171
+ min_int: int,
172
+ max_int: int,
173
+ scale_factor: float) -> Union[nn.Parameter, torch.Tensor]:
174
+ """
175
+ Symmetric quantizer according to LSQ algorithm: https://arxiv.org/pdf/1902.08153.pdf
176
+
177
+ Args:
178
+ x: input to quantize
179
+ thresholds: thresholds of quantization levels
180
+ num_bits: number of bits for quantization
181
+ sign: whether x is signed or not
182
+ min_int: min clipping integer value
183
+ max_int: max clipping integer value
184
+ scale_factor: grad scale of LSQ algorithm
185
+
186
+ Returns:
187
+ A quantized tensor
188
+ """
189
+ delta = thresholds / (2 ** (num_bits - int(sign)))
190
+ delta_scaled = grad_scale(delta, scale_factor)
191
+ rounded = ste_round(x / delta_scaled)
192
+ clipped = torch.clip(rounded, min=min_int, max=max_int)
193
+ quantized = delta_scaled * clipped
194
+ return quantized
195
+
196
+
197
+ # moved from model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py
198
+ def uniform_lsq_quantizer(x: nn.Parameter,
199
+ min_range: nn.Parameter,
200
+ max_range: nn.Parameter,
201
+ num_bits: int,
202
+ min_int: int,
203
+ max_int: int,
204
+ scale_factor: float) -> Union[nn.Parameter, torch.Tensor]:
205
+ """
206
+ Uniform quantizer according to LSQ algorithm: https://arxiv.org/pdf/1902.08153.pdf
207
+
208
+ Args:
209
+ x: input to quantize
210
+ min_range: min range of quantization values
211
+ max_range: min range of quantization values
212
+ num_bits: number of bits for quantization
213
+ min_int: min clipping integer value
214
+ max_int: max clipping integer value
215
+ scale_factor: grad scale of LSQ algorithm
216
+
217
+ Returns:
218
+ A quantized tensor
219
+ """
220
+ a, b = adjust_range_to_include_zero(min_range, max_range, num_bits)
221
+ delta = (b - a) / (2 ** num_bits - 1)
222
+ delta_scaled = grad_scale(delta, scale_factor)
223
+ rounded = ste_round((x - a) / delta_scaled)
224
+ clipped = torch.clip(rounded, min=min_int, max=max_int)
225
+ quantized = delta_scaled * clipped + a
226
+ return quantized