mct-nightly 2.4.0.20250924.535__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -56,7 +56,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
|
|
56
56
|
super(ReduceLROnPlateau, self).__init__()
|
57
57
|
|
58
58
|
if factor >= 1.0:
|
59
|
-
Logger.critical('Factor should be < 1.0.')
|
59
|
+
Logger.critical('Factor should be < 1.0.') # pragma: no cover
|
60
60
|
self.factor = factor
|
61
61
|
|
62
62
|
self.optimizer = optimizer
|
@@ -101,7 +101,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
|
|
101
101
|
else:
|
102
102
|
self.num_bad_epochs += 1
|
103
103
|
|
104
|
-
if self.in_cooldown:
|
104
|
+
if self.in_cooldown:
|
105
105
|
self.cooldown_counter -= 1
|
106
106
|
self.num_bad_epochs = 0 # Ignore any bad epochs in cooldown
|
107
107
|
|
@@ -122,7 +122,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
|
|
122
122
|
new_lr = max(old_lr * self.factor, self.min_lr)
|
123
123
|
if old_lr - new_lr > self.eps:
|
124
124
|
tf.keras.backend.set_value(self.optimizer.learning_rate, new_lr)
|
125
|
-
if self.verbose:
|
125
|
+
if self.verbose:
|
126
126
|
print(f'Epoch {epoch:05d}: reducing learning rate to {new_lr:.4e}.')
|
127
127
|
|
128
128
|
@property
|
@@ -152,13 +152,13 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
|
|
152
152
|
if self.mode == 'min' and self.threshold_mode == 'rel':
|
153
153
|
rel_epsilon = 1. - self.threshold
|
154
154
|
return a < best * rel_epsilon
|
155
|
-
elif self.mode == 'min' and self.threshold_mode == 'abs':
|
155
|
+
elif self.mode == 'min' and self.threshold_mode == 'abs':
|
156
156
|
return a < best - self.threshold
|
157
|
-
elif self.mode == 'max' and self.threshold_mode == 'rel':
|
157
|
+
elif self.mode == 'max' and self.threshold_mode == 'rel':
|
158
158
|
rel_epsilon = self.threshold + 1.
|
159
159
|
return a > best * rel_epsilon
|
160
160
|
else: # mode == 'max' and threshold_mode == 'abs':
|
161
|
-
return a > best + self.threshold
|
161
|
+
return a > best + self.threshold
|
162
162
|
|
163
163
|
def _init_is_better(self, mode: str, threshold: float, threshold_mode: str) -> None:
|
164
164
|
"""
|
@@ -186,7 +186,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
|
|
186
186
|
self.threshold = threshold
|
187
187
|
self.threshold_mode = threshold_mode
|
188
188
|
|
189
|
-
def get_config(self) -> Dict:
|
189
|
+
def get_config(self) -> Dict:
|
190
190
|
"""
|
191
191
|
Return the configuration of the scheduler as a dictionary.
|
192
192
|
|
@@ -207,7 +207,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
|
|
207
207
|
base_config = super(ReduceLROnPlateau, self).get_config()
|
208
208
|
return {**base_config, **config}
|
209
209
|
|
210
|
-
def set_config(self, config: Dict) -> None:
|
210
|
+
def set_config(self, config: Dict) -> None:
|
211
211
|
"""
|
212
212
|
Set the configuration of the scheduler from a dictionary.
|
213
213
|
|
@@ -60,10 +60,10 @@ class ReduceLROnPlateauWithReset:
|
|
60
60
|
# Attach optimizer
|
61
61
|
if not isinstance(optimizer, Optimizer):
|
62
62
|
Logger.critical('{} is not an Optimizer'.format(
|
63
|
-
type(optimizer).__name__))
|
63
|
+
type(optimizer).__name__)) # pragma: no cover
|
64
64
|
self.optimizer = optimizer
|
65
65
|
|
66
|
-
if isinstance(min_lr, (list, tuple)):
|
66
|
+
if isinstance(min_lr, (list, tuple)):
|
67
67
|
if len(min_lr) != len(optimizer.param_groups):
|
68
68
|
Logger.critical("expected {} min_lrs, got {}".format(
|
69
69
|
len(optimizer.param_groups), len(min_lr))) # pragma: no cover
|
@@ -117,7 +117,7 @@ class ReduceLROnPlateauWithReset:
|
|
117
117
|
self.num_bad_epochs += 1
|
118
118
|
|
119
119
|
# Handle cooldown period
|
120
|
-
if self.in_cooldown:
|
120
|
+
if self.in_cooldown:
|
121
121
|
self.cooldown_counter -= 1
|
122
122
|
self.num_bad_epochs = 0 # Ignore any bad epochs in cooldown
|
123
123
|
|
@@ -142,7 +142,7 @@ class ReduceLROnPlateauWithReset:
|
|
142
142
|
new_lr = max(old_lr * self.factor, self.min_lrs[i])
|
143
143
|
if old_lr - new_lr > self.eps:
|
144
144
|
param_group['lr'] = new_lr
|
145
|
-
if self.verbose:
|
145
|
+
if self.verbose:
|
146
146
|
epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch
|
147
147
|
print('Epoch {}: reducing learning rate'
|
148
148
|
' of group {} to {:.4e}.'.format(epoch_str, i, new_lr))
|
@@ -168,19 +168,19 @@ class ReduceLROnPlateauWithReset:
|
|
168
168
|
Returns:
|
169
169
|
bool: True if the new value is better, False otherwise.
|
170
170
|
"""
|
171
|
-
if best is None:
|
171
|
+
if best is None:
|
172
172
|
return True
|
173
173
|
|
174
174
|
if self.mode == 'min' and self.threshold_mode == 'rel':
|
175
175
|
rel_epsilon = 1. - self.threshold
|
176
176
|
return a < best * rel_epsilon
|
177
|
-
elif self.mode == 'min' and self.threshold_mode == 'abs':
|
177
|
+
elif self.mode == 'min' and self.threshold_mode == 'abs':
|
178
178
|
return a < best - self.threshold
|
179
|
-
elif self.mode == 'max' and self.threshold_mode == 'rel':
|
179
|
+
elif self.mode == 'max' and self.threshold_mode == 'rel':
|
180
180
|
rel_epsilon = self.threshold + 1.
|
181
181
|
return a > best * rel_epsilon
|
182
182
|
else: # mode == 'max' and threshold_mode == 'abs':
|
183
|
-
return a > best + self.threshold
|
183
|
+
return a > best + self.threshold
|
184
184
|
|
185
185
|
def _init_is_better(self) -> None:
|
186
186
|
"""
|
@@ -197,9 +197,9 @@ class ReduceLROnPlateauWithReset:
|
|
197
197
|
if self.mode == 'min':
|
198
198
|
self.mode_worse = float('inf')
|
199
199
|
else: # mode == 'max':
|
200
|
-
self.mode_worse = float('-inf')
|
200
|
+
self.mode_worse = float('-inf')
|
201
201
|
|
202
|
-
def state_dict(self) -> Dict[str, Any]:
|
202
|
+
def state_dict(self) -> Dict[str, Any]:
|
203
203
|
"""
|
204
204
|
Return the state of the scheduler as a dictionary.
|
205
205
|
|
@@ -208,7 +208,7 @@ class ReduceLROnPlateauWithReset:
|
|
208
208
|
"""
|
209
209
|
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
210
210
|
|
211
|
-
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
211
|
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
212
212
|
"""
|
213
213
|
Load the scheduler state.
|
214
214
|
|
@@ -21,7 +21,6 @@ from model_compression_toolkit.logger import Logger
|
|
21
21
|
|
22
22
|
if FOUND_TF:
|
23
23
|
import keras
|
24
|
-
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
|
25
24
|
from model_compression_toolkit.exporter.model_wrapper.keras.validate_layer import is_keras_layer_exportable
|
26
25
|
from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import \
|
27
26
|
FakelyQuantKerasExporter
|
@@ -37,7 +36,6 @@ if FOUND_TF:
|
|
37
36
|
KerasExportSerializationFormat.TFLITE: [QuantizationFormat.FAKELY_QUANT, QuantizationFormat.INT8]
|
38
37
|
}
|
39
38
|
|
40
|
-
@set_keras_info
|
41
39
|
def keras_export_model(model: keras.models.Model,
|
42
40
|
save_model_path: str,
|
43
41
|
is_layer_exportable_fn: Callable = is_keras_layer_exportable,
|
model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py
CHANGED
@@ -19,6 +19,7 @@ import torch.nn
|
|
19
19
|
|
20
20
|
from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
|
21
21
|
|
22
|
+
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
|
22
23
|
from model_compression_toolkit.verify_packages import FOUND_ONNX
|
23
24
|
from model_compression_toolkit.logger import Logger
|
24
25
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
@@ -27,7 +27,6 @@ DEFAULT_ONNX_OPSET_VERSION = 15
|
|
27
27
|
|
28
28
|
if FOUND_TORCH:
|
29
29
|
import torch.nn
|
30
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
|
31
30
|
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import FakelyQuantONNXPyTorchExporter
|
32
31
|
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import FakelyQuantTorchScriptPyTorchExporter
|
33
32
|
from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
|
@@ -42,14 +41,13 @@ if FOUND_TORCH:
|
|
42
41
|
PytorchExportSerializationFormat.ONNX: [QuantizationFormat.FAKELY_QUANT, QuantizationFormat.MCTQ]
|
43
42
|
}
|
44
43
|
|
45
|
-
@set_pytorch_info
|
46
44
|
def pytorch_export_model(model: torch.nn.Module,
|
47
45
|
save_model_path: str,
|
48
46
|
repr_dataset: Callable,
|
49
47
|
is_layer_exportable_fn: Callable = is_pytorch_layer_exportable,
|
50
48
|
serialization_format: PytorchExportSerializationFormat = PytorchExportSerializationFormat.ONNX,
|
51
49
|
quantization_format: QuantizationFormat = QuantizationFormat.MCTQ,
|
52
|
-
onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION,
|
50
|
+
onnx_opset_version: int = DEFAULT_ONNX_OPSET_VERSION,
|
53
51
|
output_names: Optional[List[str]] = None) -> None:
|
54
52
|
"""
|
55
53
|
Export a PyTorch quantized model to a torchscript or onnx model.
|
@@ -60,16 +58,14 @@ if FOUND_TORCH:
|
|
60
58
|
(where the model will be saved to ONNX model).
|
61
59
|
|
62
60
|
Args:
|
63
|
-
model: Model to export.
|
64
|
-
save_model_path: Path to save the model.
|
65
|
-
repr_dataset: Representative dataset for tracing the pytorch model (mandatory for exporting it).
|
66
|
-
is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
|
67
|
-
serialization_format: Format to export the model according to (by default
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
output_names (Optional[List[str]]): Optional list of output node names for export compatibility.
|
72
|
-
This argument is relevant only when using PytorchExportSerializationFormat.ONNX.
|
61
|
+
model (Module): Model to export.
|
62
|
+
save_model_path (str): Path to save the model.
|
63
|
+
repr_dataset (Callable): Representative dataset for tracing the pytorch model (mandatory for exporting it).
|
64
|
+
is_layer_exportable_fn (Callable): Callable to check whether a layer can be exported or not.
|
65
|
+
serialization_format (PytorchExportSerializationFormat): Format to export the model according to (by default PytorchExportSerializationFormat.ONNX).
|
66
|
+
quantization_format (QuantizationFormat): Format of how quantizers are exported (fakely-quant, int8, MCTQ quantizers).
|
67
|
+
onnx_opset_version (int): ONNX opset version to use for exported ONNX model.
|
68
|
+
output_names (Optional[List[str]]): Optional list of output node names for export compatibility. This argument is relevant only when using PytorchExportSerializationFormat.ONNX.
|
73
69
|
|
74
70
|
"""
|
75
71
|
# Ensure 'metadata' is available directly on the model, if present in submodules
|
@@ -14,8 +14,8 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
from typing import Tuple, List
|
16
16
|
|
17
|
+
from model_compression_toolkit.core import FrameworkInfo
|
17
18
|
from model_compression_toolkit.logger import Logger
|
18
|
-
from model_compression_toolkit.core.common.framework_info import get_fw_info
|
19
19
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
20
20
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
21
21
|
|
@@ -40,7 +40,8 @@ def get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str], L
|
|
40
40
|
compare_points_name = []
|
41
41
|
for n in input_graph.get_topo_sorted_nodes():
|
42
42
|
# only nodes with kernel attribute are currently trained with GPTQ and are used as compare points
|
43
|
-
|
43
|
+
kernel_attr = input_graph.fw_info.get_kernel_op_attributes(n.type)[0]
|
44
|
+
if kernel_attr is not None and n.is_weights_quantization_enabled(kernel_attr) and not n.reuse:
|
44
45
|
compare_points.append(n)
|
45
46
|
compare_points_name.append(n.name)
|
46
47
|
compare_points_std.append(n.prior_info.std_output)
|
@@ -48,15 +49,20 @@ def get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str], L
|
|
48
49
|
return compare_points, compare_points_name, compare_points_mean, compare_points_std
|
49
50
|
|
50
51
|
|
51
|
-
def get_kernel_attribute_name_for_gptq(layer_type: type) -> str:
|
52
|
+
def get_kernel_attribute_name_for_gptq(layer_type: type, fw_info: FrameworkInfo) -> str:
|
52
53
|
"""
|
53
54
|
Returns a layer's kernel attribute name for GPTQ training purposes.
|
54
55
|
|
55
56
|
Args:
|
56
57
|
layer_type: A type of model's layer.
|
58
|
+
fw_info: A FrameworkInfo object.
|
57
59
|
|
58
60
|
Returns: The name of the kernel attribute.
|
59
61
|
|
60
62
|
"""
|
61
|
-
|
62
|
-
|
63
|
+
kernel_attribute = fw_info.get_kernel_op_attributes(layer_type)
|
64
|
+
if len(kernel_attribute) != 1:
|
65
|
+
Logger.critical( # pragma: no cover
|
66
|
+
f"In GPTQ training, only the kernel weights attribute should be trained. "
|
67
|
+
f"However, the number of kernel attributes is {len(kernel_attribute)}.")
|
68
|
+
return kernel_attribute[0]
|
@@ -44,6 +44,7 @@ class GPTQTrainer(ABC):
|
|
44
44
|
graph_quant: Graph,
|
45
45
|
gptq_config: GradientPTQConfig,
|
46
46
|
fw_impl: GPTQFrameworkImplemantation,
|
47
|
+
fw_info: FrameworkInfo,
|
47
48
|
representative_data_gen_fn: Callable[[], Generator],
|
48
49
|
hessian_info_service: HessianInfoService = None):
|
49
50
|
"""
|
@@ -57,6 +58,7 @@ class GPTQTrainer(ABC):
|
|
57
58
|
graph_quant: Graph to build a quantized networks from.
|
58
59
|
gptq_config: GradientPTQConfig with parameters about the tuning process.
|
59
60
|
fw_impl: Framework implementation
|
61
|
+
fw_info: Framework information
|
60
62
|
representative_data_gen_fn: factory for representative data generator.
|
61
63
|
hessian_info_service: HessianInfoService for fetching and computing Hessian-approximation information.
|
62
64
|
"""
|
@@ -64,6 +66,7 @@ class GPTQTrainer(ABC):
|
|
64
66
|
self.graph_quant = copy.deepcopy(graph_quant)
|
65
67
|
self.gptq_config = gptq_config
|
66
68
|
self.fw_impl = fw_impl
|
69
|
+
self.fw_info = fw_info
|
67
70
|
self.representative_data_gen_fn = representative_data_gen_fn
|
68
71
|
|
69
72
|
def _get_total_grad_steps():
|
@@ -80,7 +83,8 @@ class GPTQTrainer(ABC):
|
|
80
83
|
|
81
84
|
self.float_model, self.float_user_info = fw_impl.model_builder(self.graph_float,
|
82
85
|
mode=ModelBuilderMode.FLOAT,
|
83
|
-
append2output=self.compare_points
|
86
|
+
append2output=self.compare_points,
|
87
|
+
fw_info=self.fw_info)
|
84
88
|
|
85
89
|
self.fxp_model, self.gptq_user_info = self.build_gptq_model()
|
86
90
|
if self.gptq_config.hessian_weights_config:
|
@@ -284,6 +288,7 @@ def gptq_training(graph_float: Graph,
|
|
284
288
|
gptq_config: GradientPTQConfig,
|
285
289
|
representative_data_gen: Callable,
|
286
290
|
fw_impl: GPTQFrameworkImplemantation,
|
291
|
+
fw_info: FrameworkInfo,
|
287
292
|
hessian_info_service: HessianInfoService = None) -> Graph:
|
288
293
|
"""
|
289
294
|
GPTQ training process using knowledge distillation with a teacher network (float model) and a student network (quantized model).
|
@@ -293,6 +298,7 @@ def gptq_training(graph_float: Graph,
|
|
293
298
|
gptq_config: GradientPTQConfig with parameters about the tuning process.
|
294
299
|
representative_data_gen: Dataset to use for inputs of the models.
|
295
300
|
fw_impl: Framework implementation
|
301
|
+
fw_info: Framework information
|
296
302
|
hessian_info_service: HessianInfoService to fetch information based on the Hessian approximation.
|
297
303
|
|
298
304
|
Returns:
|
@@ -306,6 +312,7 @@ def gptq_training(graph_float: Graph,
|
|
306
312
|
graph_quant,
|
307
313
|
gptq_config,
|
308
314
|
fw_impl,
|
315
|
+
fw_info,
|
309
316
|
representative_data_gen,
|
310
317
|
hessian_info_service=hessian_info_service)
|
311
318
|
|
@@ -65,6 +65,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
65
65
|
graph_quant: Graph,
|
66
66
|
gptq_config: GradientPTQConfig,
|
67
67
|
fw_impl: FrameworkImplementation,
|
68
|
+
fw_info: FrameworkInfo,
|
68
69
|
representative_data_gen: Callable,
|
69
70
|
hessian_info_service: HessianInfoService = None):
|
70
71
|
"""
|
@@ -78,6 +79,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
78
79
|
graph_quant: Graph to build a quantized networks from.
|
79
80
|
gptq_config: GradientPTQConfig with parameters about the tuning process.
|
80
81
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
82
|
+
fw_info: Framework information.
|
81
83
|
representative_data_gen: Dataset to use for inputs of the models.
|
82
84
|
hessian_info_service: HessianScoresService for fetching and computing Hessian's approximation scores.
|
83
85
|
|
@@ -92,6 +94,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
92
94
|
graph_quant,
|
93
95
|
gptq_config,
|
94
96
|
fw_impl,
|
97
|
+
fw_info,
|
95
98
|
representative_data_gen_fn=representative_data_gen,
|
96
99
|
hessian_info_service=hessian_info_service)
|
97
100
|
|
@@ -207,7 +210,8 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
207
210
|
Returns:
|
208
211
|
A boolean whether the layer is to be wrapped with a QuantizeWrapper
|
209
212
|
"""
|
210
|
-
|
213
|
+
kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)[0]
|
214
|
+
return kernel_attr is not None and node.is_weights_quantization_enabled(kernel_attr)
|
211
215
|
|
212
216
|
def gptq_wrapper(self,
|
213
217
|
n: common.BaseNode,
|
@@ -226,7 +230,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
226
230
|
# If we are here, then the node has a kernel attribute to quantize and training during GPTQ
|
227
231
|
weights_quantizers, _ = quantization_builder(n,
|
228
232
|
self.gptq_config, # TODO: split quantizers building into two functions: for weights and activations
|
229
|
-
n.
|
233
|
+
self.fw_info.get_kernel_op_attributes(n.type)[0])
|
230
234
|
if len(weights_quantizers) > 0:
|
231
235
|
return KerasTrainableQuantizationWrapper(layer,
|
232
236
|
weights_quantizers=weights_quantizers)
|
@@ -267,6 +271,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
267
271
|
|
268
272
|
gptq_model, gptq_user_info = KerasModelBuilder(graph=self.graph_quant,
|
269
273
|
append2output=self.compare_points,
|
274
|
+
fw_info=self.fw_info,
|
270
275
|
return_float_outputs=True,
|
271
276
|
wrapper=self.gptq_wrapper,
|
272
277
|
get_activation_quantizer_holder_fn=self.get_activation_quantizer_holder).build_model()
|
@@ -426,7 +431,8 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
426
431
|
Logger.critical(f"Unable to update the GPTQ graph because the layer named '{layer.layer.name}' could not be found. "
|
427
432
|
f"Verify that the layer names in the GPTQ model match those in the graph.")
|
428
433
|
node = node[0]
|
429
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type
|
434
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
|
435
|
+
fw_info=self.fw_info)
|
430
436
|
# TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
|
431
437
|
# To enable GPTQ for other attributes, this code needs to be modified.
|
432
438
|
weights, weight_quant_config, activation_quant_config = \
|
@@ -16,8 +16,8 @@
|
|
16
16
|
import tensorflow as tf
|
17
17
|
from typing import Tuple, List
|
18
18
|
from model_compression_toolkit.core.keras.constants import USE_BIAS
|
19
|
-
from model_compression_toolkit.core.common.framework_info import get_fw_info
|
20
19
|
from tensorflow.keras.models import Model
|
20
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
21
21
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
22
22
|
from model_compression_toolkit.logger import Logger
|
23
23
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
@@ -44,7 +44,8 @@ def get_gptq_trainable_parameters(fxp_model: Model,
|
|
44
44
|
|
45
45
|
for layer in fxp_model.layers:
|
46
46
|
if isinstance(layer, KerasTrainableQuantizationWrapper):
|
47
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer)
|
47
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
48
|
+
fw_info=DEFAULT_KERAS_INFO)
|
48
49
|
|
49
50
|
# collect trainable weights per quantizer
|
50
51
|
if kernel_attribute not in layer.weights_quantizers:
|
@@ -56,8 +57,9 @@ def get_gptq_trainable_parameters(fxp_model: Model,
|
|
56
57
|
trainable_threshold.extend(quantizer_trainable_threshold)
|
57
58
|
|
58
59
|
if add_bias:
|
59
|
-
|
60
|
-
use_bias =
|
60
|
+
kernel_ops_attrs = DEFAULT_KERAS_INFO.kernel_ops_attributes_mapping.get(type(layer.layer))
|
61
|
+
use_bias = kernel_ops_attrs is not None and kernel_ops_attrs[0] is not None \
|
62
|
+
and layer.layer.get_config().get(USE_BIAS)
|
61
63
|
if use_bias is not None and use_bias and layer.layer.bias is not None:
|
62
64
|
bias_weights.append([layer.layer.bias])
|
63
65
|
|
@@ -41,8 +41,9 @@ from model_compression_toolkit.metadata import create_model_metadata
|
|
41
41
|
|
42
42
|
if FOUND_TF:
|
43
43
|
import tensorflow as tf
|
44
|
-
from model_compression_toolkit.core.keras.default_framework_info import
|
44
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
45
45
|
from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
|
46
|
+
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
46
47
|
from tensorflow.keras.models import Model
|
47
48
|
from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss, sample_layer_attention_loss
|
48
49
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
@@ -151,7 +152,6 @@ if FOUND_TF:
|
|
151
152
|
gradual_activation_quantization_config=gradual_quant_config)
|
152
153
|
|
153
154
|
|
154
|
-
@set_keras_info
|
155
155
|
def keras_gradient_post_training_quantization(in_model: Model, representative_data_gen: Callable,
|
156
156
|
gptq_config: GradientPTQConfig,
|
157
157
|
gptq_representative_data_gen: Callable = None,
|
@@ -234,13 +234,16 @@ if FOUND_TF:
|
|
234
234
|
if core_config.debug_config.bypass:
|
235
235
|
return in_model, None
|
236
236
|
|
237
|
+
KerasModelValidation(model=in_model,
|
238
|
+
fw_info=DEFAULT_KERAS_INFO).validate()
|
239
|
+
|
237
240
|
if core_config.is_mixed_precision_enabled:
|
238
241
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
239
242
|
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
|
240
243
|
"Ensure usage of the correct API for keras_post_training_quantization "
|
241
244
|
"or provide a valid mixed-precision configuration.") # pragma: no cover
|
242
245
|
|
243
|
-
tb_w = init_tensorboard_writer()
|
246
|
+
tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)
|
244
247
|
|
245
248
|
fw_impl = GPTQKerasImplemantation()
|
246
249
|
|
@@ -254,6 +257,7 @@ if FOUND_TF:
|
|
254
257
|
tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
|
255
258
|
representative_data_gen=representative_data_gen,
|
256
259
|
core_config=core_config,
|
260
|
+
fw_info=DEFAULT_KERAS_INFO,
|
257
261
|
fw_impl=fw_impl,
|
258
262
|
fqc=framework_platform_capabilities,
|
259
263
|
target_resource_utilization=target_resource_utilization,
|
@@ -267,6 +271,7 @@ if FOUND_TF:
|
|
267
271
|
gptq_config,
|
268
272
|
representative_data_gen,
|
269
273
|
gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
|
274
|
+
DEFAULT_KERAS_INFO,
|
270
275
|
fw_impl,
|
271
276
|
tb_w,
|
272
277
|
hessian_info_service=hessian_info_service)
|
@@ -278,7 +283,8 @@ if FOUND_TF:
|
|
278
283
|
tb_w,
|
279
284
|
float_graph,
|
280
285
|
tg_gptq,
|
281
|
-
fw_impl
|
286
|
+
fw_impl,
|
287
|
+
DEFAULT_KERAS_INFO)
|
282
288
|
|
283
289
|
exportable_model, user_info = get_exportable_keras_model(tg_gptq)
|
284
290
|
if framework_platform_capabilities.tpc.add_metadata:
|
@@ -17,6 +17,7 @@ from typing import List, Callable
|
|
17
17
|
import tensorflow as tf
|
18
18
|
from keras import Model
|
19
19
|
|
20
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
20
21
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
21
22
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
22
23
|
|
@@ -65,7 +66,8 @@ class SoftQuantizerRegularization:
|
|
65
66
|
|
66
67
|
# Compute the regularization term without concatenating
|
67
68
|
for i, layer in enumerate(layers):
|
68
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer)
|
69
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
70
|
+
fw_info=DEFAULT_KERAS_INFO)
|
69
71
|
|
70
72
|
st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
|
71
73
|
|
@@ -54,6 +54,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
54
54
|
graph_quant: Graph,
|
55
55
|
gptq_config: GradientPTQConfig,
|
56
56
|
fw_impl: FrameworkImplementation,
|
57
|
+
fw_info: FrameworkInfo,
|
57
58
|
representative_data_gen: Callable,
|
58
59
|
hessian_info_service: HessianInfoService = None):
|
59
60
|
"""
|
@@ -67,6 +68,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
67
68
|
graph_quant: Graph to build a quantized networks from.
|
68
69
|
gptq_config: GradientPTQConfigV2 with parameters about the tuning process.
|
69
70
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
71
|
+
fw_info: Framework information
|
70
72
|
representative_data_gen: Dataset to use for inputs of the models.
|
71
73
|
hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model.
|
72
74
|
"""
|
@@ -79,6 +81,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
79
81
|
graph_quant,
|
80
82
|
gptq_config,
|
81
83
|
fw_impl,
|
84
|
+
fw_info,
|
82
85
|
representative_data_gen_fn=representative_data_gen,
|
83
86
|
hessian_info_service=hessian_info_service)
|
84
87
|
|
@@ -164,7 +167,8 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
164
167
|
A boolean whether the layer is to be wrapped with a Quantization Wrapper.
|
165
168
|
"""
|
166
169
|
|
167
|
-
|
170
|
+
kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)[0]
|
171
|
+
return kernel_attr is not None and node.is_weights_quantization_enabled(kernel_attr)
|
168
172
|
|
169
173
|
def gptq_wrapper(self,
|
170
174
|
n: BaseNode,
|
@@ -183,7 +187,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
183
187
|
# If we are here, then the node has a kernel attribute to quantize and training during GPTQ
|
184
188
|
weights_quantizers, _ = quantization_builder(n,
|
185
189
|
self.gptq_config,
|
186
|
-
n.
|
190
|
+
self.fw_info.get_kernel_op_attributes(n.type)[0])
|
187
191
|
|
188
192
|
if len(weights_quantizers) > 0:
|
189
193
|
return PytorchQuantizationWrapper(layer,
|
@@ -220,6 +224,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
220
224
|
"""
|
221
225
|
gptq_model, gptq_user_info = PyTorchModelBuilder(graph=self.graph_quant,
|
222
226
|
append2output=self.compare_points,
|
227
|
+
fw_info=self.fw_info,
|
223
228
|
wrapper=self.gptq_wrapper,
|
224
229
|
return_float_outputs=True,
|
225
230
|
get_activation_quantizer_holder_fn=self.get_activation_quantizer_holder).build_model()
|
@@ -335,7 +340,8 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
335
340
|
Logger.critical(f"Cannot update GPTQ graph: Layer with name '{name}' is missing or not unique. "
|
336
341
|
f"Ensure each layer has a unique name and exists within the graph for updates.")
|
337
342
|
node = node[0]
|
338
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type
|
343
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
|
344
|
+
fw_info=self.fw_info)
|
339
345
|
# TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
|
340
346
|
# To enable GPTQ for other attributes, this code needs to be modified.
|
341
347
|
weights, weight_quant_config, activation_quant_config = \
|
@@ -16,6 +16,7 @@ import torch
|
|
16
16
|
import torch.nn as nn
|
17
17
|
from typing import List
|
18
18
|
from model_compression_toolkit.core.pytorch.constants import BIAS
|
19
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
19
20
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
20
21
|
from model_compression_toolkit.logger import Logger
|
21
22
|
from mct_quantizers import PytorchQuantizationWrapper
|
@@ -42,7 +43,8 @@ def get_gptq_trainable_parameters(fxp_model: nn.Module,
|
|
42
43
|
|
43
44
|
for layer in fxp_model.modules():
|
44
45
|
if isinstance(layer, PytorchQuantizationWrapper):
|
45
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer)
|
46
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
47
|
+
fw_info=DEFAULT_PYTORCH_INFO)
|
46
48
|
|
47
49
|
# collect trainable weights per quantizer
|
48
50
|
if kernel_attribute not in layer.weights_quantizers:
|
@@ -39,7 +39,7 @@ from model_compression_toolkit.verify_packages import FOUND_TORCH
|
|
39
39
|
|
40
40
|
|
41
41
|
if FOUND_TORCH:
|
42
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import
|
42
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
43
43
|
from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation
|
44
44
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
45
45
|
from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss, sample_layer_attention_loss
|
@@ -142,8 +142,6 @@ if FOUND_TORCH:
|
|
142
142
|
gradual_activation_quantization_config=gradual_quant_config,
|
143
143
|
log_function=log_function)
|
144
144
|
|
145
|
-
|
146
|
-
@set_pytorch_info
|
147
145
|
def pytorch_gradient_post_training_quantization(model: Module,
|
148
146
|
representative_data_gen: Callable,
|
149
147
|
target_resource_utilization: ResourceUtilization = None,
|
@@ -218,7 +216,8 @@ if FOUND_TORCH:
|
|
218
216
|
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
|
219
217
|
"Ensure usage of the correct API for 'pytorch_gradient_post_training_quantization' "
|
220
218
|
"or provide a valid mixed-precision configuration.")
|
221
|
-
|
219
|
+
|
220
|
+
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
222
221
|
|
223
222
|
fw_impl = GPTQPytorchImplemantation()
|
224
223
|
|
@@ -234,6 +233,7 @@ if FOUND_TORCH:
|
|
234
233
|
graph, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=model,
|
235
234
|
representative_data_gen=representative_data_gen,
|
236
235
|
core_config=core_config,
|
236
|
+
fw_info=DEFAULT_PYTORCH_INFO,
|
237
237
|
fw_impl=fw_impl,
|
238
238
|
fqc=framework_quantization_capabilities,
|
239
239
|
target_resource_utilization=target_resource_utilization,
|
@@ -250,6 +250,7 @@ if FOUND_TORCH:
|
|
250
250
|
gptq_config,
|
251
251
|
representative_data_gen,
|
252
252
|
gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
|
253
|
+
DEFAULT_PYTORCH_INFO,
|
253
254
|
fw_impl,
|
254
255
|
tb_w,
|
255
256
|
hessian_info_service=hessian_info_service)
|
@@ -259,7 +260,8 @@ if FOUND_TORCH:
|
|
259
260
|
tb_w,
|
260
261
|
float_graph,
|
261
262
|
graph_gptq,
|
262
|
-
fw_impl
|
263
|
+
fw_impl,
|
264
|
+
DEFAULT_PYTORCH_INFO)
|
263
265
|
|
264
266
|
exportable_model, user_info = get_exportable_pytorch_model(graph_gptq)
|
265
267
|
if framework_quantization_capabilities.tpc.add_metadata:
|
@@ -18,6 +18,7 @@ import torch
|
|
18
18
|
from torch import nn
|
19
19
|
|
20
20
|
from mct_quantizers import PytorchQuantizationWrapper
|
21
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
21
22
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
22
23
|
|
23
24
|
|
@@ -60,7 +61,8 @@ class SoftQuantizerRegularization:
|
|
60
61
|
b = self.beta_scheduler(self.count_iter)
|
61
62
|
reg = 0
|
62
63
|
for layer, w in zip(layers, layer_weights):
|
63
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer)
|
64
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
65
|
+
fw_info=DEFAULT_PYTORCH_INFO)
|
64
66
|
|
65
67
|
st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
|
66
68
|
soft_loss = (1 - torch.pow(torch.abs(st - .5) * 2, b)).sum()
|