mct-nightly 2.2.0.20241203.546__py3-none-any.whl → 2.2.0.20241205.533__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/RECORD +29 -29
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/graph/base_graph.py +9 -5
  5. model_compression_toolkit/core/common/graph/base_node.py +2 -3
  6. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +32 -35
  7. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +9 -9
  8. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +5 -11
  9. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +12 -0
  10. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +11 -4
  11. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +4 -6
  12. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +6 -11
  13. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +6 -9
  14. model_compression_toolkit/core/keras/data_util.py +151 -18
  15. model_compression_toolkit/core/keras/hessian/activation_hessian_scores_calculator_keras.py +93 -86
  16. model_compression_toolkit/core/keras/hessian/hessian_scores_calculator_keras.py +17 -0
  17. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -2
  18. model_compression_toolkit/core/keras/keras_implementation.py +23 -27
  19. model_compression_toolkit/core/pytorch/pytorch_implementation.py +2 -4
  20. model_compression_toolkit/gptq/common/gptq_training.py +58 -0
  21. model_compression_toolkit/gptq/keras/gptq_loss.py +35 -2
  22. model_compression_toolkit/gptq/keras/gptq_training.py +137 -67
  23. model_compression_toolkit/gptq/keras/graph_info.py +1 -4
  24. model_compression_toolkit/gptq/keras/quantization_facade.py +24 -11
  25. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +23 -11
  26. model_compression_toolkit/gptq/pytorch/gptq_training.py +4 -45
  27. {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/LICENSE.md +0 -0
  28. {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/WHEEL +0 -0
  29. {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Callable, List, Tuple, Union
15
+ from typing import Callable, List, Tuple, Union, Generator
16
16
 
17
17
  import tensorflow as tf
18
18
  from keras import Model
@@ -20,11 +20,13 @@ from packaging import version
20
20
  from tensorflow.keras.layers import Layer
21
21
  from tqdm import tqdm
22
22
 
23
- from model_compression_toolkit.core.common.hessian import HessianInfoService
23
+ from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresGranularity
24
24
  # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
25
25
  from model_compression_toolkit.core.common.user_info import UserInformation
26
26
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
27
- from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader
27
+ from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader, \
28
+ FixedSampleInfoDataset, FixedTFDataset, create_tf_dataloader, TFDatasetFromGenerator, \
29
+ IterableSampleWithConstInfoDataset
28
30
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
29
31
  from model_compression_toolkit.gptq.common.gradual_activation_quantization import \
30
32
  get_gradual_activation_quantizer_wrapper_factory
@@ -83,13 +85,10 @@ class KerasGPTQTrainer(GPTQTrainer):
83
85
 
84
86
  """
85
87
 
86
- def _get_total_grad_steps():
87
- return get_total_grad_steps(representative_data_gen) * gptq_config.n_epochs
88
-
89
- # This must be set before the model building (as it is required for activation holder construction),
90
- # which occurs in the base constructor.
91
- self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(
92
- gptq_config, _get_total_grad_steps, KerasLinearAnnealingScheduler)
88
+ self.fw_soft_quantizer_regularization = SoftQuantizerRegularization
89
+ self.fw_linear_annealing_scheduler = KerasLinearAnnealingScheduler
90
+ self.fw_get_gptq_trainable_parameters_fn = get_gptq_trainable_parameters
91
+ self.fw_get_weights_for_loss_fn = get_weights_for_loss
93
92
 
94
93
  super().__init__(graph_float,
95
94
  graph_quant,
@@ -99,53 +98,106 @@ class KerasGPTQTrainer(GPTQTrainer):
99
98
  representative_data_gen_fn=representative_data_gen,
100
99
  hessian_info_service=hessian_info_service)
101
100
 
102
- self.loss_list = []
103
- self.input_scale = 1
104
-
105
- trainable_weights, bias_weights, trainable_threshold = get_gptq_trainable_parameters(
106
- self.fxp_model,
107
- fw_info,
108
- add_bias=gptq_config.train_bias)
109
-
110
- self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
111
-
112
- if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
113
- self.fxp_weights_list)):
114
- Logger.critical("Mismatch in the number of comparison points, layers with trainable weights, "
115
- "and the number of float and quantized weights for loss calculation. "
116
- "Ensure all these elements align to proceed with GPTQ training.")
117
-
118
- flattened_trainable_weights = [w for layer_weights in trainable_weights for w in layer_weights]
119
- flattened_bias_weights = [w for layer_weights in bias_weights for w in layer_weights]
120
- trainable_quantization_parameters = trainable_threshold
121
- self.optimizer_with_param = self.get_optimizer_with_param(flattened_trainable_weights,
122
- flattened_bias_weights,
123
- trainable_quantization_parameters)
124
- self.has_params_to_train = np.sum(
125
- [len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
126
-
127
- if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
128
- Logger.critical("Input scale mismatch detected between the float model and the GPTQ model. "
129
- "Confirm that the input scales for both models are correctly configured and aligned.") # pragma: no cover
130
- else:
131
- self.input_scale = self.gptq_user_info.input_scale
132
101
 
133
- self.weights_for_average_loss = self._get_compare_points_loss_weights()
102
+ def _prepare_train_dataloader_sla(self, data_gen_fn: Callable[[], Generator]) -> tf.data.Dataset:
103
+ """
104
+ Computes Sample-Layer Attention score and builds a train dataloader in TensorFlow.
134
105
 
135
- self.reg_func = get_regularization(self.gptq_config,
136
- _get_total_grad_steps,
137
- SoftQuantizerRegularization,
138
- KerasLinearAnnealingScheduler)
106
+ Args:
107
+ data_gen_fn: function for representative dataset generation.
139
108
 
140
- def _get_compare_points_loss_weights(self):
141
- """ Get compare points weights for the distillation loss. """
142
- if self.gptq_config.hessian_weights_config:
143
- hess_dataloader = data_gen_to_dataloader(self.representative_data_gen_fn,
144
- batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
145
- return self.compute_hessian_based_weights(hess_dataloader)
109
+ Returns:
110
+ TensorFlow dataset yielding three outputs - samples, weights for the distillation loss,
111
+ and weights for regularization.
112
+ """
113
+ # Create a fixed dataset
114
+ fixed_dataset = FixedTFDataset(data_gen_fn)
115
+ orig_batch_size = fixed_dataset.orig_batch_size
116
+
117
+ # Prepare a separate loader for computing hessians over the whole dataset
118
+ hess_data_loader = create_tf_dataloader(
119
+ fixed_dataset,
120
+ batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size,
121
+ shuffle=False
122
+ )
123
+
124
+ # Prepare request for Hessian computation
125
+ request = self._build_hessian_request(
126
+ granularity=HessianScoresGranularity.PER_OUTPUT_CHANNEL,
127
+ data_loader=hess_data_loader,
128
+ n_samples=None
129
+ )
130
+ layers_hessians = self.hessian_service.fetch_hessian(request, force_compute=True)
131
+
132
+ # Compute SLA score defined as max over elements
133
+ layers_hessians = {
134
+ layer: tf.convert_to_tensor(tf.reduce_max(hess, axis=tuple(range(1, len(hess.shape))))) for layer, hess in layers_hessians.items()
135
+ }
136
+
137
+ # Stack hessians for comparison points
138
+ hessians_tensor = tf.stack([layers_hessians[layer.name] for layer in self.compare_points])
139
+ assert hessians_tensor.shape[0] == len(self.compare_points)
140
+ loss_weights = list(hessians_tensor.numpy()) # Convert to a list for compatibility
141
+
142
+ # Prepare final dataset with samples and loss weights
143
+ sla_train_dataset = FixedSampleInfoDataset(fixed_dataset.samples, loss_weights)
144
+
145
+ # Calculate regularization weights as mean across samples
146
+ reg_weights = tf.reduce_mean(hessians_tensor, axis=1)
147
+
148
+ # Define a collate function to add regularization weights to each batch
149
+ def collate_fn(samples_with_loss_weights):
150
+ return *samples_with_loss_weights, reg_weights
151
+
152
+ # Create final dataset using the new dataloader with collate_fn
153
+ final_dataset = create_tf_dataloader(
154
+ dataset=sla_train_dataset,
155
+ batch_size=orig_batch_size,
156
+ shuffle=True,
157
+ collate_fn=collate_fn
158
+ )
159
+
160
+ return final_dataset
161
+
162
+ def _prepare_train_dataloader_for_non_sla(self,
163
+ data_gen_fn: Callable[[], Generator]) -> tf.data.Dataset:
164
+ """
165
+ Prepares a train dataloader for non-SLA tasks.
146
166
 
167
+ Args:
168
+ data_gen_fn: Factory for representative dataset generator.
169
+
170
+ Returns:
171
+ A `tf.data.Dataset` yielding samples with loss weights and regularization weights.
172
+ """
173
+ # Step 1: Create a dataset from the generator
174
+ dataset = TFDatasetFromGenerator(data_gen_fn)
147
175
  num_nodes = len(self.compare_points)
148
- return np.ones((num_nodes,)) / num_nodes
176
+
177
+ # Step 2: Compute loss weights
178
+ if self.gptq_config.hessian_weights_config:
179
+ hessian_dataset = create_tf_dataloader(dataset=dataset, batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
180
+ hessian_weights = self.compute_hessian_based_weights(hessian_dataset)
181
+ loss_weights = tf.convert_to_tensor(hessian_weights, dtype=tf.float32)
182
+ else:
183
+ loss_weights = tf.ones(num_nodes, dtype=tf.float32) / num_nodes
184
+
185
+ # Step 3: Create a dataset with samples and loss weights
186
+ augmented_dataset = IterableSampleWithConstInfoDataset(dataset.dataset, loss_weights)
187
+
188
+ # Step 4: Add constant regularization weights
189
+ reg_weights = tf.ones(num_nodes, dtype=tf.float32)
190
+
191
+ def collate_fn(batch):
192
+ samples, loss_weights = batch
193
+ return samples, loss_weights, reg_weights
194
+
195
+ # Step 5: Create a tf.data.Dataset with collate_fn
196
+ train_dataloader = create_tf_dataloader(augmented_dataset,
197
+ batch_size=dataset.orig_batch_size,
198
+ collate_fn=collate_fn)
199
+
200
+ return train_dataloader
149
201
 
150
202
  def _is_gptq_weights_trainable(self,
151
203
  node: common.BaseNode) -> bool:
@@ -226,9 +278,13 @@ class KerasGPTQTrainer(GPTQTrainer):
226
278
 
227
279
  return gptq_model, gptq_user_info
228
280
 
229
- def compute_gradients(self, in_y_float: List[tf.Tensor], input_data: List[np.ndarray],
281
+ def compute_gradients(self,
282
+ in_y_float: List[tf.Tensor],
283
+ input_data: List[np.ndarray],
230
284
  in_optimizer_with_param: List,
231
- training=True) -> Tuple[tf.Tensor, List[tf.Tensor]]:
285
+ training=True,
286
+ distill_loss_weights=None,
287
+ reg_weights=None) -> Tuple[tf.Tensor, List[tf.Tensor]]:
232
288
  """
233
289
  Get outputs from both teacher and student networks. Compute the observed error,
234
290
  and use it to compute the gradients and applying them to the student weights.
@@ -253,9 +309,9 @@ class KerasGPTQTrainer(GPTQTrainer):
253
309
  self.flp_weights_list,
254
310
  self.compare_points_mean,
255
311
  self.compare_points_std,
256
- self.weights_for_average_loss)
312
+ distill_loss_weights)
257
313
 
258
- reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor)
314
+ reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor, reg_weights)
259
315
 
260
316
  loss_value += reg_value
261
317
 
@@ -279,14 +335,19 @@ class KerasGPTQTrainer(GPTQTrainer):
279
335
  # Training loop
280
336
  # ----------------------------------------------
281
337
  if self.has_params_to_train:
282
- self.micro_training_loop(self.representative_data_gen_fn,
283
- compute_gradients,
338
+ self.micro_training_loop(compute_gradients,
284
339
  self.optimizer_with_param,
285
340
  self.gptq_config.n_epochs,
286
341
  True)
287
342
 
288
343
  @tf.function
289
- def nano_training_step(self, input_data, in_compute_gradients, in_optimizer_with_param, is_training):
344
+ def nano_training_step(self,
345
+ input_data,
346
+ in_compute_gradients,
347
+ in_optimizer_with_param,
348
+ is_training,
349
+ distill_loss_weights,
350
+ reg_weights):
290
351
  """
291
352
  This function run part of the training step, wrapped by a tf.function for acceleration.
292
353
  Args:
@@ -303,12 +364,15 @@ class KerasGPTQTrainer(GPTQTrainer):
303
364
  # run float model
304
365
  y_float = self.float_model(input_data)
305
366
  # rung quantized model and calculate loss & gradients
306
- loss_value_step, grads = in_compute_gradients(y_float, input_data, in_optimizer_with_param,
307
- training=is_training)
367
+ loss_value_step, grads = in_compute_gradients(y_float,
368
+ input_data,
369
+ in_optimizer_with_param,
370
+ training=is_training,
371
+ distill_loss_weights=distill_loss_weights,
372
+ reg_weights=reg_weights)
308
373
  return loss_value_step, grads
309
374
 
310
375
  def micro_training_loop(self,
311
- data_function: Callable,
312
376
  in_compute_gradients: Callable,
313
377
  in_optimizer_with_param: List[Tuple[tf.keras.optimizers.Optimizer, List[tf.Tensor]]],
314
378
  n_epochs: int,
@@ -316,7 +380,6 @@ class KerasGPTQTrainer(GPTQTrainer):
316
380
  """
317
381
  This function run a micro training loop on given set of parameters.
318
382
  Args:
319
- data_function: A callable function that give a batch of samples.
320
383
  in_compute_gradients: A callable function that compute the gradients.
321
384
  in_optimizer_with_param: A list of optimizer classes to update with the corresponding parameters.
322
385
  n_epochs: Number of update iterations of representative dataset.
@@ -327,12 +390,19 @@ class KerasGPTQTrainer(GPTQTrainer):
327
390
  """
328
391
  with tqdm(range(n_epochs), "Running GPTQ optimization") as epochs_pbar:
329
392
  for _ in epochs_pbar:
330
- with tqdm(data_function(), position=1, leave=False) as data_pbar:
393
+ with tqdm(self.train_dataloader, position=1, leave=False) as data_pbar:
331
394
  for data in data_pbar:
332
- input_data = [d * self.input_scale for d in data]
333
395
 
334
- loss_value_step, grads = self.nano_training_step(input_data, in_compute_gradients,
335
- in_optimizer_with_param, is_training)
396
+ input_data, distill_loss_weights, reg_weight = data
397
+
398
+ input_data = [d * self.input_scale for d in input_data]
399
+
400
+ loss_value_step, grads = self.nano_training_step(input_data,
401
+ in_compute_gradients,
402
+ in_optimizer_with_param,
403
+ is_training,
404
+ distill_loss_weights,
405
+ reg_weight)
336
406
  # Run one step of gradient descent by updating
337
407
  # the value of the variables to minimize the loss.
338
408
  for i, (o, p) in enumerate(in_optimizer_with_param):
@@ -16,7 +16,6 @@
16
16
  import tensorflow as tf
17
17
  from typing import Tuple, List
18
18
  from model_compression_toolkit.core.keras.constants import USE_BIAS
19
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
20
19
  from tensorflow.keras.models import Model
21
20
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
22
21
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
@@ -26,7 +25,6 @@ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_qu
26
25
 
27
26
 
28
27
  def get_gptq_trainable_parameters(fxp_model: Model,
29
- fw_info: FrameworkInfo,
30
28
  add_bias: bool = False) -> (
31
29
  List[tf.Variable], List[tf.Variable], List[tf.Variable]):
32
30
  """
@@ -34,7 +32,6 @@ def get_gptq_trainable_parameters(fxp_model: Model,
34
32
 
35
33
  Args:
36
34
  fxp_model: Model to get its trainable parameters.
37
- fw_info: Framework information needed for keras kernel ops list.
38
35
  add_bias: Whether to include biases of the model (if there are) or not.
39
36
 
40
37
  Returns:
@@ -60,7 +57,7 @@ def get_gptq_trainable_parameters(fxp_model: Model,
60
57
  trainable_threshold.extend(quantizer_trainable_threshold)
61
58
 
62
59
  if add_bias:
63
- kernel_ops_attrs = fw_info.kernel_ops_attributes_mapping.get(type(layer.layer))
60
+ kernel_ops_attrs = DEFAULT_KERAS_INFO.kernel_ops_attributes_mapping.get(type(layer.layer))
64
61
  use_bias = kernel_ops_attrs is not None and kernel_ops_attrs[0] is not None \
65
62
  and layer.layer.get_config().get(USE_BIAS)
66
63
  if use_bias is not None and use_bias and layer.layer.bias is not None:
@@ -19,7 +19,7 @@ from packaging import version
19
19
 
20
20
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
21
21
  from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \
22
- LR_BIAS_DEFAULT, GPTQ_MOMENTUM
22
+ LR_BIAS_DEFAULT, GPTQ_MOMENTUM, REG_DEFAULT_SLA
23
23
  from model_compression_toolkit.logger import Logger
24
24
  from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE, GPTQ_HESSIAN_NUM_SAMPLES
25
25
  from model_compression_toolkit.verify_packages import FOUND_TF
@@ -42,7 +42,7 @@ if FOUND_TF:
42
42
  from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
43
43
  from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
44
44
  from tensorflow.keras.models import Model
45
- from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss
45
+ from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss, sample_layer_attention_loss
46
46
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
47
47
  from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
48
48
  from model_compression_toolkit import get_target_platform_capabilities
@@ -61,11 +61,12 @@ if FOUND_TF:
61
61
  def get_keras_gptq_config(n_epochs: int,
62
62
  optimizer: OptimizerV2 = None,
63
63
  optimizer_rest: OptimizerV2 = None,
64
- loss: Callable = GPTQMultipleTensorsLoss(),
64
+ loss: Callable = None,
65
65
  log_function: Callable = None,
66
66
  use_hessian_based_weights: bool = True,
67
- regularization_factor: float = REG_DEFAULT,
67
+ regularization_factor: float = None,
68
68
  hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE,
69
+ use_hessian_sample_attention: bool = False,
69
70
  gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = False) -> GradientPTQConfig:
70
71
  """
71
72
  Create a GradientPTQConfig instance for Keras models.
@@ -79,6 +80,7 @@ if FOUND_TF:
79
80
  use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
80
81
  regularization_factor (float): A floating point number that defines the regularization factor.
81
82
  hessian_batch_size (int): Batch size for Hessian computation in Hessian-based weights GPTQ.
83
+ use_hessian_sample_attention (bool): whether to use Sample-Layer Attention score for weighted loss.
82
84
  gradual_activation_quantization (bool, GradualActivationQuantizationConfig): If False, GradualActivationQuantization is disabled. If True, GradualActivationQuantization is enabled with the default settings. GradualActivationQuantizationConfig object can be passed to use non-default settings.
83
85
 
84
86
  returns:
@@ -105,9 +107,25 @@ if FOUND_TF:
105
107
  """
106
108
  optimizer = optimizer or tf.keras.optimizers.Adam(learning_rate=LR_DEFAULT)
107
109
  optimizer_rest = optimizer_rest or tf.keras.optimizers.Adam(learning_rate=LR_REST_DEFAULT)
110
+ bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
108
111
 
109
- bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT,
110
- momentum=GPTQ_MOMENTUM)
112
+ if regularization_factor is None:
113
+ regularization_factor = REG_DEFAULT_SLA if use_hessian_sample_attention else REG_DEFAULT
114
+
115
+ loss = loss or GPTQMultipleTensorsLoss()
116
+ hessian_weights_config = None
117
+ if use_hessian_sample_attention:
118
+ if not use_hessian_based_weights: # pragma: no cover
119
+ raise ValueError('use_hessian_based_weights must be set to True in order to use Sample Layer Attention.')
120
+
121
+ hessian_weights_config = GPTQHessianScoresConfig(per_sample=True,
122
+ hessians_num_samples=None,
123
+ hessian_batch_size=hessian_batch_size)
124
+ loss = loss or sample_layer_attention_loss
125
+ elif use_hessian_based_weights:
126
+ hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
127
+ hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
128
+ hessian_batch_size=hessian_batch_size)
111
129
 
112
130
  if isinstance(gradual_activation_quantization, bool):
113
131
  gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None
@@ -117,11 +135,6 @@ if FOUND_TF:
117
135
  raise TypeError(f'gradual_activation_quantization argument should be bool or '
118
136
  f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}')
119
137
 
120
- hessian_weights_config = None
121
- if use_hessian_based_weights:
122
- hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
123
- hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
124
- hessian_batch_size=hessian_batch_size)
125
138
  return GradientPTQConfig(n_epochs=n_epochs,
126
139
  optimizer=optimizer,
127
140
  optimizer_rest=optimizer_rest,
@@ -40,30 +40,42 @@ class SoftQuantizerRegularization:
40
40
  self.count_iter = tf.Variable(0.)
41
41
 
42
42
 
43
- def __call__(self, model: Model, entropy_reg: float):
43
+ def __call__(self, model: Model, entropy_reg: float, layer_weights: tf.Tensor):
44
44
  """
45
45
  Returns the soft quantizer regularization value for SoftRounding.
46
46
 
47
47
  Args:
48
48
  model: A model to be quantized with SoftRounding.
49
49
  entropy_reg: Entropy value to scale the quantizer regularization.
50
+ layer_weights: a vector of layers weights.
50
51
 
51
52
  Returns: Regularization value.
52
53
  """
53
- soft_reg_aux: List[tf.Tensor] = []
54
+ layers = [l for l in model.layers if isinstance(l, KerasTrainableQuantizationWrapper)]
55
+
56
+ if layer_weights.shape[0] != len(layers):
57
+ raise ValueError(f'Expected weights.shape[0] to be {len(layers)}, '
58
+ f'received shape {layer_weights.shape}.') # pragma: no cover
59
+
54
60
  b = self.beta_scheduler(self.count_iter.value())
55
- for layer in model.layers:
56
- if isinstance(layer, KerasTrainableQuantizationWrapper):
57
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
58
- fw_info=DEFAULT_KERAS_INFO)
59
61
 
60
- st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
61
- soft_reg_aux.append(tf.reduce_sum(1 - tf.pow(tf.math.abs(st - .5) * 2, b)))
62
+ max_w = tf.reduce_max(layer_weights)
63
+
64
+ # Initialize reg to zero
65
+ reg = tf.constant(0.0, dtype=tf.float32)
66
+
67
+ # Compute the regularization term without concatenating
68
+ for i, layer in enumerate(layers):
69
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
70
+ fw_info=DEFAULT_KERAS_INFO)
71
+
72
+ st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
62
73
 
63
- reg = 0
74
+ soft_loss = tf.reduce_sum(1 - tf.pow(tf.math.abs(st - 0.5) * 2, b))
75
+ reg += layer_weights[i] * soft_loss
64
76
 
65
- for sq in soft_reg_aux:
66
- reg += sq
77
+ # Normalize reg by max_w
78
+ reg = reg / max_w
67
79
 
68
80
  self.count_iter.assign_add(1.0)
69
81
 
@@ -21,9 +21,6 @@ from torch.nn import Module
21
21
  from torch.utils.data import DataLoader
22
22
  from tqdm import tqdm
23
23
 
24
- from model_compression_toolkit.gptq.common.gradual_activation_quantization import get_gradual_activation_quantizer_wrapper_factory
25
- from model_compression_toolkit.gptq.common.regularization_factory import get_regularization
26
-
27
24
  from model_compression_toolkit.core.common import Graph, BaseNode
28
25
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
29
26
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
@@ -41,7 +38,6 @@ from model_compression_toolkit.gptq.pytorch.graph_info import get_gptq_trainable
41
38
  from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder
42
39
 
43
40
  from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
44
- from model_compression_toolkit.trainable_infrastructure.common.util import get_total_grad_steps
45
41
  from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import PytorchLinearAnnealingScheduler
46
42
  from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import SoftQuantizerRegularization as PytorchSoftQuantizerRegularization
47
43
 
@@ -76,13 +72,10 @@ class PytorchGPTQTrainer(GPTQTrainer):
76
72
  representative_data_gen: Dataset to use for inputs of the models.
77
73
  hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model.
78
74
  """
79
- def _get_total_grad_steps():
80
- # TODO get it from the dataset
81
- return get_total_grad_steps(representative_data_gen) * gptq_config.n_epochs
82
-
83
- # must be set prior to model building in the base class constructor
84
- self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(
85
- gptq_config, _get_total_grad_steps, PytorchLinearAnnealingScheduler)
75
+ self.fw_soft_quantizer_regularization = PytorchSoftQuantizerRegularization
76
+ self.fw_linear_annealing_scheduler = PytorchLinearAnnealingScheduler
77
+ self.fw_get_gptq_trainable_parameters_fn = get_gptq_trainable_parameters
78
+ self.fw_get_weights_for_loss_fn = get_weights_for_loss
86
79
 
87
80
  super().__init__(graph_float,
88
81
  graph_quant,
@@ -92,40 +85,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
92
85
  representative_data_gen_fn=representative_data_gen,
93
86
  hessian_info_service=hessian_info_service)
94
87
 
95
- self.loss_list = []
96
- self.input_scale = 1
97
- if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
98
- Logger.critical("Input scale mismatch between float and GPTQ networks. "
99
- "Ensure both networks have matching input scales.") # pragma: no cover
100
- else:
101
- self.input_scale = self.gptq_user_info.input_scale
102
-
103
- trainable_weights, trainable_bias, trainable_threshold = get_gptq_trainable_parameters(
104
- self.fxp_model,
105
- add_bias=self.gptq_config.train_bias)
106
-
107
- self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
108
- if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
109
- self.fxp_weights_list)):
110
- Logger.critical("GPTQ: Number of comparison points, layers with trainable weights, "
111
- "and float vs. quantized weights for loss calculation do not match. "
112
- "Verify consistency across these parameters for successful GPTQ training.")
113
-
114
- self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
115
- trainable_bias,
116
- trainable_threshold)
117
- hessian_cfg = self.gptq_config.hessian_weights_config
118
-
119
- self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample
120
- if self.use_sample_layer_attention:
121
- # normalization is currently not supported, make sure the config reflects it.
122
- if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
123
- raise NotImplementedError()
124
- self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen)
125
- else:
126
- self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen)
127
-
128
- self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps, PytorchSoftQuantizerRegularization, PytorchLinearAnnealingScheduler)
129
88
 
130
89
  def _prepare_train_dataloader_sla(self, data_gen_fn: Callable[[], Generator]) -> DataLoader:
131
90
  """