mct-nightly 1.11.0.20240304.post404__py3-none-any.whl → 1.11.0.20240306.post426__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/METADATA +5 -5
  2. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/RECORD +42 -40
  3. model_compression_toolkit/core/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_implementation.py +2 -2
  5. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +4 -70
  6. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  7. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  8. model_compression_toolkit/core/common/pruning/memory_calculator.py +19 -1
  9. model_compression_toolkit/core/common/quantization/core_config.py +3 -3
  10. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +0 -3
  11. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -3
  12. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +0 -1
  13. model_compression_toolkit/core/keras/keras_implementation.py +2 -2
  14. model_compression_toolkit/core/keras/kpi_data_facade.py +5 -6
  15. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +19 -19
  16. model_compression_toolkit/core/pytorch/constants.py +3 -0
  17. model_compression_toolkit/core/pytorch/kpi_data_facade.py +5 -5
  18. model_compression_toolkit/core/pytorch/pruning/__init__.py +14 -0
  19. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +315 -0
  20. model_compression_toolkit/core/pytorch/pytorch_implementation.py +2 -2
  21. model_compression_toolkit/gptq/__init__.py +1 -1
  22. model_compression_toolkit/gptq/common/gptq_config.py +5 -72
  23. model_compression_toolkit/gptq/keras/gptq_training.py +2 -2
  24. model_compression_toolkit/gptq/keras/quantization_facade.py +19 -33
  25. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +3 -3
  26. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +2 -4
  27. model_compression_toolkit/gptq/pytorch/gptq_training.py +2 -2
  28. model_compression_toolkit/gptq/pytorch/quantization_facade.py +14 -31
  29. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +3 -3
  30. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +2 -4
  31. model_compression_toolkit/gptq/runner.py +3 -3
  32. model_compression_toolkit/pruning/__init__.py +1 -0
  33. model_compression_toolkit/pruning/pytorch/__init__.py +14 -0
  34. model_compression_toolkit/pruning/pytorch/pruning_facade.py +166 -0
  35. model_compression_toolkit/ptq/__init__.py +2 -2
  36. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -30
  37. model_compression_toolkit/ptq/pytorch/quantization_facade.py +12 -30
  38. model_compression_toolkit/qat/keras/quantization_facade.py +6 -9
  39. model_compression_toolkit/qat/pytorch/quantization_facade.py +3 -7
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +0 -64
  41. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +0 -53
  42. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/LICENSE.md +0 -0
  43. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/WHEEL +0 -0
  44. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/top_level.txt +0 -0
@@ -21,10 +21,10 @@ from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
21
21
  from model_compression_toolkit.logger import Logger
22
22
  from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
23
23
  from model_compression_toolkit.core.common.user_info import UserInformation
24
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
24
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
25
25
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
26
26
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
27
- from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfigV2
27
+ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
28
28
  from model_compression_toolkit.core import CoreConfig
29
29
  from model_compression_toolkit.core.runner import core_runner
30
30
  from model_compression_toolkit.gptq.runner import gptq_runner
@@ -66,7 +66,7 @@ if FOUND_TF:
66
66
  loss: Callable = GPTQMultipleTensorsLoss(),
67
67
  log_function: Callable = None,
68
68
  use_hessian_based_weights: bool = True,
69
- regularization_factor: float = REG_DEFAULT) -> GradientPTQConfigV2:
69
+ regularization_factor: float = REG_DEFAULT) -> GradientPTQConfig:
70
70
  """
71
71
  Create a GradientPTQConfigV2 instance for Keras models.
72
72
 
@@ -102,26 +102,25 @@ if FOUND_TF:
102
102
  """
103
103
  bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT,
104
104
  momentum=GPTQ_MOMENTUM)
105
- return GradientPTQConfigV2(n_epochs,
106
- optimizer,
107
- optimizer_rest=optimizer_rest,
108
- loss=loss,
109
- log_function=log_function,
110
- train_bias=True,
111
- optimizer_bias=bias_optimizer,
112
- use_hessian_based_weights=use_hessian_based_weights,
113
- regularization_factor=regularization_factor)
105
+ return GradientPTQConfig(n_epochs,
106
+ optimizer,
107
+ optimizer_rest=optimizer_rest,
108
+ loss=loss,
109
+ log_function=log_function,
110
+ train_bias=True,
111
+ optimizer_bias=bias_optimizer,
112
+ use_hessian_based_weights=use_hessian_based_weights,
113
+ regularization_factor=regularization_factor)
114
114
 
115
115
 
116
116
  def keras_gradient_post_training_quantization(in_model: Model,
117
117
  representative_data_gen: Callable,
118
- gptq_config: GradientPTQConfigV2,
118
+ gptq_config: GradientPTQConfig,
119
119
  gptq_representative_data_gen: Callable = None,
120
120
  target_kpi: KPI = None,
121
121
  core_config: CoreConfig = CoreConfig(),
122
122
  fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
123
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC,
124
- new_experimental_exporter: bool = True) -> Tuple[Model, UserInformation]:
123
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
125
124
  """
126
125
  Quantize a trained Keras model using post-training quantization. The model is quantized using a
127
126
  symmetric constraint quantization thresholds (power of two).
@@ -141,13 +140,12 @@ if FOUND_TF:
141
140
  Args:
142
141
  in_model (Model): Keras model to quantize.
143
142
  representative_data_gen (Callable): Dataset used for calibration.
144
- gptq_config (GradientPTQConfigV2): Configuration for using gptq (e.g. optimizer).
143
+ gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
145
144
  gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
146
145
  target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
147
146
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
148
147
  fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
149
148
  target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
150
- new_experimental_exporter (bool): Whether to wrap the quantized model using quantization information or not. Enabled by default. Experimental and subject to future changes.
151
149
 
152
150
  Returns:
153
151
 
@@ -177,7 +175,7 @@ if FOUND_TF:
177
175
  with different bitwidths for different layers.
178
176
  The candidates bitwidth for quantization should be defined in the target platform model:
179
177
 
180
- >>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfigV2(num_of_images=1))
178
+ >>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=1))
181
179
 
182
180
  For mixed-precision set a target KPI object:
183
181
  Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
@@ -199,9 +197,9 @@ if FOUND_TF:
199
197
  fw_info=fw_info).validate()
200
198
 
201
199
  if core_config.mixed_precision_enable:
202
- if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
200
+ if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
203
201
  Logger.error("Given quantization config to mixed-precision facade is not of type "
204
- "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
202
+ "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
205
203
  "API, or pass a valid mixed precision configuration.") # pragma: no cover
206
204
 
207
205
  tb_w = init_tensorboard_writer(fw_info)
@@ -232,19 +230,7 @@ if FOUND_TF:
232
230
  if core_config.debug_config.analyze_similarity:
233
231
  analyzer_model_quantization(representative_data_gen, tb_w, tg_gptq, fw_impl, fw_info)
234
232
 
235
- if new_experimental_exporter:
236
- Logger.warning('Using new experimental wrapped and ready for export models. To '
237
- 'disable it, please set new_experimental_exporter to False when '
238
- 'calling keras_gradient_post_training_quantization. '
239
- 'If you encounter an issue please file a bug.')
240
-
241
- return get_exportable_keras_model(tg_gptq)
242
-
243
- return export_model(tg_gptq,
244
- fw_info,
245
- fw_impl,
246
- tb_w,
247
- bit_widths_config)
233
+ return get_exportable_keras_model(tg_gptq)
248
234
 
249
235
  else:
250
236
  # If tensorflow is not installed,
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import Dict, List, Tuple
16
16
 
17
- from model_compression_toolkit.gptq import GradientPTQConfigV2
17
+ from model_compression_toolkit.gptq import GradientPTQConfig
18
18
  from model_compression_toolkit.core import common
19
19
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
20
20
  from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \
@@ -33,7 +33,7 @@ from model_compression_toolkit.trainable_infrastructure.common.get_quantizers im
33
33
 
34
34
 
35
35
  def quantization_builder(n: common.BaseNode,
36
- gptq_config: GradientPTQConfigV2
36
+ gptq_config: GradientPTQConfig
37
37
  ) -> Tuple[Dict[str, BaseKerasGPTQTrainableQuantizer], List[BaseKerasInferableQuantizer]]:
38
38
  """
39
39
  Build quantizers for a node according to its quantization configuration and
@@ -41,7 +41,7 @@ def quantization_builder(n: common.BaseNode,
41
41
 
42
42
  Args:
43
43
  n: Node to build its QuantizeConfig.
44
- gptq_config (GradientPTQConfigV2): GradientPTQConfigV2 configuration.
44
+ gptq_config (GradientPTQConfig): GradientPTQConfigV2 configuration.
45
45
 
46
46
  Returns:
47
47
  A dictionary which maps the weights kernel attribute to a quantizer for GPTQ training.
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import Callable
16
16
 
17
- from model_compression_toolkit.gptq import RoundingType, GradientPTQConfigV2, GradientPTQConfig
17
+ from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig, GradientPTQConfig
18
18
  from model_compression_toolkit.gptq.keras.quantizer.soft_rounding.soft_quantizer_reg import \
19
19
  SoftQuantizerRegularization
20
20
 
@@ -38,8 +38,6 @@ def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen:
38
38
  for _ in representative_data_gen():
39
39
  num_batches += 1
40
40
 
41
- n_epochs = GradientPTQConfigV2.from_v1(n_ptq_iter=num_batches, config_v1=gptq_config).n_epochs if \
42
- not type(gptq_config) == GradientPTQConfigV2 else gptq_config.n_epochs
43
- return SoftQuantizerRegularization(total_gradient_steps=num_batches * n_epochs)
41
+ return SoftQuantizerRegularization(total_gradient_steps=num_batches * gptq_config.n_epochs)
44
42
  else:
45
43
  return lambda m, e_reg: 0
@@ -25,7 +25,7 @@ from model_compression_toolkit.logger import Logger
25
25
  from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
26
26
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
27
27
  from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
28
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
28
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
29
29
  from model_compression_toolkit.core.common import Graph, BaseNode
30
30
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
31
31
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
@@ -46,7 +46,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
46
46
  def __init__(self,
47
47
  graph_float: Graph,
48
48
  graph_quant: Graph,
49
- gptq_config: GradientPTQConfigV2,
49
+ gptq_config: GradientPTQConfig,
50
50
  fw_impl: FrameworkImplementation,
51
51
  fw_info: FrameworkInfo,
52
52
  representative_data_gen: Callable,
@@ -19,7 +19,7 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
19
19
  from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
20
20
  from model_compression_toolkit.logger import Logger
21
21
  from model_compression_toolkit.constants import PYTORCH
22
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
22
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
23
23
  from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
24
24
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
25
25
  from model_compression_toolkit.core.runner import core_runner
@@ -29,7 +29,7 @@ from model_compression_toolkit.core.exporter import export_model
29
29
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
30
30
  from model_compression_toolkit.core import CoreConfig
31
31
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
32
- MixedPrecisionQuantizationConfigV2
32
+ MixedPrecisionQuantizationConfig
33
33
 
34
34
  LR_DEFAULT = 1e-4
35
35
  LR_REST_DEFAULT = 1e-4
@@ -54,7 +54,7 @@ if FOUND_TORCH:
54
54
  loss: Callable = multiple_tensors_mse_loss,
55
55
  log_function: Callable = None,
56
56
  use_hessian_based_weights: bool = True,
57
- regularization_factor: float = REG_DEFAULT) -> GradientPTQConfigV2:
57
+ regularization_factor: float = REG_DEFAULT) -> GradientPTQConfig:
58
58
  """
59
59
  Create a GradientPTQConfigV2 instance for Pytorch models.
60
60
 
@@ -86,21 +86,19 @@ if FOUND_TORCH:
86
86
 
87
87
  """
88
88
  bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
89
- return GradientPTQConfigV2(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
90
- log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer,
91
- use_hessian_based_weights=use_hessian_based_weights,
92
- regularization_factor=regularization_factor)
89
+ return GradientPTQConfig(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
90
+ log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer,
91
+ use_hessian_based_weights=use_hessian_based_weights,
92
+ regularization_factor=regularization_factor)
93
93
 
94
94
 
95
95
  def pytorch_gradient_post_training_quantization(model: Module,
96
96
  representative_data_gen: Callable,
97
97
  target_kpi: KPI = None,
98
98
  core_config: CoreConfig = CoreConfig(),
99
- gptq_config: GradientPTQConfigV2 = None,
99
+ gptq_config: GradientPTQConfig = None,
100
100
  gptq_representative_data_gen: Callable = None,
101
- target_platform_capabilities: TargetPlatformCapabilities =
102
- DEFAULT_PYTORCH_TPC,
103
- new_experimental_exporter: bool = True):
101
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
104
102
  """
105
103
  Quantize a trained Pytorch module using post-training quantization.
106
104
  By default, the module is quantized using a symmetric constraint quantization thresholds
@@ -122,10 +120,9 @@ if FOUND_TORCH:
122
120
  representative_data_gen (Callable): Dataset used for calibration.
123
121
  target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
124
122
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
125
- gptq_config (GradientPTQConfigV2): Configuration for using gptq (e.g. optimizer).
123
+ gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
126
124
  gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
127
125
  target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
128
- new_experimental_exporter (bool): Whether to wrap the quantized model using quantization information or not. Enabled by default. Experimental and subject to future changes.
129
126
 
130
127
  Returns:
131
128
  A quantized module and information the user may need to handle the quantized module.
@@ -157,9 +154,9 @@ if FOUND_TORCH:
157
154
  """
158
155
 
159
156
  if core_config.mixed_precision_enable:
160
- if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
157
+ if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
161
158
  Logger.error("Given quantization config to mixed-precision facade is not of type "
162
- "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
159
+ "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
163
160
  "API, or pass a valid mixed precision configuration.") # pragma: no cover
164
161
 
165
162
  tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
@@ -194,22 +191,8 @@ if FOUND_TORCH:
194
191
  if core_config.debug_config.analyze_similarity:
195
192
  analyzer_model_quantization(representative_data_gen, tb_w, graph_gptq, fw_impl, DEFAULT_PYTORCH_INFO)
196
193
 
197
- # ---------------------- #
198
- # Export
199
- # ---------------------- #
200
- if new_experimental_exporter:
201
- Logger.warning('Using new experimental wrapped and ready for export models. To '
202
- 'disable it, please set new_experimental_exporter to False when '
203
- 'calling pytorch_gradient_post_training_quantization_experimental. '
204
- 'If you encounter an issue please file a bug.')
205
-
206
- return get_exportable_pytorch_model(graph_gptq)
207
-
208
- return export_model(graph_gptq,
209
- DEFAULT_PYTORCH_INFO,
210
- fw_impl,
211
- tb_w,
212
- bit_widths_config)
194
+ return get_exportable_pytorch_model(graph_gptq)
195
+
213
196
 
214
197
  else:
215
198
  # If torch is not installed,
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import List, Dict, Tuple
16
16
 
17
- from model_compression_toolkit.gptq import GradientPTQConfigV2
17
+ from model_compression_toolkit.gptq import GradientPTQConfig
18
18
  from model_compression_toolkit.core import common
19
19
  from model_compression_toolkit.core.pytorch.constants import KERNEL
20
20
  from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizer import \
@@ -34,7 +34,7 @@ from model_compression_toolkit.trainable_infrastructure.common.get_quantizers im
34
34
 
35
35
 
36
36
  def quantization_builder(n: common.BaseNode,
37
- gptq_config: GradientPTQConfigV2,
37
+ gptq_config: GradientPTQConfig,
38
38
  ) -> Tuple[Dict[str, BasePytorchQATTrainableQuantizer],
39
39
  List[BasePyTorchInferableQuantizer]]:
40
40
  """
@@ -43,7 +43,7 @@ def quantization_builder(n: common.BaseNode,
43
43
 
44
44
  Args:
45
45
  n: Node to build its QuantizeConfig.
46
- gptq_config (GradientPTQConfigV2): GradientPTQConfigV2 configuration.
46
+ gptq_config (GradientPTQConfig): GradientPTQConfigV2 configuration.
47
47
 
48
48
  Returns:
49
49
  A dictionary which maps the weights kernel attribute to a quantizer for GPTQ training.
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import Callable
16
16
 
17
- from model_compression_toolkit.gptq import RoundingType, GradientPTQConfigV2, GradientPTQConfig
17
+ from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig, GradientPTQConfig
18
18
  from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import \
19
19
  SoftQuantizerRegularization
20
20
 
@@ -38,8 +38,6 @@ def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen:
38
38
  for _ in representative_data_gen():
39
39
  num_batches += 1
40
40
 
41
- n_epochs = GradientPTQConfigV2.from_v1(n_ptq_iter=num_batches, config_v1=gptq_config).n_epochs if \
42
- not type(gptq_config) == GradientPTQConfigV2 else gptq_config.n_epochs
43
- return SoftQuantizerRegularization(total_gradient_steps=num_batches * n_epochs)
41
+ return SoftQuantizerRegularization(total_gradient_steps=num_batches * gptq_config.n_epochs)
44
42
  else:
45
43
  return lambda m, e_reg: 0
@@ -20,7 +20,7 @@ from model_compression_toolkit.core import common
20
20
  from model_compression_toolkit.core.common.hessian import HessianInfoService
21
21
  from model_compression_toolkit.core.common.statistics_correction.statistics_correction import \
22
22
  apply_statistics_correction
23
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
23
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
24
24
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
25
25
  from model_compression_toolkit.core.common import FrameworkInfo
26
26
  from model_compression_toolkit.core.common.graph.base_graph import Graph
@@ -32,7 +32,7 @@ from model_compression_toolkit.core.common.statistics_correction.apply_bias_corr
32
32
  from model_compression_toolkit.logger import Logger
33
33
 
34
34
 
35
- def _apply_gptq(gptq_config: GradientPTQConfigV2,
35
+ def _apply_gptq(gptq_config: GradientPTQConfig,
36
36
  representative_data_gen: Callable,
37
37
  tb_w: TensorboardWriter,
38
38
  tg: Graph,
@@ -74,7 +74,7 @@ def _apply_gptq(gptq_config: GradientPTQConfigV2,
74
74
 
75
75
  def gptq_runner(tg: Graph,
76
76
  core_config: CoreConfig,
77
- gptq_config: GradientPTQConfigV2,
77
+ gptq_config: GradientPTQConfig,
78
78
  representative_data_gen: Callable,
79
79
  gptq_representative_data_gen: Callable,
80
80
  fw_info: FrameworkInfo,
@@ -16,4 +16,5 @@
16
16
  from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
17
17
  from model_compression_toolkit.core.common.pruning.pruning_config import ImportanceMetric, PruningConfig, ChannelsFilteringStrategy
18
18
  from model_compression_toolkit.pruning.keras.pruning_facade import keras_pruning_experimental
19
+ from model_compression_toolkit.pruning.pytorch.pruning_facade import pytorch_pruning_experimental
19
20
 
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,166 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Callable, Tuple
17
+ from model_compression_toolkit import get_target_platform_capabilities
18
+ from model_compression_toolkit.constants import FOUND_TORCH, PYTORCH
19
+ from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
20
+ from model_compression_toolkit.core.common.pruning.pruner import Pruner
21
+ from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
22
+ from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
23
+ from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
24
+ from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
25
+ from model_compression_toolkit.logger import Logger
26
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
27
+ from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
28
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
29
+
30
+
31
+ # Check if PyTorch is available in the environment.
32
+ if FOUND_TORCH:
33
+ # Import PyTorch-specific modules from the model compression toolkit.
34
+ from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
35
+ from model_compression_toolkit.core.pytorch.pruning.pruning_pytorch_implementation import \
36
+ PruningPytorchImplementation
37
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
38
+ from torch.nn import Module
39
+
40
+ # Set the default Target Platform Capabilities (TPC) for PyTorch.
41
+ DEFAULT_PYOTRCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
42
+
43
+ def pytorch_pruning_experimental(model: Module,
44
+ target_kpi: KPI,
45
+ representative_data_gen: Callable,
46
+ pruning_config: PruningConfig = PruningConfig(),
47
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYOTRCH_TPC) -> \
48
+ Tuple[Module, PruningInfo]:
49
+ """
50
+ Perform structured pruning on a Pytorch model to meet a specified target KPI.
51
+ This function prunes the provided model according to the target KPI by grouping and pruning
52
+ channels based on each layer's SIMD configuration in the Target Platform Capabilities (TPC).
53
+ By default, the importance of each channel group is determined using the Label-Free Hessian
54
+ (LFH) method, assessing each channel's sensitivity to the Hessian of the loss function.
55
+ This pruning strategy considers groups of channels together for a more hardware-friendly
56
+ architecture. The process involves analyzing the model with a representative dataset to
57
+ identify groups of channels that can be removed with minimal impact on performance.
58
+
59
+ Notice that the pruned model must be retrained to recover the compressed model's performance.
60
+
61
+ Args:
62
+ model (Module): The PyTorch model to be pruned.
63
+ target_kpi (KPI): Key Performance Indicators specifying the pruning targets.
64
+ representative_data_gen (Callable): A function to generate representative data for pruning analysis.
65
+ pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config.
66
+ target_platform_capabilities (TargetPlatformCapabilities): Platform-specific constraints and capabilities.
67
+ Defaults to DEFAULT_PYTORCH_TPC.
68
+
69
+ Returns:
70
+ Tuple[Model, PruningInfo]: A tuple containing the pruned Pytorch model and associated pruning information.
71
+
72
+ Note:
73
+ The pruned model should be fine-tuned or retrained to recover or improve its performance post-pruning.
74
+
75
+ Examples:
76
+
77
+ Import MCT:
78
+
79
+ >>> import model_compression_toolkit as mct
80
+
81
+ Import a Pytorch model:
82
+
83
+ >>> from torchvision.models import resnet50, ResNet50_Weights
84
+ >>> model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
85
+
86
+ Create a random dataset generator:
87
+
88
+ >>> import numpy as np
89
+ >>> def repr_datagen(): yield [np.random.random((1, 3, 224, 224))]
90
+
91
+ Define a target KPI for pruning.
92
+ Here, we aim to reduce the memory footprint of weights by 50%, assuming the model weights
93
+ are represented in float32 data type (thus, each parameter is represented using 4 bytes):
94
+
95
+ >>> dense_nparams = sum(p.numel() for p in model.state_dict().values())
96
+ >>> target_kpi = mct.KPI(weights_memory=dense_nparams * 4 * 0.5)
97
+
98
+ Optionally, define a pruning configuration. num_score_approximations can be passed
99
+ to configure the number of importance scores that will be calculated for each channel.
100
+ A higher value for this parameter yields more precise score approximations but also
101
+ extends the duration of the pruning process:
102
+
103
+ >>> pruning_config = mct.pruning.PruningConfig(num_score_approximations=1)
104
+
105
+ Perform pruning:
106
+
107
+ >>> pruned_model, pruning_info = mct.pruning.pytorch_pruning_experimental(model=model, target_kpi=target_kpi, representative_data_gen=repr_datagen, pruning_config=pruning_config)
108
+
109
+ """
110
+
111
+ # Instantiate the Pytorch framework implementation.
112
+ fw_impl = PruningPytorchImplementation()
113
+
114
+ # Convert the original Pytorch model to an internal graph representation.
115
+ float_graph = read_model_to_graph(model,
116
+ representative_data_gen,
117
+ target_platform_capabilities,
118
+ DEFAULT_PYTORCH_INFO,
119
+ fw_impl)
120
+
121
+ # Apply quantization configuration to the graph. This step is necessary even when not quantizing,
122
+ # as it prepares the graph for the pruning process.
123
+ float_graph_with_compression_config = set_quantization_configuration_to_graph(float_graph,
124
+ quant_config=DEFAULTCONFIG,
125
+ mixed_precision_enable=False)
126
+
127
+ # Create a Pruner object with the graph and configuration.
128
+ pruner = Pruner(float_graph_with_compression_config,
129
+ DEFAULT_PYTORCH_INFO,
130
+ fw_impl,
131
+ target_kpi,
132
+ representative_data_gen,
133
+ pruning_config,
134
+ target_platform_capabilities)
135
+
136
+ # Apply the pruning process.
137
+ pruned_graph = pruner.prune_graph()
138
+
139
+ # Retrieve pruning information which includes the pruning masks and scores.
140
+ pruning_info = pruner.get_pruning_info()
141
+
142
+ # Rebuild the pruned graph back into a trainable Pytorch model.
143
+ pruned_model, _ = FloatPyTorchModelBuilder(graph=pruned_graph).build_model()
144
+ pruned_model.trainable = True
145
+
146
+ # Return the pruned model along with its pruning information.
147
+ return pruned_model, pruning_info
148
+
149
+ else:
150
+ def pytorch_pruning_experimental(*args, **kwargs):
151
+ """
152
+ Raises a critical error if PyTorch is not installed but the pruning function is invoked.
153
+
154
+ This function acts as a placeholder to provide a clear error message when PyTorch dependencies are missing,
155
+ indicating that the pruning functionality cannot be used without the PyTorch framework installed.
156
+
157
+ Args:
158
+ *args: Variable length argument list, not used.
159
+ **kwargs: Arbitrary keyword arguments, not used.
160
+
161
+ Raises:
162
+ CriticalError: Indicates that PyTorch must be installed to use this function.
163
+ """
164
+ Logger.critical('Installing Pytorch is mandatory '
165
+ 'when using pytorch_pruning_experimental. '
166
+ 'Could not find the torch package.') # pragma: no cover
@@ -13,5 +13,5 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.ptq.pytorch.quantization_facade import pytorch_post_training_quantization_experimental
17
- from model_compression_toolkit.ptq.keras.quantization_facade import keras_post_training_quantization_experimental
16
+ from model_compression_toolkit.ptq.pytorch.quantization_facade import pytorch_post_training_quantization
17
+ from model_compression_toolkit.ptq.keras.quantization_facade import keras_post_training_quantization