mct-nightly 1.8.0.27022023.post430__py3-none-any.whl → 1.8.0.27032023.post403__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/METADATA +7 -7
  2. {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/RECORD +65 -59
  3. {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +9 -15
  5. model_compression_toolkit/core/common/logger.py +10 -2
  6. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  7. model_compression_toolkit/core/keras/quantization_facade.py +1 -1
  8. model_compression_toolkit/core/pytorch/constants.py +4 -0
  9. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +4 -10
  10. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  11. model_compression_toolkit/exporter/__init__.py +5 -0
  12. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
  13. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +1 -1
  14. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
  15. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -39
  16. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +39 -24
  17. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +50 -42
  18. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -36
  19. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +24 -5
  20. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
  21. model_compression_toolkit/gptq/__init__.py +6 -0
  22. model_compression_toolkit/gptq/common/gptq_config.py +60 -106
  23. model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
  24. model_compression_toolkit/gptq/common/gptq_training.py +28 -38
  25. model_compression_toolkit/gptq/keras/gptq_training.py +10 -28
  26. model_compression_toolkit/gptq/keras/graph_info.py +8 -33
  27. model_compression_toolkit/gptq/keras/quantization_facade.py +6 -12
  28. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +0 -1
  29. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
  30. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  31. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  32. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +22 -128
  33. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +11 -41
  34. model_compression_toolkit/gptq/pytorch/gptq_training.py +12 -4
  35. model_compression_toolkit/gptq/pytorch/graph_info.py +9 -6
  36. model_compression_toolkit/gptq/pytorch/quantization_facade.py +9 -22
  37. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +3 -1
  38. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +0 -20
  39. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -1
  40. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
  41. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  42. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py +14 -0
  43. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  44. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +236 -0
  45. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  46. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +9 -31
  47. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +30 -37
  48. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +27 -36
  49. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +21 -21
  50. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +25 -26
  51. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
  52. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +1 -1
  53. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +12 -0
  54. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
  55. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
  56. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +12 -0
  57. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
  58. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
  59. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +53 -2
  60. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +2 -1
  61. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +22 -4
  62. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +24 -3
  63. model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
  64. {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/LICENSE.md +0 -0
  65. {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/top_level.txt +0 -0
  66. /model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +0 -0
@@ -0,0 +1,196 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import torch
16
+ import torch.nn as nn
17
+ from typing import Dict
18
+ import numpy as np
19
+
20
+ from model_compression_toolkit import quantizers_infrastructure as qi
21
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
22
+ from model_compression_toolkit.gptq.common.gptq_config import RoundingType
23
+ from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
24
+ BasePytorchGPTQTrainableQuantizer
25
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
26
+ from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
27
+ from model_compression_toolkit.gptq.common.gptq_constants import SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
28
+ from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import fix_range_to_include_zero
29
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
30
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
31
+ mark_quantizer
32
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import \
33
+ VariableGroup
34
+ from model_compression_toolkit.core.common.constants import RANGE_MAX, RANGE_MIN
35
+ from model_compression_toolkit.qat.common.constants import FQ_MIN, FQ_MAX
36
+
37
+ def soft_rounding_unifrom_quantizer(input_tensor: torch.Tensor,
38
+ auxvar_tensor: torch.Tensor,
39
+ min_range: torch.Tensor,
40
+ max_range: torch.Tensor,
41
+ num_bits: int) -> torch.Tensor:
42
+ """
43
+ Quantize a tensor uniformly for GPTQ quantizers.
44
+
45
+ Args:
46
+ input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
47
+ auxvar_tensor: Tensor that manifests the bit shift of the quantized weights due to gptq training.
48
+ min_range: Tensor with min values to compute the delta grid.
49
+ max_range: Tensor with max values to compute the delta grid.
50
+ num_bits: Num of bits to use.
51
+
52
+ Returns:
53
+ A quantized tensor.
54
+ """
55
+ # adjusts the quantization range so the quantization grid includes zero.
56
+ min_range, max_range = fix_range_to_include_zero(min_range, max_range, num_bits)
57
+ delta = qutils.calculate_delta_uniform(max_range, min_range, num_bits)
58
+ with torch.no_grad():
59
+ input_tensor_int = torch.floor(input_tensor / delta)
60
+ tensor_q = input_tensor_int + auxvar_tensor
61
+ return delta * qutils.ste_clip(tensor_q,
62
+ min_val=0,
63
+ max_val=2 ** num_bits - 1)
64
+
65
+
66
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
67
+ quantization_method=[QuantizationMethod.UNIFORM],
68
+ quantizer_type=RoundingType.SoftQuantizer)
69
+ class UniformSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
70
+ """
71
+ Trainable uniform quantizer to optimize the rounding of the quantized values using a soft quantization method.
72
+ """
73
+
74
+ def __init__(self,
75
+ quantization_config: TrainableQuantizerWeightsConfig,
76
+ quantization_parameter_learning: bool = False):
77
+ """
78
+ Construct a Pytorch model that utilize a fake weight quantizer of soft-quantizer for symmetric quantizer.
79
+
80
+ Args:
81
+ quantization_config: Trainable weights quantizer config.
82
+ quantization_parameter_learning (Bool): Whether to learn the min/max ranges or not
83
+ """
84
+
85
+ super().__init__(quantization_config)
86
+ self.num_bits = quantization_config.weights_n_bits
87
+ self.per_channel = quantization_config.weights_per_channel_threshold
88
+
89
+ self.min_values = quantization_config.weights_quantization_params[RANGE_MIN]
90
+ self.max_values = quantization_config.weights_quantization_params[RANGE_MAX]
91
+
92
+ self.quantization_axis = quantization_config.weights_channels_axis
93
+ self.quantization_parameter_learning = quantization_parameter_learning
94
+
95
+ # gamma and zeta are stretch parameters for computing the rectified sigmoid function.
96
+ # See: https://arxiv.org/pdf/2004.10568.pdf
97
+ self.gamma = SOFT_ROUNDING_GAMMA
98
+ self.zeta = SOFT_ROUNDING_ZETA
99
+
100
+ def initialize_quantization(self,
101
+ tensor_shape: torch.Size,
102
+ name: str,
103
+ layer: qi.PytorchQuantizationWrapper):
104
+ """
105
+ Add quantizer parameters to the quantizer parameters dictionary
106
+
107
+ Args:
108
+ tensor_shape: tensor shape of the quantized tensor.
109
+ name: Tensor name.
110
+ layer: Layer to quantize.
111
+ """
112
+
113
+ # Add min and max variables to layer.
114
+ if self.per_channel:
115
+ min_values = to_torch_tensor(self.min_values)
116
+ max_values = to_torch_tensor(self.max_values)
117
+ else:
118
+ min_values = torch.tensor(self.min_values)
119
+ max_values = torch.tensor(self.max_values)
120
+
121
+ layer.register_parameter(name+"_"+FQ_MIN, nn.Parameter(min_values, requires_grad=self.quantization_parameter_learning))
122
+ layer.register_parameter(name+"_"+FQ_MAX, nn.Parameter(max_values, requires_grad=self.quantization_parameter_learning))
123
+
124
+ w = layer.layer.weight
125
+ delta = qutils.calculate_delta_uniform(max_values, min_values, self.num_bits)
126
+ w_clipped_normed = torch.clip(w / delta, 0, 2 ** self.num_bits - 1)
127
+ rest = w_clipped_normed - torch.floor(w_clipped_normed) # rest of rounding [0, 1)
128
+ alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest
129
+ layer.register_parameter(f"{name}_{AUXVAR}", nn.Parameter(alpha, requires_grad=True))
130
+
131
+ # Save the quantizer parameters
132
+ self.add_quantizer_variable(FQ_MIN, layer.get_parameter(name+"_"+FQ_MIN), VariableGroup.QPARAMS)
133
+ self.add_quantizer_variable(FQ_MAX, layer.get_parameter(name+"_"+FQ_MAX), VariableGroup.QPARAMS)
134
+ self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
135
+
136
+
137
+ def get_soft_targets(self) -> torch.Tensor:
138
+ """
139
+ Computes the rectified sigmoid function for the quantization target parameters.
140
+
141
+ Returns:
142
+ A tensor with the soft rounding targets values.
143
+
144
+ """
145
+ scaled_sigmoid = torch.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma
146
+ return torch.clip(scaled_sigmoid, min=0, max=1)
147
+
148
+ def get_quant_config(self) -> Dict[str, np.ndarray]:
149
+ """
150
+ Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
151
+
152
+ Returns:
153
+ A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
154
+ Keys must match NodeQuantizationConfig attributes
155
+
156
+ """
157
+ min_values = torch_tensor_to_numpy(self.get_quantizer_variable(FQ_MIN))
158
+ max_values = torch_tensor_to_numpy(self.get_quantizer_variable(FQ_MAX))
159
+ return {RANGE_MIN: min_values,
160
+ RANGE_MAX: max_values}
161
+
162
+ def __call__(self,
163
+ inputs: nn.Parameter,
164
+ training: bool) -> torch.Tensor:
165
+ """
166
+ Quantize a tensor.
167
+
168
+ Args:
169
+ inputs: Input tensor to quantize.
170
+ training: whether in training mode or not
171
+
172
+ Returns:
173
+ quantized tensor
174
+ """
175
+ auxvar = self.get_quantizer_variable(AUXVAR)
176
+ min_range = self.get_quantizer_variable(FQ_MIN)
177
+ max_range = self.get_quantizer_variable(FQ_MAX)
178
+
179
+ #####################################################
180
+ # Soft Rounding
181
+ #####################################################
182
+ aux_var = self.get_soft_targets()
183
+ if not training:
184
+ aux_var = (aux_var >= 0.5).to(auxvar.dtype)
185
+
186
+ #####################################################
187
+ # Quantized Input
188
+ #####################################################
189
+ q_tensor = soft_rounding_unifrom_quantizer(input_tensor=inputs,
190
+ auxvar_tensor=aux_var,
191
+ min_range=min_range,
192
+ max_range=max_range,
193
+ num_bits=self.num_bits)
194
+
195
+
196
+ return q_tensor
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  import torch
16
16
  import torch.nn as nn
17
- from typing import List, Dict
17
+ from typing import Dict
18
18
  import numpy as np
19
19
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
20
20
 
@@ -30,7 +30,7 @@ from model_compression_toolkit.core.common.constants import THRESHOLD
30
30
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
31
31
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
32
32
  mark_quantizer
33
-
33
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
34
34
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
35
35
  get_threshold_reshape_shape
36
36
 
@@ -104,23 +104,19 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
104
104
  self.quantization_axis = quantization_config.weights_channels_axis
105
105
  self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
106
106
  self.max_lsbs_change = max_lsbs_change_map.get(self.num_bits)
107
- self.quantizer_parameters = {}
108
107
 
109
108
 
110
109
  def initialize_quantization(self,
111
110
  tensor_shape: torch.Size,
112
111
  name: str,
113
- layer: qi.PytorchQuantizationWrapper) -> Dict[str, nn.Parameter]:
112
+ layer: qi.PytorchQuantizationWrapper):
114
113
  """
115
- Return a dictionary of quantizer parameters and their names.
114
+ Add quantizer parameters to the quantizer parameters dictionary
116
115
 
117
116
  Args:
118
117
  tensor_shape: tensor shape of the quantized tensor.
119
118
  name: Tensor name.
120
119
  layer: Layer to quantize.
121
-
122
- Returns:
123
- Dictionary of parameters names to the variables.
124
120
  """
125
121
 
126
122
  layer.register_parameter(f"{name}_{PTQ_THRESHOLD}",
@@ -131,27 +127,9 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
131
127
  requires_grad=True))
132
128
 
133
129
  # save the quantizer added parameters for later calculations
134
- self.quantizer_parameters = {PTQ_THRESHOLD: layer.get_parameter(f"{name}_{PTQ_THRESHOLD}"),
135
- AUXVAR: layer.get_parameter(f"{name}_{AUXVAR}")}
136
-
137
- return self.quantizer_parameters
138
-
139
-
140
- def get_aux_variable(self) -> List[torch.Tensor]:
141
- """
142
- This function return a list with the quantizer's quantization auxiliary variables.
143
-
144
- Returns: A list with the quantization auxiliary variables.
145
- """
146
- return [self.quantizer_parameters.get(AUXVAR)]
130
+ self.add_quantizer_variable(PTQ_THRESHOLD, layer.get_parameter(f"{name}_{PTQ_THRESHOLD}"), VariableGroup.QPARAMS)
131
+ self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
147
132
 
148
- def get_quantization_variable(self) -> List[torch.Tensor]:
149
- """
150
- This function return a list with the quantizer's quantization parameters variables.
151
-
152
- Returns: A list with the quantization parameters.
153
- """
154
- return [self.quantizer_parameters.get(PTQ_THRESHOLD)]
155
133
 
156
134
  def get_quant_config(self) -> Dict[str, np.ndarray]:
157
135
  """
@@ -162,7 +140,7 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
162
140
  Keys must match NodeQuantizationConfig attributes
163
141
 
164
142
  """
165
- old_threshold = self.quantizer_parameters[PTQ_THRESHOLD]
143
+ old_threshold = self.get_quantizer_variable(PTQ_THRESHOLD)
166
144
  return {THRESHOLD: torch_tensor_to_numpy(old_threshold).reshape(self.threshold_shape)}
167
145
 
168
146
  def __call__(self,
@@ -178,8 +156,8 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
178
156
  Returns:
179
157
  quantized tensor
180
158
  """
181
- auxvar = self.quantizer_parameters[AUXVAR]
182
- ptq_threshold_tensor = self.quantizer_parameters[PTQ_THRESHOLD]
159
+ auxvar = self.get_quantizer_variable(AUXVAR)
160
+ ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
183
161
 
184
162
  if self.per_channel:
185
163
  reshape_shape = get_threshold_reshape_shape(inputs.shape,
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Dict, Union
16
+ from typing import Union
17
17
 
18
18
  import numpy as np
19
19
  import tensorflow as tf
@@ -32,6 +32,7 @@ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructur
32
32
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import \
33
33
  WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer, ActivationPOTInferableQuantizer, \
34
34
  ActivationSymmetricInferableQuantizer
35
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
35
36
 
36
37
 
37
38
  @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
@@ -76,22 +77,19 @@ class STEWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
76
77
  max_int = (2 ** (self.num_bits - int(C.WEIGHTS_SIGNED))) - 1
77
78
  self.min = delta * min_int
78
79
  self.max = delta * max_int
79
- self.quantizer_parameters = {}
80
+
80
81
 
81
82
  def initialize_quantization(self,
82
83
  tensor_shape: TensorShape,
83
84
  name: str,
84
- layer: qi.KerasQuantizationWrapper) -> Dict[str, tf.Variable]:
85
+ layer: qi.KerasQuantizationWrapper):
85
86
  """
86
- Add min and max variables to layer.
87
- Args:
88
- tensor_shape: Tensor shape the quantizer quantize.
89
- name: Prefix of variables names.
90
- layer: Layer to add the variables to. The variables are saved
91
- in the layer's scope.
87
+ Add quantizer parameters to the quantizer parameters dictionary
92
88
 
93
- Returns:
94
- Dictionary of new variables.
89
+ Args:
90
+ tensor_shape: tensor shape of the quantized tensor.
91
+ name: Tensor name.
92
+ layer: Layer to quantize.
95
93
  """
96
94
  ptq_threshold_tensor = layer.add_weight(
97
95
  name + THRESHOLD_TENSOR,
@@ -115,9 +113,9 @@ class STEWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
115
113
  fq_max.assign(self.max)
116
114
 
117
115
  # save the quantizer added parameters for later calculations
118
- self.quantizer_parameters = {THRESHOLD_TENSOR: ptq_threshold_tensor,
119
- FQ_MIN: fq_min, FQ_MAX: fq_max}
120
- return self.quantizer_parameters
116
+ self.add_quantizer_variable(THRESHOLD_TENSOR, ptq_threshold_tensor, VariableGroup.QPARAMS)
117
+ self.add_quantizer_variable(FQ_MIN, fq_min, VariableGroup.QPARAMS)
118
+ self.add_quantizer_variable(FQ_MAX, fq_max, VariableGroup.QPARAMS)
121
119
 
122
120
  def __call__(self,
123
121
  inputs: tf.Tensor,
@@ -134,8 +132,8 @@ class STEWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
134
132
  The quantized tensor.
135
133
  """
136
134
 
137
- _min = self.quantizer_parameters[FQ_MIN]
138
- _max = self.quantizer_parameters[FQ_MAX]
135
+ _min = self.get_quantizer_variable(FQ_MIN)
136
+ _max = self.get_quantizer_variable(FQ_MAX)
139
137
  if self.channel_axis:
140
138
  if self.perm_vec:
141
139
  inputs = tf.transpose(inputs, perm=self.perm_vec)
@@ -157,7 +155,7 @@ class STEWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
157
155
  BaseKerasInferableQuantizer object.
158
156
  """
159
157
  if self.power_of_two:
160
- pot_threshold = 2 ** np.ceil(np.log2(self.quantizer_parameters[THRESHOLD_TENSOR]))
158
+ pot_threshold = 2 ** np.ceil(np.log2(self.get_quantizer_variable(THRESHOLD_TENSOR)))
161
159
  return WeightsPOTInferableQuantizer(num_bits=self.num_bits,
162
160
  threshold=list(pot_threshold.flatten()),
163
161
  per_channel=self.per_channel,
@@ -165,8 +163,7 @@ class STEWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
165
163
  input_rank=len(self.threshold_shape))
166
164
  else:
167
165
  return WeightsSymmetricInferableQuantizer(num_bits=self.num_bits,
168
- threshold=list(self.quantizer_parameters[
169
- THRESHOLD_TENSOR].numpy().flatten()),
166
+ threshold=list(self.get_quantizer_variable(THRESHOLD_TENSOR).numpy().flatten()),
170
167
  per_channel=self.per_channel,
171
168
  channel_axis=self.channel_axis,
172
169
  input_rank=len(self.threshold_shape))
@@ -203,22 +200,18 @@ class STEActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
203
200
  max_int = (2 ** (self.num_bits - int(self.signed))) - 1
204
201
  self.min = delta * min_int
205
202
  self.max = delta * max_int
206
- self.quantizer_parameters = {}
207
203
 
208
204
  def initialize_quantization(self,
209
205
  tensor_shape: TensorShape,
210
206
  name: str,
211
- layer: qi.KerasQuantizationWrapper) -> Dict[str, tf.Variable]:
207
+ layer: qi.KerasQuantizationWrapper):
212
208
  """
213
- Add min and max variables to layer.
214
- Args:
215
- tensor_shape: Tensor shape the quantizer quantize.
216
- name: Prefix of variables names.
217
- layer: Layer to add the variables to. The variables are saved
218
- in the layer's scope.
209
+ Add quantizer parameters to the quantizer parameters dictionary
219
210
 
220
- Returns:
221
- Dictionary of new variables.
211
+ Args:
212
+ tensor_shape: tensor shape of the quantized tensor.
213
+ name: Tensor name.
214
+ layer: Layer to quantize.
222
215
  """
223
216
  ptq_threshold_tensor = layer.add_weight(
224
217
  name + THRESHOLD_TENSOR,
@@ -242,9 +235,10 @@ class STEActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
242
235
  fq_max.assign(self.max)
243
236
 
244
237
  # save the quantizer added parameters for later calculations
245
- self.quantizer_parameters = {THRESHOLD_TENSOR: ptq_threshold_tensor,
246
- FQ_MIN: fq_min, FQ_MAX: fq_max}
247
- return self.quantizer_parameters
238
+ self.add_quantizer_variable(THRESHOLD_TENSOR, ptq_threshold_tensor, VariableGroup.QPARAMS)
239
+ self.add_quantizer_variable(FQ_MIN, fq_min, VariableGroup.QPARAMS)
240
+ self.add_quantizer_variable(FQ_MAX, fq_max, VariableGroup.QPARAMS)
241
+
248
242
 
249
243
  def __call__(self,
250
244
  inputs: tf.Tensor,
@@ -259,8 +253,8 @@ class STEActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
259
253
  The quantized tensor.
260
254
  """
261
255
 
262
- _min = self.quantizer_parameters[FQ_MIN]
263
- _max = self.quantizer_parameters[FQ_MAX]
256
+ _min = self.get_quantizer_variable(FQ_MIN)
257
+ _max = self.get_quantizer_variable(FQ_MAX)
264
258
  q_tensor = tf.quantization.fake_quant_with_min_max_vars(inputs, _min, _max,
265
259
  num_bits=self.num_bits)
266
260
 
@@ -275,7 +269,7 @@ class STEActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
275
269
  """
276
270
 
277
271
  if self.power_of_two:
278
- pot_threshold = 2 ** np.ceil(np.log2(self.quantizer_parameters[THRESHOLD_TENSOR]))
272
+ pot_threshold = 2 ** np.ceil(np.log2(self.get_quantizer_variable(THRESHOLD_TENSOR)))
279
273
  return ActivationPOTInferableQuantizer(num_bits=self.num_bits,
280
274
  # In activation quantization is per-tensor only - thus we pass
281
275
  # the threshold as a list with a len of 1
@@ -285,6 +279,5 @@ class STEActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
285
279
  return ActivationSymmetricInferableQuantizer(num_bits=self.num_bits,
286
280
  # In activation quantization is per-tensor only - thus we
287
281
  # pass the threshold as a list with a len of 1
288
- threshold=[
289
- self.quantizer_parameters[THRESHOLD_TENSOR].numpy()],
282
+ threshold=[self.get_quantizer_variable(THRESHOLD_TENSOR).numpy()],
290
283
  signed=self.signed)
@@ -12,9 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
16
- from typing import Dict
17
-
18
15
  import numpy as np
19
16
  import tensorflow as tf
20
17
  from tensorflow.python.framework.tensor_shape import TensorShape
@@ -32,6 +29,7 @@ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructur
32
29
  mark_quantizer
33
30
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import \
34
31
  BaseKerasInferableQuantizer, WeightsUniformInferableQuantizer, ActivationUniformInferableQuantizer
32
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
35
33
 
36
34
 
37
35
  @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
@@ -70,22 +68,17 @@ class STEUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
70
68
  else:
71
69
  self.perm_vec = None
72
70
 
73
- self.quantizer_parameters = {}
74
-
75
71
  def initialize_quantization(self,
76
72
  tensor_shape: TensorShape,
77
73
  name: str,
78
- layer: qi.KerasQuantizationWrapper) -> Dict[str, tf.Variable]:
74
+ layer: qi.KerasQuantizationWrapper):
79
75
  """
80
- Add min and max variables to layer.
81
- Args:
82
- tensor_shape: Tensor shape the quantizer quantize.
83
- name: Prefix of variables names.
84
- layer: Layer to add the variables to. The variables are saved
85
- in the layer's scope.
76
+ Add quantizer parameters to the quantizer parameters dictionary
86
77
 
87
- Returns:
88
- Dictionary of new variables.
78
+ Args:
79
+ tensor_shape: tensor shape of the quantized tensor.
80
+ name: Tensor name.
81
+ layer: Layer to quantize.
89
82
  """
90
83
  fq_min = layer.add_weight(
91
84
  name + FQ_MIN,
@@ -102,8 +95,9 @@ class STEUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
102
95
  fq_max.assign(self.max)
103
96
 
104
97
  # save the quantizer added parameters for later calculations
105
- self.quantizer_parameters = {FQ_MIN: fq_min, FQ_MAX: fq_max}
106
- return self.quantizer_parameters
98
+ self.add_quantizer_variable(FQ_MIN, fq_min, VariableGroup.QPARAMS)
99
+ self.add_quantizer_variable(FQ_MAX, fq_max, VariableGroup.QPARAMS)
100
+
107
101
 
108
102
  def __call__(self, inputs: tf.Tensor,
109
103
  training: bool):
@@ -117,8 +111,8 @@ class STEUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
117
111
  The quantized tensor.
118
112
  """
119
113
 
120
- _min = self.quantizer_parameters[FQ_MIN]
121
- _max = self.quantizer_parameters[FQ_MAX]
114
+ _min = self.get_quantizer_variable(FQ_MIN)
115
+ _max = self.get_quantizer_variable(FQ_MAX)
122
116
  _min, _max = adjust_range_to_include_zero(_min, _max, self.num_bits)
123
117
 
124
118
  if self.per_channel:
@@ -142,8 +136,8 @@ class STEUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
142
136
  Returns:
143
137
  BaseKerasInferableQuantizer object.
144
138
  """
145
- min_range, max_range = fix_range_to_include_zero(self.quantizer_parameters[FQ_MIN].numpy(),
146
- self.quantizer_parameters[FQ_MAX].numpy(),
139
+ min_range, max_range = fix_range_to_include_zero(self.get_quantizer_variable(FQ_MIN).numpy(),
140
+ self.get_quantizer_variable(FQ_MAX).numpy(),
147
141
  self.num_bits)
148
142
  return WeightsUniformInferableQuantizer(num_bits=self.num_bits,
149
143
  min_range=list(min_range.flatten()),
@@ -174,22 +168,18 @@ class STEUniformActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
174
168
  self.num_bits = quantization_config.activation_n_bits
175
169
  self.min_range = quantization_config.activation_quantization_params[C.RANGE_MIN]
176
170
  self.max_range = quantization_config.activation_quantization_params[C.RANGE_MAX]
177
- self.quantizer_parameters = {}
178
171
 
179
172
  def initialize_quantization(self,
180
173
  tensor_shape: TensorShape,
181
174
  name: str,
182
- layer: qi.KerasQuantizationWrapper) -> Dict[str, tf.Variable]:
175
+ layer: qi.KerasQuantizationWrapper):
183
176
  """
184
- Add min and max variables to layer.
185
- Args:
186
- tensor_shape: Tensor shape the quantizer quantize.
187
- name: Prefix of variables names.
188
- layer: Layer to add the variables to. The variables are saved
189
- in the layer's scope.
177
+ Add quantizer parameters to the quantizer parameters dictionary
190
178
 
191
- Returns:
192
- Dictionary of new variables.
179
+ Args:
180
+ tensor_shape: tensor shape of the quantized tensor.
181
+ name: Tensor name.
182
+ layer: Layer to quantize.
193
183
  """
194
184
  fq_min = layer.add_weight(
195
185
  name + FQ_MIN,
@@ -206,8 +196,9 @@ class STEUniformActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
206
196
  fq_max.assign(self.max_range)
207
197
 
208
198
  # save the quantizer added parameters for later calculations
209
- self.quantizer_parameters = {FQ_MIN: fq_min, FQ_MAX: fq_max}
210
- return self.quantizer_parameters
199
+ self.add_quantizer_variable(FQ_MIN, fq_min, VariableGroup.QPARAMS)
200
+ self.add_quantizer_variable(FQ_MAX, fq_max, VariableGroup.QPARAMS)
201
+
211
202
 
212
203
  def __call__(self,
213
204
  inputs: tf.Tensor,
@@ -222,8 +213,8 @@ class STEUniformActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
222
213
  The quantized tensor.
223
214
  """
224
215
 
225
- _min = self.quantizer_parameters[FQ_MIN]
226
- _max = self.quantizer_parameters[FQ_MAX]
216
+ _min = self.get_quantizer_variable(FQ_MIN)
217
+ _max = self.get_quantizer_variable(FQ_MAX)
227
218
  _min, _max = adjust_range_to_include_zero(_min, _max, self.num_bits)
228
219
  q_tensor = tf.quantization.fake_quant_with_min_max_vars(inputs, _min, _max,
229
220
  num_bits=self.num_bits)
@@ -237,8 +228,8 @@ class STEUniformActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
237
228
  Returns:
238
229
  BaseKerasInferableQuantizer object.
239
230
  """
240
- min_range, max_range = fix_range_to_include_zero(self.quantizer_parameters[FQ_MIN].numpy(),
241
- self.quantizer_parameters[FQ_MAX].numpy(),
231
+ min_range, max_range = fix_range_to_include_zero(self.get_quantizer_variable(FQ_MIN).numpy(),
232
+ self.get_quantizer_variable(FQ_MAX).numpy(),
242
233
  self.num_bits)
243
234
  return ActivationUniformInferableQuantizer(num_bits=self.num_bits,
244
235
  # In activation quantization is per-tensor only - thus we pass