mct-nightly 2.4.0.20250616.616__py3-none-any.whl → 2.4.0.20250618.606__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/METADATA +1 -1
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/RECORD +120 -120
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +2 -5
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
- model_compression_toolkit/core/common/framework_implementation.py +10 -22
- model_compression_toolkit/core/common/framework_info.py +105 -68
- model_compression_toolkit/core/common/graph/base_graph.py +15 -42
- model_compression_toolkit/core/common/graph/base_node.py +103 -42
- model_compression_toolkit/core/common/graph/functional_node.py +18 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
- model_compression_toolkit/core/common/model_collector.py +10 -20
- model_compression_toolkit/core/common/model_validation.py +1 -4
- model_compression_toolkit/core/common/network_editors/actions.py +14 -38
- model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
- model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
- model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
- model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
- model_compression_toolkit/core/common/pruning/pruner.py +1 -6
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
- model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
- model_compression_toolkit/core/graph_prep_runner.py +2 -16
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/keras/default_framework_info.py +138 -87
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
- model_compression_toolkit/core/keras/keras_implementation.py +15 -35
- model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
- model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +0 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/quantization_prep_runner.py +4 -9
- model_compression_toolkit/core/runner.py +5 -15
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
- model_compression_toolkit/gptq/common/gptq_training.py +1 -8
- model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
- model_compression_toolkit/gptq/keras/graph_info.py +4 -6
- model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
- model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/runner.py +1 -7
- model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
- model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
- model_compression_toolkit/ptq/runner.py +1 -4
- model_compression_toolkit/qat/common/qat_config.py +2 -6
- model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
- model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
- model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
- model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
- model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
- model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
- model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/top_level.txt +0 -0
@@ -60,10 +60,10 @@ class ReduceLROnPlateauWithReset:
|
|
60
60
|
# Attach optimizer
|
61
61
|
if not isinstance(optimizer, Optimizer):
|
62
62
|
Logger.critical('{} is not an Optimizer'.format(
|
63
|
-
type(optimizer).__name__))
|
63
|
+
type(optimizer).__name__)) # pragma: no cover
|
64
64
|
self.optimizer = optimizer
|
65
65
|
|
66
|
-
if isinstance(min_lr, (list, tuple)):
|
66
|
+
if isinstance(min_lr, (list, tuple)): # pragma: no cover
|
67
67
|
if len(min_lr) != len(optimizer.param_groups):
|
68
68
|
Logger.critical("expected {} min_lrs, got {}".format(
|
69
69
|
len(optimizer.param_groups), len(min_lr))) # pragma: no cover
|
@@ -117,7 +117,7 @@ class ReduceLROnPlateauWithReset:
|
|
117
117
|
self.num_bad_epochs += 1
|
118
118
|
|
119
119
|
# Handle cooldown period
|
120
|
-
if self.in_cooldown:
|
120
|
+
if self.in_cooldown: # pragma: no cover
|
121
121
|
self.cooldown_counter -= 1
|
122
122
|
self.num_bad_epochs = 0 # Ignore any bad epochs in cooldown
|
123
123
|
|
@@ -142,7 +142,7 @@ class ReduceLROnPlateauWithReset:
|
|
142
142
|
new_lr = max(old_lr * self.factor, self.min_lrs[i])
|
143
143
|
if old_lr - new_lr > self.eps:
|
144
144
|
param_group['lr'] = new_lr
|
145
|
-
if self.verbose:
|
145
|
+
if self.verbose: # pragma: no cover
|
146
146
|
epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch
|
147
147
|
print('Epoch {}: reducing learning rate'
|
148
148
|
' of group {} to {:.4e}.'.format(epoch_str, i, new_lr))
|
@@ -168,19 +168,19 @@ class ReduceLROnPlateauWithReset:
|
|
168
168
|
Returns:
|
169
169
|
bool: True if the new value is better, False otherwise.
|
170
170
|
"""
|
171
|
-
if best is None:
|
171
|
+
if best is None: # pragma: no cover
|
172
172
|
return True
|
173
173
|
|
174
174
|
if self.mode == 'min' and self.threshold_mode == 'rel':
|
175
175
|
rel_epsilon = 1. - self.threshold
|
176
176
|
return a < best * rel_epsilon
|
177
|
-
elif self.mode == 'min' and self.threshold_mode == 'abs':
|
177
|
+
elif self.mode == 'min' and self.threshold_mode == 'abs': # pragma: no cover
|
178
178
|
return a < best - self.threshold
|
179
|
-
elif self.mode == 'max' and self.threshold_mode == 'rel':
|
179
|
+
elif self.mode == 'max' and self.threshold_mode == 'rel': # pragma: no cover
|
180
180
|
rel_epsilon = self.threshold + 1.
|
181
181
|
return a > best * rel_epsilon
|
182
182
|
else: # mode == 'max' and threshold_mode == 'abs':
|
183
|
-
return a > best + self.threshold
|
183
|
+
return a > best + self.threshold # pragma: no cover
|
184
184
|
|
185
185
|
def _init_is_better(self) -> None:
|
186
186
|
"""
|
@@ -197,9 +197,9 @@ class ReduceLROnPlateauWithReset:
|
|
197
197
|
if self.mode == 'min':
|
198
198
|
self.mode_worse = float('inf')
|
199
199
|
else: # mode == 'max':
|
200
|
-
self.mode_worse = float('-inf')
|
200
|
+
self.mode_worse = float('-inf') # pragma: no cover
|
201
201
|
|
202
|
-
def state_dict(self) -> Dict[str, Any]:
|
202
|
+
def state_dict(self) -> Dict[str, Any]: # pragma: no cover
|
203
203
|
"""
|
204
204
|
Return the state of the scheduler as a dictionary.
|
205
205
|
|
@@ -208,7 +208,7 @@ class ReduceLROnPlateauWithReset:
|
|
208
208
|
"""
|
209
209
|
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
210
210
|
|
211
|
-
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
211
|
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # pragma: no cover
|
212
212
|
"""
|
213
213
|
Load the scheduler state.
|
214
214
|
|
@@ -14,8 +14,8 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
from typing import Tuple, List
|
16
16
|
|
17
|
-
from model_compression_toolkit.core import FrameworkInfo
|
18
17
|
from model_compression_toolkit.logger import Logger
|
18
|
+
from model_compression_toolkit.core.common.framework_info import get_fw_info
|
19
19
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
20
20
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
21
21
|
|
@@ -40,8 +40,7 @@ def get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str], L
|
|
40
40
|
compare_points_name = []
|
41
41
|
for n in input_graph.get_topo_sorted_nodes():
|
42
42
|
# only nodes with kernel attribute are currently trained with GPTQ and are used as compare points
|
43
|
-
kernel_attr
|
44
|
-
if kernel_attr is not None and n.is_weights_quantization_enabled(kernel_attr) and not n.reuse:
|
43
|
+
if n.kernel_attr is not None and n.is_weights_quantization_enabled(n.kernel_attr) and not n.reuse:
|
45
44
|
compare_points.append(n)
|
46
45
|
compare_points_name.append(n.name)
|
47
46
|
compare_points_std.append(n.prior_info.std_output)
|
@@ -49,20 +48,15 @@ def get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str], L
|
|
49
48
|
return compare_points, compare_points_name, compare_points_mean, compare_points_std
|
50
49
|
|
51
50
|
|
52
|
-
def get_kernel_attribute_name_for_gptq(layer_type: type
|
51
|
+
def get_kernel_attribute_name_for_gptq(layer_type: type) -> str:
|
53
52
|
"""
|
54
53
|
Returns a layer's kernel attribute name for GPTQ training purposes.
|
55
54
|
|
56
55
|
Args:
|
57
56
|
layer_type: A type of model's layer.
|
58
|
-
fw_info: A FrameworkInfo object.
|
59
57
|
|
60
58
|
Returns: The name of the kernel attribute.
|
61
59
|
|
62
60
|
"""
|
63
|
-
|
64
|
-
|
65
|
-
Logger.critical( # pragma: no cover
|
66
|
-
f"In GPTQ training, only the kernel weights attribute should be trained. "
|
67
|
-
f"However, the number of kernel attributes is {len(kernel_attribute)}.")
|
68
|
-
return kernel_attribute[0]
|
61
|
+
|
62
|
+
return get_fw_info().get_kernel_op_attribute(layer_type)
|
@@ -44,7 +44,6 @@ class GPTQTrainer(ABC):
|
|
44
44
|
graph_quant: Graph,
|
45
45
|
gptq_config: GradientPTQConfig,
|
46
46
|
fw_impl: GPTQFrameworkImplemantation,
|
47
|
-
fw_info: FrameworkInfo,
|
48
47
|
representative_data_gen_fn: Callable[[], Generator],
|
49
48
|
hessian_info_service: HessianInfoService = None):
|
50
49
|
"""
|
@@ -58,7 +57,6 @@ class GPTQTrainer(ABC):
|
|
58
57
|
graph_quant: Graph to build a quantized networks from.
|
59
58
|
gptq_config: GradientPTQConfig with parameters about the tuning process.
|
60
59
|
fw_impl: Framework implementation
|
61
|
-
fw_info: Framework information
|
62
60
|
representative_data_gen_fn: factory for representative data generator.
|
63
61
|
hessian_info_service: HessianInfoService for fetching and computing Hessian-approximation information.
|
64
62
|
"""
|
@@ -66,7 +64,6 @@ class GPTQTrainer(ABC):
|
|
66
64
|
self.graph_quant = copy.deepcopy(graph_quant)
|
67
65
|
self.gptq_config = gptq_config
|
68
66
|
self.fw_impl = fw_impl
|
69
|
-
self.fw_info = fw_info
|
70
67
|
self.representative_data_gen_fn = representative_data_gen_fn
|
71
68
|
|
72
69
|
def _get_total_grad_steps():
|
@@ -83,8 +80,7 @@ class GPTQTrainer(ABC):
|
|
83
80
|
|
84
81
|
self.float_model, self.float_user_info = fw_impl.model_builder(self.graph_float,
|
85
82
|
mode=ModelBuilderMode.FLOAT,
|
86
|
-
append2output=self.compare_points
|
87
|
-
fw_info=self.fw_info)
|
83
|
+
append2output=self.compare_points)
|
88
84
|
|
89
85
|
self.fxp_model, self.gptq_user_info = self.build_gptq_model()
|
90
86
|
if self.gptq_config.hessian_weights_config:
|
@@ -288,7 +284,6 @@ def gptq_training(graph_float: Graph,
|
|
288
284
|
gptq_config: GradientPTQConfig,
|
289
285
|
representative_data_gen: Callable,
|
290
286
|
fw_impl: GPTQFrameworkImplemantation,
|
291
|
-
fw_info: FrameworkInfo,
|
292
287
|
hessian_info_service: HessianInfoService = None) -> Graph:
|
293
288
|
"""
|
294
289
|
GPTQ training process using knowledge distillation with a teacher network (float model) and a student network (quantized model).
|
@@ -298,7 +293,6 @@ def gptq_training(graph_float: Graph,
|
|
298
293
|
gptq_config: GradientPTQConfig with parameters about the tuning process.
|
299
294
|
representative_data_gen: Dataset to use for inputs of the models.
|
300
295
|
fw_impl: Framework implementation
|
301
|
-
fw_info: Framework information
|
302
296
|
hessian_info_service: HessianInfoService to fetch information based on the Hessian approximation.
|
303
297
|
|
304
298
|
Returns:
|
@@ -312,7 +306,6 @@ def gptq_training(graph_float: Graph,
|
|
312
306
|
graph_quant,
|
313
307
|
gptq_config,
|
314
308
|
fw_impl,
|
315
|
-
fw_info,
|
316
309
|
representative_data_gen,
|
317
310
|
hessian_info_service=hessian_info_service)
|
318
311
|
|
@@ -65,7 +65,6 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
65
65
|
graph_quant: Graph,
|
66
66
|
gptq_config: GradientPTQConfig,
|
67
67
|
fw_impl: FrameworkImplementation,
|
68
|
-
fw_info: FrameworkInfo,
|
69
68
|
representative_data_gen: Callable,
|
70
69
|
hessian_info_service: HessianInfoService = None):
|
71
70
|
"""
|
@@ -79,7 +78,6 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
79
78
|
graph_quant: Graph to build a quantized networks from.
|
80
79
|
gptq_config: GradientPTQConfig with parameters about the tuning process.
|
81
80
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
82
|
-
fw_info: Framework information.
|
83
81
|
representative_data_gen: Dataset to use for inputs of the models.
|
84
82
|
hessian_info_service: HessianScoresService for fetching and computing Hessian's approximation scores.
|
85
83
|
|
@@ -94,7 +92,6 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
94
92
|
graph_quant,
|
95
93
|
gptq_config,
|
96
94
|
fw_impl,
|
97
|
-
fw_info,
|
98
95
|
representative_data_gen_fn=representative_data_gen,
|
99
96
|
hessian_info_service=hessian_info_service)
|
100
97
|
|
@@ -210,8 +207,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
210
207
|
Returns:
|
211
208
|
A boolean whether the layer is to be wrapped with a QuantizeWrapper
|
212
209
|
"""
|
213
|
-
kernel_attr
|
214
|
-
return kernel_attr is not None and node.is_weights_quantization_enabled(kernel_attr)
|
210
|
+
return node.kernel_attr is not None and node.is_weights_quantization_enabled(node.kernel_attr)
|
215
211
|
|
216
212
|
def gptq_wrapper(self,
|
217
213
|
n: common.BaseNode,
|
@@ -230,7 +226,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
230
226
|
# If we are here, then the node has a kernel attribute to quantize and training during GPTQ
|
231
227
|
weights_quantizers, _ = quantization_builder(n,
|
232
228
|
self.gptq_config, # TODO: split quantizers building into two functions: for weights and activations
|
233
|
-
|
229
|
+
n.kernel_attr)
|
234
230
|
if len(weights_quantizers) > 0:
|
235
231
|
return KerasTrainableQuantizationWrapper(layer,
|
236
232
|
weights_quantizers=weights_quantizers)
|
@@ -271,7 +267,6 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
271
267
|
|
272
268
|
gptq_model, gptq_user_info = KerasModelBuilder(graph=self.graph_quant,
|
273
269
|
append2output=self.compare_points,
|
274
|
-
fw_info=self.fw_info,
|
275
270
|
return_float_outputs=True,
|
276
271
|
wrapper=self.gptq_wrapper,
|
277
272
|
get_activation_quantizer_holder_fn=self.get_activation_quantizer_holder).build_model()
|
@@ -431,8 +426,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
431
426
|
Logger.critical(f"Unable to update the GPTQ graph because the layer named '{layer.layer.name}' could not be found. "
|
432
427
|
f"Verify that the layer names in the GPTQ model match those in the graph.")
|
433
428
|
node = node[0]
|
434
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type
|
435
|
-
fw_info=self.fw_info)
|
429
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type)
|
436
430
|
# TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
|
437
431
|
# To enable GPTQ for other attributes, this code needs to be modified.
|
438
432
|
weights, weight_quant_config, activation_quant_config = \
|
@@ -16,8 +16,8 @@
|
|
16
16
|
import tensorflow as tf
|
17
17
|
from typing import Tuple, List
|
18
18
|
from model_compression_toolkit.core.keras.constants import USE_BIAS
|
19
|
+
from model_compression_toolkit.core.common.framework_info import get_fw_info
|
19
20
|
from tensorflow.keras.models import Model
|
20
|
-
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
21
21
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
22
22
|
from model_compression_toolkit.logger import Logger
|
23
23
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
@@ -44,8 +44,7 @@ def get_gptq_trainable_parameters(fxp_model: Model,
|
|
44
44
|
|
45
45
|
for layer in fxp_model.layers:
|
46
46
|
if isinstance(layer, KerasTrainableQuantizationWrapper):
|
47
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer)
|
48
|
-
fw_info=DEFAULT_KERAS_INFO)
|
47
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer))
|
49
48
|
|
50
49
|
# collect trainable weights per quantizer
|
51
50
|
if kernel_attribute not in layer.weights_quantizers:
|
@@ -57,9 +56,8 @@ def get_gptq_trainable_parameters(fxp_model: Model,
|
|
57
56
|
trainable_threshold.extend(quantizer_trainable_threshold)
|
58
57
|
|
59
58
|
if add_bias:
|
60
|
-
|
61
|
-
use_bias =
|
62
|
-
and layer.layer.get_config().get(USE_BIAS)
|
59
|
+
kernel_ops_attr = get_fw_info().get_kernel_op_attribute(type(layer.layer))
|
60
|
+
use_bias = kernel_ops_attr is not None and layer.layer.get_config().get(USE_BIAS)
|
63
61
|
if use_bias is not None and use_bias and layer.layer.bias is not None:
|
64
62
|
bias_weights.append([layer.layer.bias])
|
65
63
|
|
@@ -41,7 +41,7 @@ from model_compression_toolkit.metadata import create_model_metadata
|
|
41
41
|
|
42
42
|
if FOUND_TF:
|
43
43
|
import tensorflow as tf
|
44
|
-
from model_compression_toolkit.core.keras.default_framework_info import
|
44
|
+
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
|
45
45
|
from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
|
46
46
|
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
47
47
|
from tensorflow.keras.models import Model
|
@@ -152,6 +152,7 @@ if FOUND_TF:
|
|
152
152
|
gradual_activation_quantization_config=gradual_quant_config)
|
153
153
|
|
154
154
|
|
155
|
+
@set_keras_info
|
155
156
|
def keras_gradient_post_training_quantization(in_model: Model, representative_data_gen: Callable,
|
156
157
|
gptq_config: GradientPTQConfig,
|
157
158
|
gptq_representative_data_gen: Callable = None,
|
@@ -234,8 +235,7 @@ if FOUND_TF:
|
|
234
235
|
if core_config.debug_config.bypass:
|
235
236
|
return in_model, None
|
236
237
|
|
237
|
-
KerasModelValidation(model=in_model
|
238
|
-
fw_info=DEFAULT_KERAS_INFO).validate()
|
238
|
+
KerasModelValidation(model=in_model).validate()
|
239
239
|
|
240
240
|
if core_config.is_mixed_precision_enabled:
|
241
241
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
@@ -243,7 +243,7 @@ if FOUND_TF:
|
|
243
243
|
"Ensure usage of the correct API for keras_post_training_quantization "
|
244
244
|
"or provide a valid mixed-precision configuration.") # pragma: no cover
|
245
245
|
|
246
|
-
tb_w = init_tensorboard_writer(
|
246
|
+
tb_w = init_tensorboard_writer()
|
247
247
|
|
248
248
|
fw_impl = GPTQKerasImplemantation()
|
249
249
|
|
@@ -257,7 +257,6 @@ if FOUND_TF:
|
|
257
257
|
tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
|
258
258
|
representative_data_gen=representative_data_gen,
|
259
259
|
core_config=core_config,
|
260
|
-
fw_info=DEFAULT_KERAS_INFO,
|
261
260
|
fw_impl=fw_impl,
|
262
261
|
fqc=framework_platform_capabilities,
|
263
262
|
target_resource_utilization=target_resource_utilization,
|
@@ -271,7 +270,6 @@ if FOUND_TF:
|
|
271
270
|
gptq_config,
|
272
271
|
representative_data_gen,
|
273
272
|
gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
|
274
|
-
DEFAULT_KERAS_INFO,
|
275
273
|
fw_impl,
|
276
274
|
tb_w,
|
277
275
|
hessian_info_service=hessian_info_service)
|
@@ -283,8 +281,7 @@ if FOUND_TF:
|
|
283
281
|
tb_w,
|
284
282
|
float_graph,
|
285
283
|
tg_gptq,
|
286
|
-
fw_impl
|
287
|
-
DEFAULT_KERAS_INFO)
|
284
|
+
fw_impl)
|
288
285
|
|
289
286
|
exportable_model, user_info = get_exportable_keras_model(tg_gptq)
|
290
287
|
if framework_platform_capabilities.tpc.add_metadata:
|
@@ -17,7 +17,6 @@ from typing import List, Callable
|
|
17
17
|
import tensorflow as tf
|
18
18
|
from keras import Model
|
19
19
|
|
20
|
-
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
21
20
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
22
21
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
23
22
|
|
@@ -66,8 +65,7 @@ class SoftQuantizerRegularization:
|
|
66
65
|
|
67
66
|
# Compute the regularization term without concatenating
|
68
67
|
for i, layer in enumerate(layers):
|
69
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer)
|
70
|
-
fw_info=DEFAULT_KERAS_INFO)
|
68
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer))
|
71
69
|
|
72
70
|
st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
|
73
71
|
|
@@ -54,7 +54,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
54
54
|
graph_quant: Graph,
|
55
55
|
gptq_config: GradientPTQConfig,
|
56
56
|
fw_impl: FrameworkImplementation,
|
57
|
-
fw_info: FrameworkInfo,
|
58
57
|
representative_data_gen: Callable,
|
59
58
|
hessian_info_service: HessianInfoService = None):
|
60
59
|
"""
|
@@ -68,7 +67,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
68
67
|
graph_quant: Graph to build a quantized networks from.
|
69
68
|
gptq_config: GradientPTQConfigV2 with parameters about the tuning process.
|
70
69
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
71
|
-
fw_info: Framework information
|
72
70
|
representative_data_gen: Dataset to use for inputs of the models.
|
73
71
|
hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model.
|
74
72
|
"""
|
@@ -81,7 +79,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
81
79
|
graph_quant,
|
82
80
|
gptq_config,
|
83
81
|
fw_impl,
|
84
|
-
fw_info,
|
85
82
|
representative_data_gen_fn=representative_data_gen,
|
86
83
|
hessian_info_service=hessian_info_service)
|
87
84
|
|
@@ -167,8 +164,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
167
164
|
A boolean whether the layer is to be wrapped with a Quantization Wrapper.
|
168
165
|
"""
|
169
166
|
|
170
|
-
kernel_attr
|
171
|
-
return kernel_attr is not None and node.is_weights_quantization_enabled(kernel_attr)
|
167
|
+
return node.kernel_attr is not None and node.is_weights_quantization_enabled(node.kernel_attr)
|
172
168
|
|
173
169
|
def gptq_wrapper(self,
|
174
170
|
n: BaseNode,
|
@@ -187,7 +183,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
187
183
|
# If we are here, then the node has a kernel attribute to quantize and training during GPTQ
|
188
184
|
weights_quantizers, _ = quantization_builder(n,
|
189
185
|
self.gptq_config,
|
190
|
-
|
186
|
+
n.kernel_attr)
|
191
187
|
|
192
188
|
if len(weights_quantizers) > 0:
|
193
189
|
return PytorchQuantizationWrapper(layer,
|
@@ -224,7 +220,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
224
220
|
"""
|
225
221
|
gptq_model, gptq_user_info = PyTorchModelBuilder(graph=self.graph_quant,
|
226
222
|
append2output=self.compare_points,
|
227
|
-
fw_info=self.fw_info,
|
228
223
|
wrapper=self.gptq_wrapper,
|
229
224
|
return_float_outputs=True,
|
230
225
|
get_activation_quantizer_holder_fn=self.get_activation_quantizer_holder).build_model()
|
@@ -340,8 +335,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
340
335
|
Logger.critical(f"Cannot update GPTQ graph: Layer with name '{name}' is missing or not unique. "
|
341
336
|
f"Ensure each layer has a unique name and exists within the graph for updates.")
|
342
337
|
node = node[0]
|
343
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type
|
344
|
-
fw_info=self.fw_info)
|
338
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type)
|
345
339
|
# TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
|
346
340
|
# To enable GPTQ for other attributes, this code needs to be modified.
|
347
341
|
weights, weight_quant_config, activation_quant_config = \
|
@@ -16,7 +16,6 @@ import torch
|
|
16
16
|
import torch.nn as nn
|
17
17
|
from typing import List
|
18
18
|
from model_compression_toolkit.core.pytorch.constants import BIAS
|
19
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
20
19
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
21
20
|
from model_compression_toolkit.logger import Logger
|
22
21
|
from mct_quantizers import PytorchQuantizationWrapper
|
@@ -43,8 +42,7 @@ def get_gptq_trainable_parameters(fxp_model: nn.Module,
|
|
43
42
|
|
44
43
|
for layer in fxp_model.modules():
|
45
44
|
if isinstance(layer, PytorchQuantizationWrapper):
|
46
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer)
|
47
|
-
fw_info=DEFAULT_PYTORCH_INFO)
|
45
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer))
|
48
46
|
|
49
47
|
# collect trainable weights per quantizer
|
50
48
|
if kernel_attribute not in layer.weights_quantizers:
|
@@ -39,7 +39,7 @@ from model_compression_toolkit.verify_packages import FOUND_TORCH
|
|
39
39
|
|
40
40
|
|
41
41
|
if FOUND_TORCH:
|
42
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import
|
42
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
|
43
43
|
from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation
|
44
44
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
45
45
|
from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss, sample_layer_attention_loss
|
@@ -142,6 +142,8 @@ if FOUND_TORCH:
|
|
142
142
|
gradual_activation_quantization_config=gradual_quant_config,
|
143
143
|
log_function=log_function)
|
144
144
|
|
145
|
+
|
146
|
+
@set_pytorch_info
|
145
147
|
def pytorch_gradient_post_training_quantization(model: Module,
|
146
148
|
representative_data_gen: Callable,
|
147
149
|
target_resource_utilization: ResourceUtilization = None,
|
@@ -216,8 +218,7 @@ if FOUND_TORCH:
|
|
216
218
|
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
|
217
219
|
"Ensure usage of the correct API for 'pytorch_gradient_post_training_quantization' "
|
218
220
|
"or provide a valid mixed-precision configuration.")
|
219
|
-
|
220
|
-
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
221
|
+
tb_w = init_tensorboard_writer()
|
221
222
|
|
222
223
|
fw_impl = GPTQPytorchImplemantation()
|
223
224
|
|
@@ -233,7 +234,6 @@ if FOUND_TORCH:
|
|
233
234
|
graph, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=model,
|
234
235
|
representative_data_gen=representative_data_gen,
|
235
236
|
core_config=core_config,
|
236
|
-
fw_info=DEFAULT_PYTORCH_INFO,
|
237
237
|
fw_impl=fw_impl,
|
238
238
|
fqc=framework_quantization_capabilities,
|
239
239
|
target_resource_utilization=target_resource_utilization,
|
@@ -250,7 +250,6 @@ if FOUND_TORCH:
|
|
250
250
|
gptq_config,
|
251
251
|
representative_data_gen,
|
252
252
|
gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
|
253
|
-
DEFAULT_PYTORCH_INFO,
|
254
253
|
fw_impl,
|
255
254
|
tb_w,
|
256
255
|
hessian_info_service=hessian_info_service)
|
@@ -260,8 +259,7 @@ if FOUND_TORCH:
|
|
260
259
|
tb_w,
|
261
260
|
float_graph,
|
262
261
|
graph_gptq,
|
263
|
-
fw_impl
|
264
|
-
DEFAULT_PYTORCH_INFO)
|
262
|
+
fw_impl)
|
265
263
|
|
266
264
|
exportable_model, user_info = get_exportable_pytorch_model(graph_gptq)
|
267
265
|
if framework_quantization_capabilities.tpc.add_metadata:
|
@@ -18,7 +18,6 @@ import torch
|
|
18
18
|
from torch import nn
|
19
19
|
|
20
20
|
from mct_quantizers import PytorchQuantizationWrapper
|
21
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
22
21
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
23
22
|
|
24
23
|
|
@@ -61,8 +60,7 @@ class SoftQuantizerRegularization:
|
|
61
60
|
b = self.beta_scheduler(self.count_iter)
|
62
61
|
reg = 0
|
63
62
|
for layer, w in zip(layers, layer_weights):
|
64
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer)
|
65
|
-
fw_info=DEFAULT_PYTORCH_INFO)
|
63
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer))
|
66
64
|
|
67
65
|
st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
|
68
66
|
soft_loss = (1 - torch.pow(torch.abs(st - .5) * 2, b)).sum()
|
@@ -37,7 +37,6 @@ def _apply_gptq(gptq_config: GradientPTQConfig,
|
|
37
37
|
tb_w: TensorboardWriter,
|
38
38
|
tg: Graph,
|
39
39
|
tg_bias: Graph,
|
40
|
-
fw_info: FrameworkInfo,
|
41
40
|
fw_impl: FrameworkImplementation,
|
42
41
|
hessian_info_service: HessianInfoService = None) -> Graph:
|
43
42
|
"""
|
@@ -52,7 +51,6 @@ def _apply_gptq(gptq_config: GradientPTQConfig,
|
|
52
51
|
tb_w: TensorBoardWriter object to log events.
|
53
52
|
tg: Float Reference Graph.
|
54
53
|
tg_bias: Graph of quantized model.
|
55
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.).
|
56
54
|
fw_impl: Framework implementation per framework
|
57
55
|
hessian_info_service: HessianInfoService to fetch information based on the hessian approximation for the float model.
|
58
56
|
Returns:
|
@@ -64,7 +62,6 @@ def _apply_gptq(gptq_config: GradientPTQConfig,
|
|
64
62
|
gptq_config,
|
65
63
|
representative_data_gen,
|
66
64
|
fw_impl,
|
67
|
-
fw_info,
|
68
65
|
hessian_info_service=hessian_info_service)
|
69
66
|
|
70
67
|
if tb_w is not None:
|
@@ -77,7 +74,6 @@ def gptq_runner(tg: Graph,
|
|
77
74
|
gptq_config: GradientPTQConfig,
|
78
75
|
representative_data_gen: Callable,
|
79
76
|
gptq_representative_data_gen: Callable,
|
80
|
-
fw_info: FrameworkInfo,
|
81
77
|
fw_impl: FrameworkImplementation,
|
82
78
|
tb_w: TensorboardWriter,
|
83
79
|
hessian_info_service: HessianInfoService = None) -> Graph:
|
@@ -91,7 +87,6 @@ def gptq_runner(tg: Graph,
|
|
91
87
|
gptq_config: GradientPTQConfig with parameters about the tuning process.
|
92
88
|
representative_data_gen: Dataset used for calibration.
|
93
89
|
gptq_representative_data_gen: Dataset used for GPTQ training
|
94
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.)
|
95
90
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
96
91
|
tb_w: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
|
97
92
|
hessian_info_service: HessianScoresService to fetch approximations of the hessian scores for the float model.
|
@@ -104,7 +99,7 @@ def gptq_runner(tg: Graph,
|
|
104
99
|
#############################################
|
105
100
|
# Apply Statistics Correction
|
106
101
|
#############################################
|
107
|
-
tg_bias = apply_statistics_correction(tg, representative_data_gen, core_config,
|
102
|
+
tg_bias = apply_statistics_correction(tg, representative_data_gen, core_config, fw_impl, tb_w)
|
108
103
|
|
109
104
|
if tb_w is not None:
|
110
105
|
tb_w.add_graph(tg_bias, 'after_bias_correction')
|
@@ -117,7 +112,6 @@ def gptq_runner(tg: Graph,
|
|
117
112
|
tb_w,
|
118
113
|
tg,
|
119
114
|
tg_bias,
|
120
|
-
fw_info,
|
121
115
|
fw_impl,
|
122
116
|
hessian_info_service=hessian_info_service)
|
123
117
|
|
@@ -35,11 +35,12 @@ if FOUND_TF:
|
|
35
35
|
AttachTpcToKeras
|
36
36
|
from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
|
37
37
|
from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
|
38
|
-
from model_compression_toolkit.core.keras.default_framework_info import
|
38
|
+
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
|
39
39
|
from tensorflow.keras.models import Model
|
40
40
|
|
41
41
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
42
42
|
|
43
|
+
@set_keras_info
|
43
44
|
def keras_pruning_experimental(model: Model,
|
44
45
|
target_resource_utilization: ResourceUtilization,
|
45
46
|
representative_data_gen: Callable,
|
@@ -123,7 +124,6 @@ if FOUND_TF:
|
|
123
124
|
float_graph = read_model_to_graph(model,
|
124
125
|
representative_data_gen,
|
125
126
|
target_platform_capabilities,
|
126
|
-
DEFAULT_KERAS_INFO,
|
127
127
|
fw_impl)
|
128
128
|
|
129
129
|
# Apply quantization configuration to the graph. This step is necessary even when not quantizing,
|
@@ -134,7 +134,6 @@ if FOUND_TF:
|
|
134
134
|
|
135
135
|
# Create a Pruner object with the graph and configuration.
|
136
136
|
pruner = Pruner(float_graph_with_compression_config,
|
137
|
-
DEFAULT_KERAS_INFO,
|
138
137
|
fw_impl,
|
139
138
|
target_resource_utilization,
|
140
139
|
representative_data_gen,
|
@@ -36,7 +36,7 @@ if FOUND_TORCH:
|
|
36
36
|
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
|
37
37
|
from model_compression_toolkit.core.pytorch.pruning.pruning_pytorch_implementation import \
|
38
38
|
PruningPytorchImplementation
|
39
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import
|
39
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
|
40
40
|
from torch.nn import Module
|
41
41
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
42
42
|
AttachTpcToPytorch
|
@@ -44,6 +44,7 @@ if FOUND_TORCH:
|
|
44
44
|
# Set the default Target Platform Capabilities (TPC) for PyTorch.
|
45
45
|
DEFAULT_PYOTRCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
46
46
|
|
47
|
+
@set_pytorch_info
|
47
48
|
def pytorch_pruning_experimental(model: Module,
|
48
49
|
target_resource_utilization: ResourceUtilization,
|
49
50
|
representative_data_gen: Callable,
|
@@ -129,7 +130,6 @@ if FOUND_TORCH:
|
|
129
130
|
float_graph = read_model_to_graph(model,
|
130
131
|
representative_data_gen,
|
131
132
|
framework_platform_capabilities,
|
132
|
-
DEFAULT_PYTORCH_INFO,
|
133
133
|
fw_impl)
|
134
134
|
|
135
135
|
# Apply quantization configuration to the graph. This step is necessary even when not quantizing,
|
@@ -140,7 +140,6 @@ if FOUND_TORCH:
|
|
140
140
|
|
141
141
|
# Create a Pruner object with the graph and configuration.
|
142
142
|
pruner = Pruner(float_graph_with_compression_config,
|
143
|
-
DEFAULT_PYTORCH_INFO,
|
144
143
|
fw_impl,
|
145
144
|
target_resource_utilization,
|
146
145
|
representative_data_gen,
|