mct-nightly 1.11.0.20240304.post404__py3-none-any.whl → 1.11.0.20240306.post426__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/METADATA +5 -5
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/RECORD +42 -40
- model_compression_toolkit/core/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +4 -70
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +19 -1
- model_compression_toolkit/core/common/quantization/core_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +0 -3
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +0 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -2
- model_compression_toolkit/core/keras/kpi_data_facade.py +5 -6
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +19 -19
- model_compression_toolkit/core/pytorch/constants.py +3 -0
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +5 -5
- model_compression_toolkit/core/pytorch/pruning/__init__.py +14 -0
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +315 -0
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +2 -2
- model_compression_toolkit/gptq/__init__.py +1 -1
- model_compression_toolkit/gptq/common/gptq_config.py +5 -72
- model_compression_toolkit/gptq/keras/gptq_training.py +2 -2
- model_compression_toolkit/gptq/keras/quantization_facade.py +19 -33
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +3 -3
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +2 -4
- model_compression_toolkit/gptq/pytorch/gptq_training.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +14 -31
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +3 -3
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +2 -4
- model_compression_toolkit/gptq/runner.py +3 -3
- model_compression_toolkit/pruning/__init__.py +1 -0
- model_compression_toolkit/pruning/pytorch/__init__.py +14 -0
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +166 -0
- model_compression_toolkit/ptq/__init__.py +2 -2
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -30
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +12 -30
- model_compression_toolkit/qat/keras/quantization_facade.py +6 -9
- model_compression_toolkit/qat/pytorch/quantization_facade.py +3 -7
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +0 -64
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +0 -53
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/WHEEL +0 -0
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/top_level.txt +0 -0
|
@@ -21,10 +21,10 @@ from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
|
|
|
21
21
|
from model_compression_toolkit.logger import Logger
|
|
22
22
|
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
|
|
23
23
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
24
|
-
from model_compression_toolkit.gptq.common.gptq_config import
|
|
24
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
25
25
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
26
26
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
27
|
-
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import
|
|
27
|
+
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
|
|
28
28
|
from model_compression_toolkit.core import CoreConfig
|
|
29
29
|
from model_compression_toolkit.core.runner import core_runner
|
|
30
30
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
|
@@ -66,7 +66,7 @@ if FOUND_TF:
|
|
|
66
66
|
loss: Callable = GPTQMultipleTensorsLoss(),
|
|
67
67
|
log_function: Callable = None,
|
|
68
68
|
use_hessian_based_weights: bool = True,
|
|
69
|
-
regularization_factor: float = REG_DEFAULT) ->
|
|
69
|
+
regularization_factor: float = REG_DEFAULT) -> GradientPTQConfig:
|
|
70
70
|
"""
|
|
71
71
|
Create a GradientPTQConfigV2 instance for Keras models.
|
|
72
72
|
|
|
@@ -102,26 +102,25 @@ if FOUND_TF:
|
|
|
102
102
|
"""
|
|
103
103
|
bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT,
|
|
104
104
|
momentum=GPTQ_MOMENTUM)
|
|
105
|
-
return
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
105
|
+
return GradientPTQConfig(n_epochs,
|
|
106
|
+
optimizer,
|
|
107
|
+
optimizer_rest=optimizer_rest,
|
|
108
|
+
loss=loss,
|
|
109
|
+
log_function=log_function,
|
|
110
|
+
train_bias=True,
|
|
111
|
+
optimizer_bias=bias_optimizer,
|
|
112
|
+
use_hessian_based_weights=use_hessian_based_weights,
|
|
113
|
+
regularization_factor=regularization_factor)
|
|
114
114
|
|
|
115
115
|
|
|
116
116
|
def keras_gradient_post_training_quantization(in_model: Model,
|
|
117
117
|
representative_data_gen: Callable,
|
|
118
|
-
gptq_config:
|
|
118
|
+
gptq_config: GradientPTQConfig,
|
|
119
119
|
gptq_representative_data_gen: Callable = None,
|
|
120
120
|
target_kpi: KPI = None,
|
|
121
121
|
core_config: CoreConfig = CoreConfig(),
|
|
122
122
|
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
|
123
|
-
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC,
|
|
124
|
-
new_experimental_exporter: bool = True) -> Tuple[Model, UserInformation]:
|
|
123
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
|
|
125
124
|
"""
|
|
126
125
|
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
|
127
126
|
symmetric constraint quantization thresholds (power of two).
|
|
@@ -141,13 +140,12 @@ if FOUND_TF:
|
|
|
141
140
|
Args:
|
|
142
141
|
in_model (Model): Keras model to quantize.
|
|
143
142
|
representative_data_gen (Callable): Dataset used for calibration.
|
|
144
|
-
gptq_config (
|
|
143
|
+
gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
|
|
145
144
|
gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
|
|
146
145
|
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
|
|
147
146
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
148
147
|
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
|
|
149
148
|
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
150
|
-
new_experimental_exporter (bool): Whether to wrap the quantized model using quantization information or not. Enabled by default. Experimental and subject to future changes.
|
|
151
149
|
|
|
152
150
|
Returns:
|
|
153
151
|
|
|
@@ -177,7 +175,7 @@ if FOUND_TF:
|
|
|
177
175
|
with different bitwidths for different layers.
|
|
178
176
|
The candidates bitwidth for quantization should be defined in the target platform model:
|
|
179
177
|
|
|
180
|
-
>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.
|
|
178
|
+
>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=1))
|
|
181
179
|
|
|
182
180
|
For mixed-precision set a target KPI object:
|
|
183
181
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
@@ -199,9 +197,9 @@ if FOUND_TF:
|
|
|
199
197
|
fw_info=fw_info).validate()
|
|
200
198
|
|
|
201
199
|
if core_config.mixed_precision_enable:
|
|
202
|
-
if not isinstance(core_config.mixed_precision_config,
|
|
200
|
+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
203
201
|
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
204
|
-
"
|
|
202
|
+
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
|
|
205
203
|
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
206
204
|
|
|
207
205
|
tb_w = init_tensorboard_writer(fw_info)
|
|
@@ -232,19 +230,7 @@ if FOUND_TF:
|
|
|
232
230
|
if core_config.debug_config.analyze_similarity:
|
|
233
231
|
analyzer_model_quantization(representative_data_gen, tb_w, tg_gptq, fw_impl, fw_info)
|
|
234
232
|
|
|
235
|
-
|
|
236
|
-
Logger.warning('Using new experimental wrapped and ready for export models. To '
|
|
237
|
-
'disable it, please set new_experimental_exporter to False when '
|
|
238
|
-
'calling keras_gradient_post_training_quantization. '
|
|
239
|
-
'If you encounter an issue please file a bug.')
|
|
240
|
-
|
|
241
|
-
return get_exportable_keras_model(tg_gptq)
|
|
242
|
-
|
|
243
|
-
return export_model(tg_gptq,
|
|
244
|
-
fw_info,
|
|
245
|
-
fw_impl,
|
|
246
|
-
tb_w,
|
|
247
|
-
bit_widths_config)
|
|
233
|
+
return get_exportable_keras_model(tg_gptq)
|
|
248
234
|
|
|
249
235
|
else:
|
|
250
236
|
# If tensorflow is not installed,
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Dict, List, Tuple
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.gptq import
|
|
17
|
+
from model_compression_toolkit.gptq import GradientPTQConfig
|
|
18
18
|
from model_compression_toolkit.core import common
|
|
19
19
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
20
20
|
from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \
|
|
@@ -33,7 +33,7 @@ from model_compression_toolkit.trainable_infrastructure.common.get_quantizers im
|
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
def quantization_builder(n: common.BaseNode,
|
|
36
|
-
gptq_config:
|
|
36
|
+
gptq_config: GradientPTQConfig
|
|
37
37
|
) -> Tuple[Dict[str, BaseKerasGPTQTrainableQuantizer], List[BaseKerasInferableQuantizer]]:
|
|
38
38
|
"""
|
|
39
39
|
Build quantizers for a node according to its quantization configuration and
|
|
@@ -41,7 +41,7 @@ def quantization_builder(n: common.BaseNode,
|
|
|
41
41
|
|
|
42
42
|
Args:
|
|
43
43
|
n: Node to build its QuantizeConfig.
|
|
44
|
-
gptq_config (
|
|
44
|
+
gptq_config (GradientPTQConfig): GradientPTQConfigV2 configuration.
|
|
45
45
|
|
|
46
46
|
Returns:
|
|
47
47
|
A dictionary which maps the weights kernel attribute to a quantizer for GPTQ training.
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Callable
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.gptq import RoundingType,
|
|
17
|
+
from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig, GradientPTQConfig
|
|
18
18
|
from model_compression_toolkit.gptq.keras.quantizer.soft_rounding.soft_quantizer_reg import \
|
|
19
19
|
SoftQuantizerRegularization
|
|
20
20
|
|
|
@@ -38,8 +38,6 @@ def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen:
|
|
|
38
38
|
for _ in representative_data_gen():
|
|
39
39
|
num_batches += 1
|
|
40
40
|
|
|
41
|
-
|
|
42
|
-
not type(gptq_config) == GradientPTQConfigV2 else gptq_config.n_epochs
|
|
43
|
-
return SoftQuantizerRegularization(total_gradient_steps=num_batches * n_epochs)
|
|
41
|
+
return SoftQuantizerRegularization(total_gradient_steps=num_batches * gptq_config.n_epochs)
|
|
44
42
|
else:
|
|
45
43
|
return lambda m, e_reg: 0
|
|
@@ -25,7 +25,7 @@ from model_compression_toolkit.logger import Logger
|
|
|
25
25
|
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
|
26
26
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
27
27
|
from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
|
|
28
|
-
from model_compression_toolkit.gptq.common.gptq_config import
|
|
28
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
29
29
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
30
30
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
31
31
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
@@ -46,7 +46,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
46
46
|
def __init__(self,
|
|
47
47
|
graph_float: Graph,
|
|
48
48
|
graph_quant: Graph,
|
|
49
|
-
gptq_config:
|
|
49
|
+
gptq_config: GradientPTQConfig,
|
|
50
50
|
fw_impl: FrameworkImplementation,
|
|
51
51
|
fw_info: FrameworkInfo,
|
|
52
52
|
representative_data_gen: Callable,
|
|
@@ -19,7 +19,7 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
|
|
|
19
19
|
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
|
|
20
20
|
from model_compression_toolkit.logger import Logger
|
|
21
21
|
from model_compression_toolkit.constants import PYTORCH
|
|
22
|
-
from model_compression_toolkit.gptq.common.gptq_config import
|
|
22
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
23
23
|
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
24
24
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
25
25
|
from model_compression_toolkit.core.runner import core_runner
|
|
@@ -29,7 +29,7 @@ from model_compression_toolkit.core.exporter import export_model
|
|
|
29
29
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
30
30
|
from model_compression_toolkit.core import CoreConfig
|
|
31
31
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
32
|
-
|
|
32
|
+
MixedPrecisionQuantizationConfig
|
|
33
33
|
|
|
34
34
|
LR_DEFAULT = 1e-4
|
|
35
35
|
LR_REST_DEFAULT = 1e-4
|
|
@@ -54,7 +54,7 @@ if FOUND_TORCH:
|
|
|
54
54
|
loss: Callable = multiple_tensors_mse_loss,
|
|
55
55
|
log_function: Callable = None,
|
|
56
56
|
use_hessian_based_weights: bool = True,
|
|
57
|
-
regularization_factor: float = REG_DEFAULT) ->
|
|
57
|
+
regularization_factor: float = REG_DEFAULT) -> GradientPTQConfig:
|
|
58
58
|
"""
|
|
59
59
|
Create a GradientPTQConfigV2 instance for Pytorch models.
|
|
60
60
|
|
|
@@ -86,21 +86,19 @@ if FOUND_TORCH:
|
|
|
86
86
|
|
|
87
87
|
"""
|
|
88
88
|
bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
|
|
89
|
-
return
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
89
|
+
return GradientPTQConfig(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
|
|
90
|
+
log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer,
|
|
91
|
+
use_hessian_based_weights=use_hessian_based_weights,
|
|
92
|
+
regularization_factor=regularization_factor)
|
|
93
93
|
|
|
94
94
|
|
|
95
95
|
def pytorch_gradient_post_training_quantization(model: Module,
|
|
96
96
|
representative_data_gen: Callable,
|
|
97
97
|
target_kpi: KPI = None,
|
|
98
98
|
core_config: CoreConfig = CoreConfig(),
|
|
99
|
-
gptq_config:
|
|
99
|
+
gptq_config: GradientPTQConfig = None,
|
|
100
100
|
gptq_representative_data_gen: Callable = None,
|
|
101
|
-
target_platform_capabilities: TargetPlatformCapabilities =
|
|
102
|
-
DEFAULT_PYTORCH_TPC,
|
|
103
|
-
new_experimental_exporter: bool = True):
|
|
101
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
|
|
104
102
|
"""
|
|
105
103
|
Quantize a trained Pytorch module using post-training quantization.
|
|
106
104
|
By default, the module is quantized using a symmetric constraint quantization thresholds
|
|
@@ -122,10 +120,9 @@ if FOUND_TORCH:
|
|
|
122
120
|
representative_data_gen (Callable): Dataset used for calibration.
|
|
123
121
|
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
|
|
124
122
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
125
|
-
gptq_config (
|
|
123
|
+
gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
|
|
126
124
|
gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
|
|
127
125
|
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
128
|
-
new_experimental_exporter (bool): Whether to wrap the quantized model using quantization information or not. Enabled by default. Experimental and subject to future changes.
|
|
129
126
|
|
|
130
127
|
Returns:
|
|
131
128
|
A quantized module and information the user may need to handle the quantized module.
|
|
@@ -157,9 +154,9 @@ if FOUND_TORCH:
|
|
|
157
154
|
"""
|
|
158
155
|
|
|
159
156
|
if core_config.mixed_precision_enable:
|
|
160
|
-
if not isinstance(core_config.mixed_precision_config,
|
|
157
|
+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
161
158
|
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
162
|
-
"
|
|
159
|
+
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
|
|
163
160
|
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
164
161
|
|
|
165
162
|
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
|
@@ -194,22 +191,8 @@ if FOUND_TORCH:
|
|
|
194
191
|
if core_config.debug_config.analyze_similarity:
|
|
195
192
|
analyzer_model_quantization(representative_data_gen, tb_w, graph_gptq, fw_impl, DEFAULT_PYTORCH_INFO)
|
|
196
193
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
# ---------------------- #
|
|
200
|
-
if new_experimental_exporter:
|
|
201
|
-
Logger.warning('Using new experimental wrapped and ready for export models. To '
|
|
202
|
-
'disable it, please set new_experimental_exporter to False when '
|
|
203
|
-
'calling pytorch_gradient_post_training_quantization_experimental. '
|
|
204
|
-
'If you encounter an issue please file a bug.')
|
|
205
|
-
|
|
206
|
-
return get_exportable_pytorch_model(graph_gptq)
|
|
207
|
-
|
|
208
|
-
return export_model(graph_gptq,
|
|
209
|
-
DEFAULT_PYTORCH_INFO,
|
|
210
|
-
fw_impl,
|
|
211
|
-
tb_w,
|
|
212
|
-
bit_widths_config)
|
|
194
|
+
return get_exportable_pytorch_model(graph_gptq)
|
|
195
|
+
|
|
213
196
|
|
|
214
197
|
else:
|
|
215
198
|
# If torch is not installed,
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import List, Dict, Tuple
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.gptq import
|
|
17
|
+
from model_compression_toolkit.gptq import GradientPTQConfig
|
|
18
18
|
from model_compression_toolkit.core import common
|
|
19
19
|
from model_compression_toolkit.core.pytorch.constants import KERNEL
|
|
20
20
|
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizer import \
|
|
@@ -34,7 +34,7 @@ from model_compression_toolkit.trainable_infrastructure.common.get_quantizers im
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
def quantization_builder(n: common.BaseNode,
|
|
37
|
-
gptq_config:
|
|
37
|
+
gptq_config: GradientPTQConfig,
|
|
38
38
|
) -> Tuple[Dict[str, BasePytorchQATTrainableQuantizer],
|
|
39
39
|
List[BasePyTorchInferableQuantizer]]:
|
|
40
40
|
"""
|
|
@@ -43,7 +43,7 @@ def quantization_builder(n: common.BaseNode,
|
|
|
43
43
|
|
|
44
44
|
Args:
|
|
45
45
|
n: Node to build its QuantizeConfig.
|
|
46
|
-
gptq_config (
|
|
46
|
+
gptq_config (GradientPTQConfig): GradientPTQConfigV2 configuration.
|
|
47
47
|
|
|
48
48
|
Returns:
|
|
49
49
|
A dictionary which maps the weights kernel attribute to a quantizer for GPTQ training.
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Callable
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.gptq import RoundingType,
|
|
17
|
+
from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig, GradientPTQConfig
|
|
18
18
|
from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import \
|
|
19
19
|
SoftQuantizerRegularization
|
|
20
20
|
|
|
@@ -38,8 +38,6 @@ def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen:
|
|
|
38
38
|
for _ in representative_data_gen():
|
|
39
39
|
num_batches += 1
|
|
40
40
|
|
|
41
|
-
|
|
42
|
-
not type(gptq_config) == GradientPTQConfigV2 else gptq_config.n_epochs
|
|
43
|
-
return SoftQuantizerRegularization(total_gradient_steps=num_batches * n_epochs)
|
|
41
|
+
return SoftQuantizerRegularization(total_gradient_steps=num_batches * gptq_config.n_epochs)
|
|
44
42
|
else:
|
|
45
43
|
return lambda m, e_reg: 0
|
|
@@ -20,7 +20,7 @@ from model_compression_toolkit.core import common
|
|
|
20
20
|
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
|
21
21
|
from model_compression_toolkit.core.common.statistics_correction.statistics_correction import \
|
|
22
22
|
apply_statistics_correction
|
|
23
|
-
from model_compression_toolkit.gptq.common.gptq_config import
|
|
23
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
24
24
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
25
25
|
from model_compression_toolkit.core.common import FrameworkInfo
|
|
26
26
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
@@ -32,7 +32,7 @@ from model_compression_toolkit.core.common.statistics_correction.apply_bias_corr
|
|
|
32
32
|
from model_compression_toolkit.logger import Logger
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
def _apply_gptq(gptq_config:
|
|
35
|
+
def _apply_gptq(gptq_config: GradientPTQConfig,
|
|
36
36
|
representative_data_gen: Callable,
|
|
37
37
|
tb_w: TensorboardWriter,
|
|
38
38
|
tg: Graph,
|
|
@@ -74,7 +74,7 @@ def _apply_gptq(gptq_config: GradientPTQConfigV2,
|
|
|
74
74
|
|
|
75
75
|
def gptq_runner(tg: Graph,
|
|
76
76
|
core_config: CoreConfig,
|
|
77
|
-
gptq_config:
|
|
77
|
+
gptq_config: GradientPTQConfig,
|
|
78
78
|
representative_data_gen: Callable,
|
|
79
79
|
gptq_representative_data_gen: Callable,
|
|
80
80
|
fw_info: FrameworkInfo,
|
|
@@ -16,4 +16,5 @@
|
|
|
16
16
|
from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
|
|
17
17
|
from model_compression_toolkit.core.common.pruning.pruning_config import ImportanceMetric, PruningConfig, ChannelsFilteringStrategy
|
|
18
18
|
from model_compression_toolkit.pruning.keras.pruning_facade import keras_pruning_experimental
|
|
19
|
+
from model_compression_toolkit.pruning.pytorch.pruning_facade import pytorch_pruning_experimental
|
|
19
20
|
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Callable, Tuple
|
|
17
|
+
from model_compression_toolkit import get_target_platform_capabilities
|
|
18
|
+
from model_compression_toolkit.constants import FOUND_TORCH, PYTORCH
|
|
19
|
+
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
20
|
+
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
|
21
|
+
from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
|
|
22
|
+
from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
|
|
23
|
+
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
|
24
|
+
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
|
25
|
+
from model_compression_toolkit.logger import Logger
|
|
26
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
27
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
|
28
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Check if PyTorch is available in the environment.
|
|
32
|
+
if FOUND_TORCH:
|
|
33
|
+
# Import PyTorch-specific modules from the model compression toolkit.
|
|
34
|
+
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
|
|
35
|
+
from model_compression_toolkit.core.pytorch.pruning.pruning_pytorch_implementation import \
|
|
36
|
+
PruningPytorchImplementation
|
|
37
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
38
|
+
from torch.nn import Module
|
|
39
|
+
|
|
40
|
+
# Set the default Target Platform Capabilities (TPC) for PyTorch.
|
|
41
|
+
DEFAULT_PYOTRCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
|
42
|
+
|
|
43
|
+
def pytorch_pruning_experimental(model: Module,
|
|
44
|
+
target_kpi: KPI,
|
|
45
|
+
representative_data_gen: Callable,
|
|
46
|
+
pruning_config: PruningConfig = PruningConfig(),
|
|
47
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYOTRCH_TPC) -> \
|
|
48
|
+
Tuple[Module, PruningInfo]:
|
|
49
|
+
"""
|
|
50
|
+
Perform structured pruning on a Pytorch model to meet a specified target KPI.
|
|
51
|
+
This function prunes the provided model according to the target KPI by grouping and pruning
|
|
52
|
+
channels based on each layer's SIMD configuration in the Target Platform Capabilities (TPC).
|
|
53
|
+
By default, the importance of each channel group is determined using the Label-Free Hessian
|
|
54
|
+
(LFH) method, assessing each channel's sensitivity to the Hessian of the loss function.
|
|
55
|
+
This pruning strategy considers groups of channels together for a more hardware-friendly
|
|
56
|
+
architecture. The process involves analyzing the model with a representative dataset to
|
|
57
|
+
identify groups of channels that can be removed with minimal impact on performance.
|
|
58
|
+
|
|
59
|
+
Notice that the pruned model must be retrained to recover the compressed model's performance.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
model (Module): The PyTorch model to be pruned.
|
|
63
|
+
target_kpi (KPI): Key Performance Indicators specifying the pruning targets.
|
|
64
|
+
representative_data_gen (Callable): A function to generate representative data for pruning analysis.
|
|
65
|
+
pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config.
|
|
66
|
+
target_platform_capabilities (TargetPlatformCapabilities): Platform-specific constraints and capabilities.
|
|
67
|
+
Defaults to DEFAULT_PYTORCH_TPC.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Tuple[Model, PruningInfo]: A tuple containing the pruned Pytorch model and associated pruning information.
|
|
71
|
+
|
|
72
|
+
Note:
|
|
73
|
+
The pruned model should be fine-tuned or retrained to recover or improve its performance post-pruning.
|
|
74
|
+
|
|
75
|
+
Examples:
|
|
76
|
+
|
|
77
|
+
Import MCT:
|
|
78
|
+
|
|
79
|
+
>>> import model_compression_toolkit as mct
|
|
80
|
+
|
|
81
|
+
Import a Pytorch model:
|
|
82
|
+
|
|
83
|
+
>>> from torchvision.models import resnet50, ResNet50_Weights
|
|
84
|
+
>>> model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
|
85
|
+
|
|
86
|
+
Create a random dataset generator:
|
|
87
|
+
|
|
88
|
+
>>> import numpy as np
|
|
89
|
+
>>> def repr_datagen(): yield [np.random.random((1, 3, 224, 224))]
|
|
90
|
+
|
|
91
|
+
Define a target KPI for pruning.
|
|
92
|
+
Here, we aim to reduce the memory footprint of weights by 50%, assuming the model weights
|
|
93
|
+
are represented in float32 data type (thus, each parameter is represented using 4 bytes):
|
|
94
|
+
|
|
95
|
+
>>> dense_nparams = sum(p.numel() for p in model.state_dict().values())
|
|
96
|
+
>>> target_kpi = mct.KPI(weights_memory=dense_nparams * 4 * 0.5)
|
|
97
|
+
|
|
98
|
+
Optionally, define a pruning configuration. num_score_approximations can be passed
|
|
99
|
+
to configure the number of importance scores that will be calculated for each channel.
|
|
100
|
+
A higher value for this parameter yields more precise score approximations but also
|
|
101
|
+
extends the duration of the pruning process:
|
|
102
|
+
|
|
103
|
+
>>> pruning_config = mct.pruning.PruningConfig(num_score_approximations=1)
|
|
104
|
+
|
|
105
|
+
Perform pruning:
|
|
106
|
+
|
|
107
|
+
>>> pruned_model, pruning_info = mct.pruning.pytorch_pruning_experimental(model=model, target_kpi=target_kpi, representative_data_gen=repr_datagen, pruning_config=pruning_config)
|
|
108
|
+
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
# Instantiate the Pytorch framework implementation.
|
|
112
|
+
fw_impl = PruningPytorchImplementation()
|
|
113
|
+
|
|
114
|
+
# Convert the original Pytorch model to an internal graph representation.
|
|
115
|
+
float_graph = read_model_to_graph(model,
|
|
116
|
+
representative_data_gen,
|
|
117
|
+
target_platform_capabilities,
|
|
118
|
+
DEFAULT_PYTORCH_INFO,
|
|
119
|
+
fw_impl)
|
|
120
|
+
|
|
121
|
+
# Apply quantization configuration to the graph. This step is necessary even when not quantizing,
|
|
122
|
+
# as it prepares the graph for the pruning process.
|
|
123
|
+
float_graph_with_compression_config = set_quantization_configuration_to_graph(float_graph,
|
|
124
|
+
quant_config=DEFAULTCONFIG,
|
|
125
|
+
mixed_precision_enable=False)
|
|
126
|
+
|
|
127
|
+
# Create a Pruner object with the graph and configuration.
|
|
128
|
+
pruner = Pruner(float_graph_with_compression_config,
|
|
129
|
+
DEFAULT_PYTORCH_INFO,
|
|
130
|
+
fw_impl,
|
|
131
|
+
target_kpi,
|
|
132
|
+
representative_data_gen,
|
|
133
|
+
pruning_config,
|
|
134
|
+
target_platform_capabilities)
|
|
135
|
+
|
|
136
|
+
# Apply the pruning process.
|
|
137
|
+
pruned_graph = pruner.prune_graph()
|
|
138
|
+
|
|
139
|
+
# Retrieve pruning information which includes the pruning masks and scores.
|
|
140
|
+
pruning_info = pruner.get_pruning_info()
|
|
141
|
+
|
|
142
|
+
# Rebuild the pruned graph back into a trainable Pytorch model.
|
|
143
|
+
pruned_model, _ = FloatPyTorchModelBuilder(graph=pruned_graph).build_model()
|
|
144
|
+
pruned_model.trainable = True
|
|
145
|
+
|
|
146
|
+
# Return the pruned model along with its pruning information.
|
|
147
|
+
return pruned_model, pruning_info
|
|
148
|
+
|
|
149
|
+
else:
|
|
150
|
+
def pytorch_pruning_experimental(*args, **kwargs):
|
|
151
|
+
"""
|
|
152
|
+
Raises a critical error if PyTorch is not installed but the pruning function is invoked.
|
|
153
|
+
|
|
154
|
+
This function acts as a placeholder to provide a clear error message when PyTorch dependencies are missing,
|
|
155
|
+
indicating that the pruning functionality cannot be used without the PyTorch framework installed.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
*args: Variable length argument list, not used.
|
|
159
|
+
**kwargs: Arbitrary keyword arguments, not used.
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
CriticalError: Indicates that PyTorch must be installed to use this function.
|
|
163
|
+
"""
|
|
164
|
+
Logger.critical('Installing Pytorch is mandatory '
|
|
165
|
+
'when using pytorch_pruning_experimental. '
|
|
166
|
+
'Could not find the torch package.') # pragma: no cover
|
|
@@ -13,5 +13,5 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from model_compression_toolkit.ptq.pytorch.quantization_facade import
|
|
17
|
-
from model_compression_toolkit.ptq.keras.quantization_facade import
|
|
16
|
+
from model_compression_toolkit.ptq.pytorch.quantization_facade import pytorch_post_training_quantization
|
|
17
|
+
from model_compression_toolkit.ptq.keras.quantization_facade import keras_post_training_quantization
|