mct-nightly 2.2.0.20241203.546__py3-none-any.whl → 2.2.0.20241205.533__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/RECORD +29 -29
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/base_graph.py +9 -5
- model_compression_toolkit/core/common/graph/base_node.py +2 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +32 -35
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +9 -9
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +5 -11
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +12 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +11 -4
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +4 -6
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +6 -11
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +6 -9
- model_compression_toolkit/core/keras/data_util.py +151 -18
- model_compression_toolkit/core/keras/hessian/activation_hessian_scores_calculator_keras.py +93 -86
- model_compression_toolkit/core/keras/hessian/hessian_scores_calculator_keras.py +17 -0
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -2
- model_compression_toolkit/core/keras/keras_implementation.py +23 -27
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +2 -4
- model_compression_toolkit/gptq/common/gptq_training.py +58 -0
- model_compression_toolkit/gptq/keras/gptq_loss.py +35 -2
- model_compression_toolkit/gptq/keras/gptq_training.py +137 -67
- model_compression_toolkit/gptq/keras/graph_info.py +1 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +24 -11
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +23 -11
- model_compression_toolkit/gptq/pytorch/gptq_training.py +4 -45
- {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from typing import Callable, List, Tuple, Union
|
15
|
+
from typing import Callable, List, Tuple, Union, Generator
|
16
16
|
|
17
17
|
import tensorflow as tf
|
18
18
|
from keras import Model
|
@@ -20,11 +20,13 @@ from packaging import version
|
|
20
20
|
from tensorflow.keras.layers import Layer
|
21
21
|
from tqdm import tqdm
|
22
22
|
|
23
|
-
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
23
|
+
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresGranularity
|
24
24
|
# As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
|
25
25
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
26
26
|
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
27
|
-
from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader
|
27
|
+
from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader, \
|
28
|
+
FixedSampleInfoDataset, FixedTFDataset, create_tf_dataloader, TFDatasetFromGenerator, \
|
29
|
+
IterableSampleWithConstInfoDataset
|
28
30
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
29
31
|
from model_compression_toolkit.gptq.common.gradual_activation_quantization import \
|
30
32
|
get_gradual_activation_quantizer_wrapper_factory
|
@@ -83,13 +85,10 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
83
85
|
|
84
86
|
"""
|
85
87
|
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
# which occurs in the base constructor.
|
91
|
-
self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(
|
92
|
-
gptq_config, _get_total_grad_steps, KerasLinearAnnealingScheduler)
|
88
|
+
self.fw_soft_quantizer_regularization = SoftQuantizerRegularization
|
89
|
+
self.fw_linear_annealing_scheduler = KerasLinearAnnealingScheduler
|
90
|
+
self.fw_get_gptq_trainable_parameters_fn = get_gptq_trainable_parameters
|
91
|
+
self.fw_get_weights_for_loss_fn = get_weights_for_loss
|
93
92
|
|
94
93
|
super().__init__(graph_float,
|
95
94
|
graph_quant,
|
@@ -99,53 +98,106 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
99
98
|
representative_data_gen_fn=representative_data_gen,
|
100
99
|
hessian_info_service=hessian_info_service)
|
101
100
|
|
102
|
-
self.loss_list = []
|
103
|
-
self.input_scale = 1
|
104
|
-
|
105
|
-
trainable_weights, bias_weights, trainable_threshold = get_gptq_trainable_parameters(
|
106
|
-
self.fxp_model,
|
107
|
-
fw_info,
|
108
|
-
add_bias=gptq_config.train_bias)
|
109
|
-
|
110
|
-
self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
|
111
|
-
|
112
|
-
if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
|
113
|
-
self.fxp_weights_list)):
|
114
|
-
Logger.critical("Mismatch in the number of comparison points, layers with trainable weights, "
|
115
|
-
"and the number of float and quantized weights for loss calculation. "
|
116
|
-
"Ensure all these elements align to proceed with GPTQ training.")
|
117
|
-
|
118
|
-
flattened_trainable_weights = [w for layer_weights in trainable_weights for w in layer_weights]
|
119
|
-
flattened_bias_weights = [w for layer_weights in bias_weights for w in layer_weights]
|
120
|
-
trainable_quantization_parameters = trainable_threshold
|
121
|
-
self.optimizer_with_param = self.get_optimizer_with_param(flattened_trainable_weights,
|
122
|
-
flattened_bias_weights,
|
123
|
-
trainable_quantization_parameters)
|
124
|
-
self.has_params_to_train = np.sum(
|
125
|
-
[len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
|
126
|
-
|
127
|
-
if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
|
128
|
-
Logger.critical("Input scale mismatch detected between the float model and the GPTQ model. "
|
129
|
-
"Confirm that the input scales for both models are correctly configured and aligned.") # pragma: no cover
|
130
|
-
else:
|
131
|
-
self.input_scale = self.gptq_user_info.input_scale
|
132
101
|
|
133
|
-
|
102
|
+
def _prepare_train_dataloader_sla(self, data_gen_fn: Callable[[], Generator]) -> tf.data.Dataset:
|
103
|
+
"""
|
104
|
+
Computes Sample-Layer Attention score and builds a train dataloader in TensorFlow.
|
134
105
|
|
135
|
-
|
136
|
-
|
137
|
-
SoftQuantizerRegularization,
|
138
|
-
KerasLinearAnnealingScheduler)
|
106
|
+
Args:
|
107
|
+
data_gen_fn: function for representative dataset generation.
|
139
108
|
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
109
|
+
Returns:
|
110
|
+
TensorFlow dataset yielding three outputs - samples, weights for the distillation loss,
|
111
|
+
and weights for regularization.
|
112
|
+
"""
|
113
|
+
# Create a fixed dataset
|
114
|
+
fixed_dataset = FixedTFDataset(data_gen_fn)
|
115
|
+
orig_batch_size = fixed_dataset.orig_batch_size
|
116
|
+
|
117
|
+
# Prepare a separate loader for computing hessians over the whole dataset
|
118
|
+
hess_data_loader = create_tf_dataloader(
|
119
|
+
fixed_dataset,
|
120
|
+
batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size,
|
121
|
+
shuffle=False
|
122
|
+
)
|
123
|
+
|
124
|
+
# Prepare request for Hessian computation
|
125
|
+
request = self._build_hessian_request(
|
126
|
+
granularity=HessianScoresGranularity.PER_OUTPUT_CHANNEL,
|
127
|
+
data_loader=hess_data_loader,
|
128
|
+
n_samples=None
|
129
|
+
)
|
130
|
+
layers_hessians = self.hessian_service.fetch_hessian(request, force_compute=True)
|
131
|
+
|
132
|
+
# Compute SLA score defined as max over elements
|
133
|
+
layers_hessians = {
|
134
|
+
layer: tf.convert_to_tensor(tf.reduce_max(hess, axis=tuple(range(1, len(hess.shape))))) for layer, hess in layers_hessians.items()
|
135
|
+
}
|
136
|
+
|
137
|
+
# Stack hessians for comparison points
|
138
|
+
hessians_tensor = tf.stack([layers_hessians[layer.name] for layer in self.compare_points])
|
139
|
+
assert hessians_tensor.shape[0] == len(self.compare_points)
|
140
|
+
loss_weights = list(hessians_tensor.numpy()) # Convert to a list for compatibility
|
141
|
+
|
142
|
+
# Prepare final dataset with samples and loss weights
|
143
|
+
sla_train_dataset = FixedSampleInfoDataset(fixed_dataset.samples, loss_weights)
|
144
|
+
|
145
|
+
# Calculate regularization weights as mean across samples
|
146
|
+
reg_weights = tf.reduce_mean(hessians_tensor, axis=1)
|
147
|
+
|
148
|
+
# Define a collate function to add regularization weights to each batch
|
149
|
+
def collate_fn(samples_with_loss_weights):
|
150
|
+
return *samples_with_loss_weights, reg_weights
|
151
|
+
|
152
|
+
# Create final dataset using the new dataloader with collate_fn
|
153
|
+
final_dataset = create_tf_dataloader(
|
154
|
+
dataset=sla_train_dataset,
|
155
|
+
batch_size=orig_batch_size,
|
156
|
+
shuffle=True,
|
157
|
+
collate_fn=collate_fn
|
158
|
+
)
|
159
|
+
|
160
|
+
return final_dataset
|
161
|
+
|
162
|
+
def _prepare_train_dataloader_for_non_sla(self,
|
163
|
+
data_gen_fn: Callable[[], Generator]) -> tf.data.Dataset:
|
164
|
+
"""
|
165
|
+
Prepares a train dataloader for non-SLA tasks.
|
146
166
|
|
167
|
+
Args:
|
168
|
+
data_gen_fn: Factory for representative dataset generator.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
A `tf.data.Dataset` yielding samples with loss weights and regularization weights.
|
172
|
+
"""
|
173
|
+
# Step 1: Create a dataset from the generator
|
174
|
+
dataset = TFDatasetFromGenerator(data_gen_fn)
|
147
175
|
num_nodes = len(self.compare_points)
|
148
|
-
|
176
|
+
|
177
|
+
# Step 2: Compute loss weights
|
178
|
+
if self.gptq_config.hessian_weights_config:
|
179
|
+
hessian_dataset = create_tf_dataloader(dataset=dataset, batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
|
180
|
+
hessian_weights = self.compute_hessian_based_weights(hessian_dataset)
|
181
|
+
loss_weights = tf.convert_to_tensor(hessian_weights, dtype=tf.float32)
|
182
|
+
else:
|
183
|
+
loss_weights = tf.ones(num_nodes, dtype=tf.float32) / num_nodes
|
184
|
+
|
185
|
+
# Step 3: Create a dataset with samples and loss weights
|
186
|
+
augmented_dataset = IterableSampleWithConstInfoDataset(dataset.dataset, loss_weights)
|
187
|
+
|
188
|
+
# Step 4: Add constant regularization weights
|
189
|
+
reg_weights = tf.ones(num_nodes, dtype=tf.float32)
|
190
|
+
|
191
|
+
def collate_fn(batch):
|
192
|
+
samples, loss_weights = batch
|
193
|
+
return samples, loss_weights, reg_weights
|
194
|
+
|
195
|
+
# Step 5: Create a tf.data.Dataset with collate_fn
|
196
|
+
train_dataloader = create_tf_dataloader(augmented_dataset,
|
197
|
+
batch_size=dataset.orig_batch_size,
|
198
|
+
collate_fn=collate_fn)
|
199
|
+
|
200
|
+
return train_dataloader
|
149
201
|
|
150
202
|
def _is_gptq_weights_trainable(self,
|
151
203
|
node: common.BaseNode) -> bool:
|
@@ -226,9 +278,13 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
226
278
|
|
227
279
|
return gptq_model, gptq_user_info
|
228
280
|
|
229
|
-
def compute_gradients(self,
|
281
|
+
def compute_gradients(self,
|
282
|
+
in_y_float: List[tf.Tensor],
|
283
|
+
input_data: List[np.ndarray],
|
230
284
|
in_optimizer_with_param: List,
|
231
|
-
training=True
|
285
|
+
training=True,
|
286
|
+
distill_loss_weights=None,
|
287
|
+
reg_weights=None) -> Tuple[tf.Tensor, List[tf.Tensor]]:
|
232
288
|
"""
|
233
289
|
Get outputs from both teacher and student networks. Compute the observed error,
|
234
290
|
and use it to compute the gradients and applying them to the student weights.
|
@@ -253,9 +309,9 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
253
309
|
self.flp_weights_list,
|
254
310
|
self.compare_points_mean,
|
255
311
|
self.compare_points_std,
|
256
|
-
|
312
|
+
distill_loss_weights)
|
257
313
|
|
258
|
-
reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor)
|
314
|
+
reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor, reg_weights)
|
259
315
|
|
260
316
|
loss_value += reg_value
|
261
317
|
|
@@ -279,14 +335,19 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
279
335
|
# Training loop
|
280
336
|
# ----------------------------------------------
|
281
337
|
if self.has_params_to_train:
|
282
|
-
self.micro_training_loop(
|
283
|
-
compute_gradients,
|
338
|
+
self.micro_training_loop(compute_gradients,
|
284
339
|
self.optimizer_with_param,
|
285
340
|
self.gptq_config.n_epochs,
|
286
341
|
True)
|
287
342
|
|
288
343
|
@tf.function
|
289
|
-
def nano_training_step(self,
|
344
|
+
def nano_training_step(self,
|
345
|
+
input_data,
|
346
|
+
in_compute_gradients,
|
347
|
+
in_optimizer_with_param,
|
348
|
+
is_training,
|
349
|
+
distill_loss_weights,
|
350
|
+
reg_weights):
|
290
351
|
"""
|
291
352
|
This function run part of the training step, wrapped by a tf.function for acceleration.
|
292
353
|
Args:
|
@@ -303,12 +364,15 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
303
364
|
# run float model
|
304
365
|
y_float = self.float_model(input_data)
|
305
366
|
# rung quantized model and calculate loss & gradients
|
306
|
-
loss_value_step, grads = in_compute_gradients(y_float,
|
307
|
-
|
367
|
+
loss_value_step, grads = in_compute_gradients(y_float,
|
368
|
+
input_data,
|
369
|
+
in_optimizer_with_param,
|
370
|
+
training=is_training,
|
371
|
+
distill_loss_weights=distill_loss_weights,
|
372
|
+
reg_weights=reg_weights)
|
308
373
|
return loss_value_step, grads
|
309
374
|
|
310
375
|
def micro_training_loop(self,
|
311
|
-
data_function: Callable,
|
312
376
|
in_compute_gradients: Callable,
|
313
377
|
in_optimizer_with_param: List[Tuple[tf.keras.optimizers.Optimizer, List[tf.Tensor]]],
|
314
378
|
n_epochs: int,
|
@@ -316,7 +380,6 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
316
380
|
"""
|
317
381
|
This function run a micro training loop on given set of parameters.
|
318
382
|
Args:
|
319
|
-
data_function: A callable function that give a batch of samples.
|
320
383
|
in_compute_gradients: A callable function that compute the gradients.
|
321
384
|
in_optimizer_with_param: A list of optimizer classes to update with the corresponding parameters.
|
322
385
|
n_epochs: Number of update iterations of representative dataset.
|
@@ -327,12 +390,19 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
327
390
|
"""
|
328
391
|
with tqdm(range(n_epochs), "Running GPTQ optimization") as epochs_pbar:
|
329
392
|
for _ in epochs_pbar:
|
330
|
-
with tqdm(
|
393
|
+
with tqdm(self.train_dataloader, position=1, leave=False) as data_pbar:
|
331
394
|
for data in data_pbar:
|
332
|
-
input_data = [d * self.input_scale for d in data]
|
333
395
|
|
334
|
-
|
335
|
-
|
396
|
+
input_data, distill_loss_weights, reg_weight = data
|
397
|
+
|
398
|
+
input_data = [d * self.input_scale for d in input_data]
|
399
|
+
|
400
|
+
loss_value_step, grads = self.nano_training_step(input_data,
|
401
|
+
in_compute_gradients,
|
402
|
+
in_optimizer_with_param,
|
403
|
+
is_training,
|
404
|
+
distill_loss_weights,
|
405
|
+
reg_weight)
|
336
406
|
# Run one step of gradient descent by updating
|
337
407
|
# the value of the variables to minimize the loss.
|
338
408
|
for i, (o, p) in enumerate(in_optimizer_with_param):
|
@@ -16,7 +16,6 @@
|
|
16
16
|
import tensorflow as tf
|
17
17
|
from typing import Tuple, List
|
18
18
|
from model_compression_toolkit.core.keras.constants import USE_BIAS
|
19
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
20
19
|
from tensorflow.keras.models import Model
|
21
20
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
22
21
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
@@ -26,7 +25,6 @@ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_qu
|
|
26
25
|
|
27
26
|
|
28
27
|
def get_gptq_trainable_parameters(fxp_model: Model,
|
29
|
-
fw_info: FrameworkInfo,
|
30
28
|
add_bias: bool = False) -> (
|
31
29
|
List[tf.Variable], List[tf.Variable], List[tf.Variable]):
|
32
30
|
"""
|
@@ -34,7 +32,6 @@ def get_gptq_trainable_parameters(fxp_model: Model,
|
|
34
32
|
|
35
33
|
Args:
|
36
34
|
fxp_model: Model to get its trainable parameters.
|
37
|
-
fw_info: Framework information needed for keras kernel ops list.
|
38
35
|
add_bias: Whether to include biases of the model (if there are) or not.
|
39
36
|
|
40
37
|
Returns:
|
@@ -60,7 +57,7 @@ def get_gptq_trainable_parameters(fxp_model: Model,
|
|
60
57
|
trainable_threshold.extend(quantizer_trainable_threshold)
|
61
58
|
|
62
59
|
if add_bias:
|
63
|
-
kernel_ops_attrs =
|
60
|
+
kernel_ops_attrs = DEFAULT_KERAS_INFO.kernel_ops_attributes_mapping.get(type(layer.layer))
|
64
61
|
use_bias = kernel_ops_attrs is not None and kernel_ops_attrs[0] is not None \
|
65
62
|
and layer.layer.get_config().get(USE_BIAS)
|
66
63
|
if use_bias is not None and use_bias and layer.layer.bias is not None:
|
@@ -19,7 +19,7 @@ from packaging import version
|
|
19
19
|
|
20
20
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
21
21
|
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \
|
22
|
-
LR_BIAS_DEFAULT, GPTQ_MOMENTUM
|
22
|
+
LR_BIAS_DEFAULT, GPTQ_MOMENTUM, REG_DEFAULT_SLA
|
23
23
|
from model_compression_toolkit.logger import Logger
|
24
24
|
from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE, GPTQ_HESSIAN_NUM_SAMPLES
|
25
25
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
@@ -42,7 +42,7 @@ if FOUND_TF:
|
|
42
42
|
from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
|
43
43
|
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
44
44
|
from tensorflow.keras.models import Model
|
45
|
-
from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss
|
45
|
+
from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss, sample_layer_attention_loss
|
46
46
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
47
47
|
from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
|
48
48
|
from model_compression_toolkit import get_target_platform_capabilities
|
@@ -61,11 +61,12 @@ if FOUND_TF:
|
|
61
61
|
def get_keras_gptq_config(n_epochs: int,
|
62
62
|
optimizer: OptimizerV2 = None,
|
63
63
|
optimizer_rest: OptimizerV2 = None,
|
64
|
-
loss: Callable =
|
64
|
+
loss: Callable = None,
|
65
65
|
log_function: Callable = None,
|
66
66
|
use_hessian_based_weights: bool = True,
|
67
|
-
regularization_factor: float =
|
67
|
+
regularization_factor: float = None,
|
68
68
|
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE,
|
69
|
+
use_hessian_sample_attention: bool = False,
|
69
70
|
gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = False) -> GradientPTQConfig:
|
70
71
|
"""
|
71
72
|
Create a GradientPTQConfig instance for Keras models.
|
@@ -79,6 +80,7 @@ if FOUND_TF:
|
|
79
80
|
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
|
80
81
|
regularization_factor (float): A floating point number that defines the regularization factor.
|
81
82
|
hessian_batch_size (int): Batch size for Hessian computation in Hessian-based weights GPTQ.
|
83
|
+
use_hessian_sample_attention (bool): whether to use Sample-Layer Attention score for weighted loss.
|
82
84
|
gradual_activation_quantization (bool, GradualActivationQuantizationConfig): If False, GradualActivationQuantization is disabled. If True, GradualActivationQuantization is enabled with the default settings. GradualActivationQuantizationConfig object can be passed to use non-default settings.
|
83
85
|
|
84
86
|
returns:
|
@@ -105,9 +107,25 @@ if FOUND_TF:
|
|
105
107
|
"""
|
106
108
|
optimizer = optimizer or tf.keras.optimizers.Adam(learning_rate=LR_DEFAULT)
|
107
109
|
optimizer_rest = optimizer_rest or tf.keras.optimizers.Adam(learning_rate=LR_REST_DEFAULT)
|
110
|
+
bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
|
108
111
|
|
109
|
-
|
110
|
-
|
112
|
+
if regularization_factor is None:
|
113
|
+
regularization_factor = REG_DEFAULT_SLA if use_hessian_sample_attention else REG_DEFAULT
|
114
|
+
|
115
|
+
loss = loss or GPTQMultipleTensorsLoss()
|
116
|
+
hessian_weights_config = None
|
117
|
+
if use_hessian_sample_attention:
|
118
|
+
if not use_hessian_based_weights: # pragma: no cover
|
119
|
+
raise ValueError('use_hessian_based_weights must be set to True in order to use Sample Layer Attention.')
|
120
|
+
|
121
|
+
hessian_weights_config = GPTQHessianScoresConfig(per_sample=True,
|
122
|
+
hessians_num_samples=None,
|
123
|
+
hessian_batch_size=hessian_batch_size)
|
124
|
+
loss = loss or sample_layer_attention_loss
|
125
|
+
elif use_hessian_based_weights:
|
126
|
+
hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
|
127
|
+
hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
|
128
|
+
hessian_batch_size=hessian_batch_size)
|
111
129
|
|
112
130
|
if isinstance(gradual_activation_quantization, bool):
|
113
131
|
gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None
|
@@ -117,11 +135,6 @@ if FOUND_TF:
|
|
117
135
|
raise TypeError(f'gradual_activation_quantization argument should be bool or '
|
118
136
|
f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}')
|
119
137
|
|
120
|
-
hessian_weights_config = None
|
121
|
-
if use_hessian_based_weights:
|
122
|
-
hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
|
123
|
-
hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
|
124
|
-
hessian_batch_size=hessian_batch_size)
|
125
138
|
return GradientPTQConfig(n_epochs=n_epochs,
|
126
139
|
optimizer=optimizer,
|
127
140
|
optimizer_rest=optimizer_rest,
|
@@ -40,30 +40,42 @@ class SoftQuantizerRegularization:
|
|
40
40
|
self.count_iter = tf.Variable(0.)
|
41
41
|
|
42
42
|
|
43
|
-
def __call__(self, model: Model, entropy_reg: float):
|
43
|
+
def __call__(self, model: Model, entropy_reg: float, layer_weights: tf.Tensor):
|
44
44
|
"""
|
45
45
|
Returns the soft quantizer regularization value for SoftRounding.
|
46
46
|
|
47
47
|
Args:
|
48
48
|
model: A model to be quantized with SoftRounding.
|
49
49
|
entropy_reg: Entropy value to scale the quantizer regularization.
|
50
|
+
layer_weights: a vector of layers weights.
|
50
51
|
|
51
52
|
Returns: Regularization value.
|
52
53
|
"""
|
53
|
-
|
54
|
+
layers = [l for l in model.layers if isinstance(l, KerasTrainableQuantizationWrapper)]
|
55
|
+
|
56
|
+
if layer_weights.shape[0] != len(layers):
|
57
|
+
raise ValueError(f'Expected weights.shape[0] to be {len(layers)}, '
|
58
|
+
f'received shape {layer_weights.shape}.') # pragma: no cover
|
59
|
+
|
54
60
|
b = self.beta_scheduler(self.count_iter.value())
|
55
|
-
for layer in model.layers:
|
56
|
-
if isinstance(layer, KerasTrainableQuantizationWrapper):
|
57
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
58
|
-
fw_info=DEFAULT_KERAS_INFO)
|
59
61
|
|
60
|
-
|
61
|
-
|
62
|
+
max_w = tf.reduce_max(layer_weights)
|
63
|
+
|
64
|
+
# Initialize reg to zero
|
65
|
+
reg = tf.constant(0.0, dtype=tf.float32)
|
66
|
+
|
67
|
+
# Compute the regularization term without concatenating
|
68
|
+
for i, layer in enumerate(layers):
|
69
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
70
|
+
fw_info=DEFAULT_KERAS_INFO)
|
71
|
+
|
72
|
+
st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
|
62
73
|
|
63
|
-
|
74
|
+
soft_loss = tf.reduce_sum(1 - tf.pow(tf.math.abs(st - 0.5) * 2, b))
|
75
|
+
reg += layer_weights[i] * soft_loss
|
64
76
|
|
65
|
-
|
66
|
-
|
77
|
+
# Normalize reg by max_w
|
78
|
+
reg = reg / max_w
|
67
79
|
|
68
80
|
self.count_iter.assign_add(1.0)
|
69
81
|
|
@@ -21,9 +21,6 @@ from torch.nn import Module
|
|
21
21
|
from torch.utils.data import DataLoader
|
22
22
|
from tqdm import tqdm
|
23
23
|
|
24
|
-
from model_compression_toolkit.gptq.common.gradual_activation_quantization import get_gradual_activation_quantizer_wrapper_factory
|
25
|
-
from model_compression_toolkit.gptq.common.regularization_factory import get_regularization
|
26
|
-
|
27
24
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
28
25
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
29
26
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
@@ -41,7 +38,6 @@ from model_compression_toolkit.gptq.pytorch.graph_info import get_gptq_trainable
|
|
41
38
|
from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder
|
42
39
|
|
43
40
|
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
|
44
|
-
from model_compression_toolkit.trainable_infrastructure.common.util import get_total_grad_steps
|
45
41
|
from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import PytorchLinearAnnealingScheduler
|
46
42
|
from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import SoftQuantizerRegularization as PytorchSoftQuantizerRegularization
|
47
43
|
|
@@ -76,13 +72,10 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
76
72
|
representative_data_gen: Dataset to use for inputs of the models.
|
77
73
|
hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model.
|
78
74
|
"""
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
# must be set prior to model building in the base class constructor
|
84
|
-
self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(
|
85
|
-
gptq_config, _get_total_grad_steps, PytorchLinearAnnealingScheduler)
|
75
|
+
self.fw_soft_quantizer_regularization = PytorchSoftQuantizerRegularization
|
76
|
+
self.fw_linear_annealing_scheduler = PytorchLinearAnnealingScheduler
|
77
|
+
self.fw_get_gptq_trainable_parameters_fn = get_gptq_trainable_parameters
|
78
|
+
self.fw_get_weights_for_loss_fn = get_weights_for_loss
|
86
79
|
|
87
80
|
super().__init__(graph_float,
|
88
81
|
graph_quant,
|
@@ -92,40 +85,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
92
85
|
representative_data_gen_fn=representative_data_gen,
|
93
86
|
hessian_info_service=hessian_info_service)
|
94
87
|
|
95
|
-
self.loss_list = []
|
96
|
-
self.input_scale = 1
|
97
|
-
if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
|
98
|
-
Logger.critical("Input scale mismatch between float and GPTQ networks. "
|
99
|
-
"Ensure both networks have matching input scales.") # pragma: no cover
|
100
|
-
else:
|
101
|
-
self.input_scale = self.gptq_user_info.input_scale
|
102
|
-
|
103
|
-
trainable_weights, trainable_bias, trainable_threshold = get_gptq_trainable_parameters(
|
104
|
-
self.fxp_model,
|
105
|
-
add_bias=self.gptq_config.train_bias)
|
106
|
-
|
107
|
-
self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
|
108
|
-
if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
|
109
|
-
self.fxp_weights_list)):
|
110
|
-
Logger.critical("GPTQ: Number of comparison points, layers with trainable weights, "
|
111
|
-
"and float vs. quantized weights for loss calculation do not match. "
|
112
|
-
"Verify consistency across these parameters for successful GPTQ training.")
|
113
|
-
|
114
|
-
self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
|
115
|
-
trainable_bias,
|
116
|
-
trainable_threshold)
|
117
|
-
hessian_cfg = self.gptq_config.hessian_weights_config
|
118
|
-
|
119
|
-
self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample
|
120
|
-
if self.use_sample_layer_attention:
|
121
|
-
# normalization is currently not supported, make sure the config reflects it.
|
122
|
-
if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
|
123
|
-
raise NotImplementedError()
|
124
|
-
self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen)
|
125
|
-
else:
|
126
|
-
self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen)
|
127
|
-
|
128
|
-
self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps, PytorchSoftQuantizerRegularization, PytorchLinearAnnealingScheduler)
|
129
88
|
|
130
89
|
def _prepare_train_dataloader_sla(self, data_gen_fn: Callable[[], Generator]) -> DataLoader:
|
131
90
|
"""
|
{mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/top_level.txt
RENAMED
File without changes
|