mct-nightly 1.8.0.27022023.post430__py3-none-any.whl → 1.8.0.27032023.post403__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/METADATA +7 -7
  2. {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/RECORD +65 -59
  3. {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +9 -15
  5. model_compression_toolkit/core/common/logger.py +10 -2
  6. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  7. model_compression_toolkit/core/keras/quantization_facade.py +1 -1
  8. model_compression_toolkit/core/pytorch/constants.py +4 -0
  9. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +4 -10
  10. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  11. model_compression_toolkit/exporter/__init__.py +5 -0
  12. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
  13. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +1 -1
  14. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
  15. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -39
  16. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +39 -24
  17. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +50 -42
  18. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -36
  19. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +24 -5
  20. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
  21. model_compression_toolkit/gptq/__init__.py +6 -0
  22. model_compression_toolkit/gptq/common/gptq_config.py +60 -106
  23. model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
  24. model_compression_toolkit/gptq/common/gptq_training.py +28 -38
  25. model_compression_toolkit/gptq/keras/gptq_training.py +10 -28
  26. model_compression_toolkit/gptq/keras/graph_info.py +8 -33
  27. model_compression_toolkit/gptq/keras/quantization_facade.py +6 -12
  28. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +0 -1
  29. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
  30. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  31. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  32. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +22 -128
  33. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +11 -41
  34. model_compression_toolkit/gptq/pytorch/gptq_training.py +12 -4
  35. model_compression_toolkit/gptq/pytorch/graph_info.py +9 -6
  36. model_compression_toolkit/gptq/pytorch/quantization_facade.py +9 -22
  37. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +3 -1
  38. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +0 -20
  39. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -1
  40. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
  41. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  42. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py +14 -0
  43. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  44. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +236 -0
  45. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  46. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +9 -31
  47. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +30 -37
  48. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +27 -36
  49. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +21 -21
  50. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +25 -26
  51. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
  52. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +1 -1
  53. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +12 -0
  54. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
  55. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
  56. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +12 -0
  57. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
  58. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
  59. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +53 -2
  60. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +2 -1
  61. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +22 -4
  62. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +24 -3
  63. model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
  64. {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/LICENSE.md +0 -0
  65. {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/top_level.txt +0 -0
  66. /model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +0 -0
@@ -16,9 +16,7 @@ from enum import Enum
16
16
  from typing import Callable, Any, Dict
17
17
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
18
18
  from model_compression_toolkit.core import common
19
- from model_compression_toolkit.gptq.common.gptq_constants import N_BATCHES_STR, QUANT_PARAM_LEARNING_STR, N_EPOCHS_STR, \
20
- MAX_LSB_STR
21
- from model_compression_toolkit.gptq.common.gptq_quantizer_config import GPTQQuantizerConfig, SoftQuantizerConfig
19
+ from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR, MAX_LSB_STR, REG_DEFAULT
22
20
 
23
21
 
24
22
  class RoundingType(Enum):
@@ -31,30 +29,53 @@ class RoundingType(Enum):
31
29
  SoftQuantizer = 1
32
30
 
33
31
 
32
+ class GPTQHessianWeightsConfig:
33
+ """
34
+ Configuration to use for computing the Hessian-based weights for GPTQ loss metric.
35
+ """
36
+
37
+ def __init__(self,
38
+ hessians_num_samples: int = 16,
39
+ norm_weights: bool = True,
40
+ log_norm: bool = True,
41
+ scale_log_norm: bool = False,
42
+ hessians_n_iter: int = 50):
43
+
44
+ """
45
+ Initialize a GPTQHessianWeightsConfig.
46
+ Args:
47
+ hessians_num_samples (int): Number of samples to use for computing the Hessian-based weights.
48
+ norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
49
+ log_norm (bool): Whether to use log normalization to the GPTQ Hessian-based weights.
50
+ scale_log_norm (bool): Whether to scale the final vector of the Hessian weights.
51
+ hessians_n_iter (int): Number of random iterations to run Hessian approximation for GPTQ weights.
52
+ """
53
+
54
+ self.hessians_num_samples = hessians_num_samples
55
+ self.norm_weights = norm_weights
56
+ self.log_norm = log_norm
57
+ self.scale_log_norm = scale_log_norm
58
+ self.hessians_n_iter = hessians_n_iter
59
+
60
+
34
61
  class GradientPTQConfig:
35
62
  """
36
63
  Configuration to use for quantization with GradientPTQ (experimental).
37
64
  """
38
65
 
39
- def __init__(self,
40
- n_iter: int,
66
+ def __init__(self, n_iter: int,
41
67
  optimizer: Any,
42
68
  optimizer_rest: Any = None,
43
69
  loss: Callable = None,
44
70
  log_function: Callable = None,
45
71
  train_bias: bool = True,
46
- quantization_parameters_learning: bool = False,
47
72
  rounding_type: RoundingType = RoundingType.SoftQuantizer,
48
- lsb_change_per_bit_width: dict = DefaultDict({}, lambda: 1),
49
- eps: float = 1e-6,
50
- use_jac_based_weights: bool = True,
51
- num_samples_for_loss: int = 16,
52
- norm_weights: bool = False,
73
+ use_hessian_based_weights: bool = True,
53
74
  optimizer_quantization_parameter: Any = None,
54
75
  optimizer_bias: Any = None,
55
- log_norm: bool = True,
56
- weights_n_iter: int = 50,
57
- quantizer_config: GPTQQuantizerConfig = SoftQuantizerConfig()):
76
+ regularization_factor: float = REG_DEFAULT,
77
+ hessian_weights_config: GPTQHessianWeightsConfig = GPTQHessianWeightsConfig(),
78
+ gptq_quantizer_params_override: Dict[str, Any] = None):
58
79
  """
59
80
  Initialize a GradientPTQConfig.
60
81
 
@@ -67,18 +88,13 @@ class GradientPTQConfig:
67
88
  accordingly. see example in multiple_tensors_mse_loss
68
89
  log_function (Callable): Function to log information about the GPTQ process.
69
90
  train_bias (bool): Whether to update the bias during the training or not.
70
- quantization_parameters_learning (bool): Whether to update the quantization param during the training or not.
71
91
  rounding_type (RoundingType): An enum that defines the rounding type.
72
- lsb_change_per_bit_width (dict): Whether to update the bias during the training or not.
73
- eps (float): A floating point value for numeric stability.
74
- use_jac_based_weights (bool): Whether to use jacobian-based weights for weighted average loss.
75
- num_samples_for_loss (int): Number of samples to use for computing the jacobian-based weights.
76
- norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
92
+ use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
77
93
  optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters.
78
- optimizer_bias (Any): Optimizer to override the rest optimizerfor bias.
79
- log_norm (bool): Whether to use log normalization to the GPTQ Jacobian-based weights.
80
- weights_n_iter (int): Number of random iterations to run Jacobian approximation for GPTQ weights.
81
- quantizer_config (GPTQQuantizerConfig): A class that contains the quantizer specific config.
94
+ optimizer_bias (Any): Optimizer to override the rest optimizer for bias.
95
+ regularization_factor (float): A floating point number that defines the regularization factor.
96
+ hessian_weights_config (GPTQHessianWeightsConfig): A configuration that include all necessary arguments to run a computation of Hessian weights for the GPTQ loss.
97
+ gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
82
98
 
83
99
  """
84
100
  self.n_iter = n_iter
@@ -88,69 +104,36 @@ class GradientPTQConfig:
88
104
  self.log_function = log_function
89
105
  self.train_bias = train_bias
90
106
 
91
- if quantization_parameters_learning and rounding_type == RoundingType.STE:
92
- common.Logger.error("Quantization parameters learning is not supported with STE rounding.")
93
-
94
- self.quantization_parameters_learning = quantization_parameters_learning
95
107
  self.rounding_type = rounding_type
96
- self.lsb_change_per_bit_width = lsb_change_per_bit_width
97
- self.eps = eps
98
- self.use_jac_based_weights = use_jac_based_weights
99
- self.num_samples_for_loss = num_samples_for_loss
100
- self.norm_weights = norm_weights
108
+ self.use_hessian_based_weights = use_hessian_based_weights
101
109
  self.optimizer_quantization_parameter = optimizer_quantization_parameter
102
110
  self.optimizer_bias = optimizer_bias
103
- self.log_norm = log_norm
104
- self.weights_n_iter = weights_n_iter
111
+ self.regularization_factor = regularization_factor
112
+ self.hessian_weights_config = hessian_weights_config
105
113
 
106
- if self._verify_quantizer_config(quantizer_config, rounding_type):
107
- self.quantizer_config = quantizer_config
108
- else:
109
- common.Logger.error(f"Quantizer config of type {type(quantizer_config)} "
110
- f"is not suitable for rounding type {rounding_type}")
111
-
112
-
113
- def _verify_quantizer_config(self, quantizer_config, rounding_type) -> bool:
114
- """
115
- Verifies that the given quantizer config matches the given rounding type.
116
-
117
- Args:
118
- quantizer_config: A quantizer config.
119
- rounding_type: A RoundingType.
120
-
121
- Returns: True if the quantizer config matches the rounding type, False otherwise.
122
-
123
- """
124
- if rounding_type == RoundingType.SoftQuantizer:
125
- return type(quantizer_config) == SoftQuantizerConfig
126
-
127
- # Here, we compare type() and not isinstance to exclude instance equality because of inheritance
128
- return type(quantizer_config) == GPTQQuantizerConfig
114
+ # Since the default quantizer is soft quantizer, we initialize the gptq_quantizer_params_override dictionary
115
+ # with its extended params
116
+ self.gptq_quantizer_params_override = {QUANT_PARAM_LEARNING_STR: False} \
117
+ if gptq_quantizer_params_override is None else gptq_quantizer_params_override
129
118
 
130
119
 
131
120
  class GradientPTQConfigV2(GradientPTQConfig):
132
121
  """
133
122
  Configuration to use for quantization with GradientPTQV2 (experimental).
134
123
  """
135
- def __init__(self,
136
- n_epochs: int,
124
+ def __init__(self, n_epochs: int,
137
125
  optimizer: Any,
138
126
  optimizer_rest: Any = None,
139
127
  loss: Callable = None,
140
128
  log_function: Callable = None,
141
129
  train_bias: bool = True,
142
- quantization_parameters_learning: bool = False,
143
130
  rounding_type: RoundingType = RoundingType.SoftQuantizer,
144
- lsb_change_per_bit_width: dict = DefaultDict({}, lambda: 1),
145
- eps: float = 1e-6,
146
- use_jac_based_weights: bool = True,
147
- num_samples_for_loss: int = 16,
148
- norm_weights: bool = False,
131
+ use_hessian_based_weights: bool = True,
149
132
  optimizer_quantization_parameter: Any = None,
150
133
  optimizer_bias: Any = None,
151
- log_norm: bool = True,
152
- weights_n_iter: int = 50,
153
- quantizer_config: GPTQQuantizerConfig = SoftQuantizerConfig()):
134
+ regularization_factor: float = REG_DEFAULT,
135
+ hessian_weights_config: GPTQHessianWeightsConfig = GPTQHessianWeightsConfig(),
136
+ gptq_quantizer_params_override: Dict[str, Any] = None):
154
137
  """
155
138
  Initialize a GradientPTQConfigV2.
156
139
 
@@ -163,18 +146,13 @@ class GradientPTQConfigV2(GradientPTQConfig):
163
146
  accordingly. see example in multiple_tensors_mse_loss
164
147
  log_function (Callable): Function to log information about the GPTQ process.
165
148
  train_bias (bool): Whether to update the bias during the training or not.
166
- quantization_parameters_learning (bool): Whether to update the quantization param during the training or not.
167
149
  rounding_type (RoundingType): An enum that defines the rounding type.
168
- lsb_change_per_bit_width (dict): Whether to update the bias during the training or not.
169
- eps (float): A floating point value for numeric stability.
170
- use_jac_based_weights (bool): Whether to use jacobian-based weights for weighted average loss.
171
- num_samples_for_loss (int): Number of samples to use for computing the jacobian-based weights.
172
- norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
150
+ use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
173
151
  optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters.
174
152
  optimizer_bias (Any): Optimizer to override the rest optimizerfor bias.
175
- log_norm (bool): Whether to use log normalization to the GPTQ Jacobian-based weights.
176
- weights_n_iter (int): Number of random iterations to run Jacobian approximation for GPTQ weights.
177
- quantizer_config (Any): A class that contains the quantizer specific config.
153
+ regularization_factor (float): A floating point number that defines the regularization factor.
154
+ hessian_weights_config (GPTQHessianWeightsConfig): A configuration that include all necessary arguments to run a computation of Hessian weights for the GPTQ loss.
155
+ gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
178
156
 
179
157
  """
180
158
 
@@ -184,18 +162,13 @@ class GradientPTQConfigV2(GradientPTQConfig):
184
162
  loss=loss,
185
163
  log_function=log_function,
186
164
  train_bias=train_bias,
187
- quantization_parameters_learning=quantization_parameters_learning,
188
165
  rounding_type=rounding_type,
189
- lsb_change_per_bit_width=lsb_change_per_bit_width,
190
- eps=eps,
191
- use_jac_based_weights=use_jac_based_weights,
192
- num_samples_for_loss=num_samples_for_loss,
193
- norm_weights=norm_weights,
166
+ use_hessian_based_weights=use_hessian_based_weights,
194
167
  optimizer_quantization_parameter=optimizer_quantization_parameter,
195
168
  optimizer_bias=optimizer_bias,
196
- log_norm=log_norm,
197
- weights_n_iter=weights_n_iter,
198
- quantizer_config=quantizer_config)
169
+ regularization_factor=regularization_factor,
170
+ hessian_weights_config=hessian_weights_config,
171
+ gptq_quantizer_params_override=gptq_quantizer_params_override)
199
172
  self.n_epochs = n_epochs
200
173
 
201
174
  @classmethod
@@ -212,22 +185,3 @@ class GradientPTQConfigV2(GradientPTQConfig):
212
185
  v1_params = config_v1.__dict__
213
186
  v1_params = {k: v for k, v in v1_params.items() if k != 'n_iter'}
214
187
  return cls(n_epochs, **v1_params)
215
-
216
- def get_extended_quantizer_parametes(self) -> Dict[str, Any]:
217
- """
218
- Return a dictionary with a mapping to necessary additional parameters for initializing the GPTQ quantizer.
219
-
220
- Returns: A dictionary with parameters for initializing a quantizer.
221
-
222
- """
223
-
224
- if self.rounding_type == RoundingType.SoftQuantizer:
225
- return {N_BATCHES_STR: self.quantizer_config.n_batches,
226
- QUANT_PARAM_LEARNING_STR: self.quantization_parameters_learning,
227
- N_EPOCHS_STR: self.n_epochs}
228
- elif self.rounding_type == RoundingType.STE:
229
- return {MAX_LSB_STR: self.lsb_change_per_bit_width}
230
-
231
- return {}
232
-
233
-
@@ -2,7 +2,6 @@
2
2
  AUXVAR = 'auxvar_tensor'
3
3
  ITERVAR = 'iteration_variable'
4
4
  SCALE_TENSOR = "scale_ptq_tensor"
5
- GPTQ_ITER = "gptq_iter"
6
5
  AUXSHIFT = 'shift'
7
6
  WEIGHTS_QUANTIZATION_PARAMS = 'weights_quantization_params'
8
7
  PTQ_MIN_RANGE = "min_range"
@@ -11,22 +10,16 @@ PTQ_THRESHOLD = "ptq_threshold"
11
10
  SCALE_PTQ = "scale"
12
11
 
13
12
  # Default quantizer values
14
- N_EPOCHS = 10000
15
13
  N_CYCLES = 4
16
14
  MIM_TEMP = 0.5
17
15
  MAX_TEMP = 1.0
18
16
  REG_DEFAULT = 0.01
19
- MAX_ITERATIONS_DEFAULT = 10000
20
17
  MAX_LSB_CHANGE = 1
21
18
 
22
19
  # Soft rounding arguments values
23
20
  SOFT_ROUNDING_GAMMA = -0.1
24
21
  SOFT_ROUNDING_ZETA = 1.1
25
- SOFT_ROUNDING_BETA = 2 / 3
26
22
 
27
23
  # GPTQ config constant
28
- REGULARIZATION_VALUES = "regularization_values"
29
- N_BATCHES_STR = 'n_batches'
30
24
  QUANT_PARAM_LEARNING_STR = 'quantization_parameter_learning'
31
- N_EPOCHS_STR = 'n_epochs'
32
25
  MAX_LSB_STR = 'max_lsbs_change_map'
@@ -16,10 +16,11 @@ import copy
16
16
  from abc import ABC, abstractmethod
17
17
  import numpy as np
18
18
  from typing import Callable, List, Any
19
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType
19
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
20
20
  from model_compression_toolkit.core.common import Graph, Logger, BaseNode
21
21
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
22
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
23
+ from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
23
24
  from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
24
25
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
25
26
 
@@ -34,8 +35,7 @@ class GPTQTrainer(ABC):
34
35
  graph_quant: Graph,
35
36
  gptq_config: GradientPTQConfig,
36
37
  fw_impl: FrameworkImplementation,
37
- fw_info: FrameworkInfo,
38
- representative_data_gen: Callable):
38
+ fw_info: FrameworkInfo):
39
39
  """
40
40
  Build two models from a graph: A teacher network (float model) and a student network (quantized model).
41
41
  Use the dataset generator to pass images through the teacher and student networks to get intermediate
@@ -48,7 +48,6 @@ class GPTQTrainer(ABC):
48
48
  gptq_config: GradientPTQConfig with parameters about the tuning process.
49
49
  fw_impl: Framework implementation
50
50
  fw_info: Framework information
51
- representative_data_gen: Dataset to use for inputs of the models.
52
51
  """
53
52
  self.graph_float = copy.deepcopy(graph_float)
54
53
  self.graph_quant = copy.deepcopy(graph_quant)
@@ -66,10 +65,6 @@ class GPTQTrainer(ABC):
66
65
  append2output=self.compare_points,
67
66
  fw_info=self.fw_info)
68
67
 
69
- if self.gptq_config.rounding_type == RoundingType.SoftQuantizer:
70
- # dry run on the representative dataset to count number of batches
71
- self.count_num_batches_for_training(representative_data_gen)
72
-
73
68
  self.fxp_model, self.gptq_user_info = self.build_gptq_model()
74
69
 
75
70
  def get_optimizer_with_param(self,
@@ -88,8 +83,10 @@ class GPTQTrainer(ABC):
88
83
 
89
84
  w2train = [*flattened_trainable_weights]
90
85
 
86
+ quant_params_learning = self.gptq_config.gptq_quantizer_params_override.get(QUANT_PARAM_LEARNING_STR, False)
87
+
91
88
  optimizer_with_param = [(self.gptq_config.optimizer, w2train)]
92
- if self.gptq_config.train_bias or self.gptq_config.quantization_parameters_learning:
89
+ if self.gptq_config.train_bias or quant_params_learning:
93
90
  w2train_res = []
94
91
  if self.gptq_config.train_bias:
95
92
  if self.gptq_config.optimizer_bias is not None:
@@ -99,7 +96,7 @@ class GPTQTrainer(ABC):
99
96
  if self.gptq_config.optimizer_rest is None:
100
97
  Logger.error( # pragma: no cover
101
98
  "To enable bias micro training an additional optimizer is required, please define the optimizer_rest")
102
- if self.gptq_config.quantization_parameters_learning:
99
+ if quant_params_learning:
103
100
  if self.gptq_config.optimizer_quantization_parameter is not None: # Ability to override optimizer
104
101
  optimizer_with_param.append((self.gptq_config.optimizer_quantization_parameter,
105
102
  trainable_quantization_parameters))
@@ -107,25 +104,32 @@ class GPTQTrainer(ABC):
107
104
  w2train_res.extend(trainable_quantization_parameters)
108
105
  if self.gptq_config.optimizer_rest is None:
109
106
  Logger.error( # pragma: no cover
110
- "To enable bias micro training an additional optimizer is required, please define the optimizer_rest")
111
- optimizer_with_param.append((self.gptq_config.optimizer_rest, w2train_res))
107
+ "To enable quantization parameters micro training an additional optimizer is required, please define the optimizer_rest")
108
+ if len(w2train_res) > 0:
109
+ # Either bias or quantization parameters are trainable but did not provide a specific optimizer,
110
+ # so we should use optimizer_rest to train them
111
+ if self.gptq_config.optimizer_rest is None:
112
+ Logger.error( # pragma: no cover
113
+ "To enable bias or quantization parameters micro training an additional optimizer is required, please define the optimizer_rest")
114
+ optimizer_with_param.append((self.gptq_config.optimizer_rest, w2train_res))
112
115
 
113
116
  return optimizer_with_param
114
117
 
115
118
 
116
- def compute_jacobian_based_weights(self,
117
- representative_data_gen: Callable) -> np.ndarray:
119
+ def compute_hessian_based_weights(self,
120
+ representative_data_gen: Callable) -> np.ndarray:
118
121
  """
119
- Computes the jacobian-based weights using the framework's model_grad method per batch of images.
122
+ Computes the Hessian-based weights using the framework's model_grad method per batch of images.
120
123
 
121
124
  Args:
122
- representative_data_gen: Dataset used for inference to compute the jacobian-based weights.
125
+ representative_data_gen: Dataset used for inference to compute the Hessian-based weights.
123
126
 
124
127
  Returns: A vector of weights, one for each compare point,
125
128
  to be used for the loss metric weighted average computation when running GPTQ training.
126
129
  """
127
- if self.gptq_config.use_jac_based_weights:
128
- images = self._generate_images_batch(representative_data_gen, self.gptq_config.num_samples_for_loss)
130
+ if self.gptq_config.use_hessian_based_weights:
131
+ images = self._generate_images_batch(representative_data_gen,
132
+ self.gptq_config.hessian_weights_config.hessians_num_samples)
129
133
 
130
134
  model_output_replacement = self._get_model_output_replacement()
131
135
 
@@ -143,17 +147,18 @@ class GPTQTrainer(ABC):
143
147
  output_list=model_output_replacement,
144
148
  all_outputs_indices=[],
145
149
  alpha=0,
146
- norm_weights=self.gptq_config.norm_weights,
147
- n_iter=self.gptq_config.weights_n_iter)
150
+ norm_weights=self.gptq_config.hessian_weights_config.norm_weights,
151
+ n_iter=self.gptq_config.hessian_weights_config.hessians_n_iter)
148
152
  points_apprx_jacobians_weights.append(image_ip_gradients)
149
- if self.gptq_config.log_norm:
153
+ if self.gptq_config.hessian_weights_config.log_norm:
150
154
  mean_jacobian_weights = np.mean(points_apprx_jacobians_weights, axis=0)
151
155
  mean_jacobian_weights = np.where(mean_jacobian_weights != 0, mean_jacobian_weights,
152
156
  np.partition(mean_jacobian_weights, 1)[1])
153
157
  log_weights = np.log10(mean_jacobian_weights)
154
158
 
155
- # To add scaling to the normalized weights replace return statement with the following line:
156
- # return log_weights - np.min(log_weights) / (np.max(log_weights) - np.min(log_weights))
159
+ if self.gptq_config.hessian_weights_config.scale_log_norm:
160
+ return (log_weights - np.min(log_weights)) / (np.max(log_weights) - np.min(log_weights))
161
+
157
162
  return log_weights - np.min(log_weights)
158
163
  else:
159
164
  return np.mean(points_apprx_jacobians_weights, axis=0)
@@ -249,21 +254,6 @@ class GPTQTrainer(ABC):
249
254
  replacement_outputs.append(prev_node)
250
255
  return replacement_outputs
251
256
 
252
- def count_num_batches_for_training(self, representative_data_gen):
253
- """
254
- Runs a "dry-run" of the representative dataset to count the number of batches for each training epoch.
255
-
256
- Args:
257
- representative_data_gen: A callable method to retrieve images from Dataset.
258
-
259
- Returns: The number of batches for each training epoch.
260
-
261
- """
262
- num_batches = 0
263
- for _ in representative_data_gen():
264
- num_batches += 1
265
- self.gptq_config.quantizer_config.set_num_batches(num_batches)
266
-
267
257
 
268
258
  def gptq_training(graph_float: Graph,
269
259
  graph_quant: Graph,
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from functools import partial
16
15
  from typing import Callable, List, Tuple, Union
17
16
 
18
17
  import tensorflow as tf
@@ -23,7 +22,6 @@ from tqdm import tqdm
23
22
  # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
24
23
  from model_compression_toolkit.core.common.user_info import UserInformation
25
24
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
26
- from model_compression_toolkit.gptq.common.gptq_constants import REGULARIZATION_VALUES
27
25
  from packaging import version
28
26
 
29
27
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
@@ -37,15 +35,15 @@ else:
37
35
 
38
36
  from model_compression_toolkit.core import common
39
37
  from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
40
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2, RoundingType
38
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
41
39
  from model_compression_toolkit.core.common import Graph
42
- from model_compression_toolkit.gptq.keras.graph_info import get_weights_for_loss, \
43
- get_soft_rounding_reg, get_gptq_trainable_parameters
40
+ from model_compression_toolkit.gptq.keras.graph_info import get_weights_for_loss, get_gptq_trainable_parameters
41
+ from model_compression_toolkit.gptq.keras.quantizer.regularization_factory import get_regularization
44
42
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
45
43
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
46
44
  import numpy as np
47
45
  import copy
48
- from model_compression_toolkit.core.keras.constants import BIAS, USE_BIAS, KERNEL
46
+ from model_compression_toolkit.core.keras.constants import BIAS, USE_BIAS
49
47
  from model_compression_toolkit import quantizers_infrastructure as qi
50
48
 
51
49
 
@@ -79,13 +77,12 @@ class KerasGPTQTrainer(GPTQTrainer):
79
77
  graph_quant,
80
78
  gptq_config,
81
79
  fw_impl,
82
- fw_info,
83
- representative_data_gen)
80
+ fw_info)
84
81
 
85
82
  self.loss_list = []
86
83
  self.input_scale = 1
87
84
 
88
- trainable_weights, bias_weights, trainable_threshold, temperature_weights = get_gptq_trainable_parameters(
85
+ trainable_weights, bias_weights, trainable_threshold = get_gptq_trainable_parameters(
89
86
  self.fxp_model,
90
87
  fw_info,
91
88
  add_bias=gptq_config.train_bias)
@@ -112,7 +109,9 @@ class KerasGPTQTrainer(GPTQTrainer):
112
109
  else:
113
110
  self.input_scale = self.gptq_user_info.input_scale
114
111
 
115
- self.weights_for_average_loss = self.compute_jacobian_based_weights(representative_data_gen)
112
+ self.weights_for_average_loss = self.compute_hessian_based_weights(representative_data_gen)
113
+
114
+ self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
116
115
 
117
116
  def _is_gptq_applicable(self,
118
117
  node: common.BaseNode) -> bool:
@@ -195,9 +194,7 @@ class KerasGPTQTrainer(GPTQTrainer):
195
194
  self.compare_points_std,
196
195
  self.weights_for_average_loss)
197
196
 
198
- reg_value = self.gptq_config.quantizer_config.get_regularization_value(
199
- self.fxp_model,
200
- **{REGULARIZATION_VALUES: self._get_quantizer_regularization_values(self.gptq_config.rounding_type)})
197
+ reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor)
201
198
 
202
199
  loss_value += reg_value
203
200
 
@@ -319,18 +316,3 @@ class KerasGPTQTrainer(GPTQTrainer):
319
316
  node.set_weights_by_keys(BIAS, new_bias)
320
317
 
321
318
  return graph
322
-
323
- def _get_quantizer_regularization_values(self, rounding_type: RoundingType) -> List[tf.Tensor]:
324
- """
325
- Mapping between a rounding type to its matching regularization method.
326
-
327
- Args:
328
- rounding_type: GPTQ rounding type.
329
-
330
- Returns: A regularization computation method.
331
-
332
- """
333
- if rounding_type == RoundingType.SoftQuantizer:
334
- return get_soft_rounding_reg(self.fxp_model)
335
- else:
336
- return []
@@ -13,23 +13,21 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
-
17
16
  import tensorflow as tf
18
17
  from typing import Tuple, List
19
-
20
18
  from model_compression_toolkit.core.keras.constants import USE_BIAS
21
19
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
20
  from tensorflow.keras.models import Model
23
-
24
21
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
25
22
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
26
23
  from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
24
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
27
25
 
28
26
 
29
27
  def get_gptq_trainable_parameters(fxp_model: Model,
30
28
  fw_info: FrameworkInfo,
31
29
  add_bias: bool = False) -> (
32
- List[tf.Variable], List[tf.Variable], List[tf.Variable], List[tf.Variable], List[tf.Variable]):
30
+ List[tf.Variable], List[tf.Variable], List[tf.Variable]):
33
31
  """
34
32
  Get trainable parameters from all layers in a model
35
33
 
@@ -45,16 +43,17 @@ def get_gptq_trainable_parameters(fxp_model: Model,
45
43
  trainable_weights: List[tf.Tensor] = []
46
44
  trainable_threshold: List[tf.Tensor] = []
47
45
  bias_weights: List[List[tf.Tensor]] = []
48
- temperature_weights: List[tf.Tensor] = []
49
46
 
50
47
  for layer in fxp_model.layers:
51
48
  if isinstance(layer, KerasQuantizationWrapper):
52
49
  kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
53
50
  fw_info=DEFAULT_KERAS_INFO)
54
51
 
55
- # collect trainable weights per layer
56
- layer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_aux_variable()
57
- layer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_quantization_variable()
52
+ # collect trainable weights per quantizer
53
+ quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
54
+ quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
55
+ trainable_weights.append(quantizer_trainable_weights)
56
+ trainable_threshold.extend(quantizer_trainable_threshold)
58
57
 
59
58
  if add_bias:
60
59
  kernel_ops_attrs = fw_info.kernel_ops_attributes_mapping.get(type(layer.layer))
@@ -62,10 +61,8 @@ def get_gptq_trainable_parameters(fxp_model: Model,
62
61
  and layer.layer.get_config().get(USE_BIAS)
63
62
  if use_bias is not None and use_bias:
64
63
  bias_weights.append([layer.layer.bias])
65
- trainable_weights.append(layer_trainable_weights)
66
- trainable_threshold.extend(layer_trainable_threshold)
67
64
 
68
- return trainable_weights, bias_weights, trainable_threshold, temperature_weights
65
+ return trainable_weights, bias_weights, trainable_threshold
69
66
 
70
67
 
71
68
  def get_weights_for_loss(fxp_model: Model) -> Tuple[List[list], List[list]]:
@@ -95,25 +92,3 @@ def get_weights_for_loss(fxp_model: Model) -> Tuple[List[list], List[list]]:
95
92
  fxp_weights_list.append(_layer_fxp_weights)
96
93
 
97
94
  return flp_weights_list, fxp_weights_list
98
-
99
-
100
- # TODO: this function need to move to location that is relevant only for soft quantizer -
101
- # once deciding how to handle GPTQ quantizers regularization.
102
- def get_soft_rounding_reg(fxp_model: Model) -> List[tf.Tensor]:
103
- """
104
- This function returns the soft quantizer regularization values for SoftRounding.
105
-
106
- Args:
107
- fxp_model: A model to be quantized with SoftRounding.
108
-
109
- Returns: A list of tensors.
110
- """
111
-
112
- soft_reg_aux: List[tf.Tensor] = []
113
- for layer in fxp_model.layers:
114
- if isinstance(layer, KerasQuantizationWrapper):
115
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
116
- fw_info=DEFAULT_KERAS_INFO)
117
-
118
- soft_reg_aux.append(layer.weights_quantizers[kernel_attribute].get_regularization())
119
- return soft_reg_aux
@@ -85,24 +85,18 @@ if common.constants.FOUND_TF:
85
85
 
86
86
  Create a GradientPTQConfigV2 to run for 5 epochs:
87
87
 
88
- >>> gptq_conf = mct.get_keras_gptq_config(n_epochs=5)
88
+ >>> gptq_conf = mct.gptq.get_keras_gptq_config(n_epochs=5)
89
89
 
90
90
  Other Tensorflow optimizers can be passed:
91
91
 
92
- >>> gptq_conf = mct.get_keras_gptq_config(n_epochs=3, optimizer=tf.keras.optimizers.Nadam())
92
+ >>> gptq_conf = mct.gptq.get_keras_gptq_config(n_epochs=3, optimizer=tf.keras.optimizers.Nadam())
93
93
 
94
94
  The configuration can be passed to :func:`~model_compression_toolkit.keras_post_training_quantization` in order to quantize a keras model using gptq.
95
95
 
96
96
  """
97
97
  bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
98
- return GradientPTQConfigV2(n_epochs,
99
- optimizer,
100
- optimizer_rest=optimizer_rest,
101
- loss=loss,
102
- log_function=log_function,
103
- train_bias=True,
104
- quantization_parameters_learning=True,
105
- optimizer_bias=bias_optimizer)
98
+ return GradientPTQConfigV2(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
99
+ log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer)
106
100
 
107
101
 
108
102
  def keras_gradient_post_training_quantization_experimental(in_model: Model,
@@ -181,11 +175,11 @@ if common.constants.FOUND_TF:
181
175
 
182
176
  Create GPTQ config:
183
177
 
184
- >>> gptq_config = mct.get_keras_gptq_config(n_epochs=1)
178
+ >>> gptq_config = mct.gptq.get_keras_gptq_config(n_epochs=1)
185
179
 
186
180
  Pass the model with the representative dataset generator to get a quantized model:
187
181
 
188
- >>> quantized_model, quantization_info = mct.keras_gradient_post_training_quantization_experimental(model, repr_datagen, gptq_config, target_kpi=kpi, core_config=config)
182
+ >>> quantized_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization_experimental(model, repr_datagen, gptq_config, target_kpi=kpi, core_config=config)
189
183
 
190
184
  """
191
185
  KerasModelValidation(model=in_model,