mct-nightly 2.4.0.20250706.701__py3-none-any.whl → 2.4.0.20250708.612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/RECORD +36 -38
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/collectors/base_collector.py +4 -1
  5. model_compression_toolkit/core/common/collectors/mean_collector.py +7 -4
  6. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +7 -4
  7. model_compression_toolkit/core/common/model_collector.py +17 -3
  8. model_compression_toolkit/core/common/pruning/memory_calculator.py +1 -1
  9. model_compression_toolkit/core/common/quantization/node_quantization_config.py +25 -87
  10. model_compression_toolkit/core/common/quantization/quantization_config.py +0 -1
  11. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +26 -17
  12. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +27 -49
  13. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +12 -7
  14. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +4 -14
  15. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -1
  16. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +4 -13
  17. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +3 -3
  18. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +5 -7
  19. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +7 -5
  20. model_compression_toolkit/core/graph_prep_runner.py +1 -11
  21. model_compression_toolkit/core/keras/default_framework_info.py +1 -1
  22. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +21 -11
  23. model_compression_toolkit/core/keras/keras_implementation.py +2 -2
  24. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +8 -0
  25. model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
  26. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +9 -1
  27. model_compression_toolkit/core/quantization_prep_runner.py +2 -2
  28. model_compression_toolkit/gptq/keras/quantization_facade.py +0 -3
  29. model_compression_toolkit/ptq/keras/quantization_facade.py +0 -3
  30. model_compression_toolkit/qat/keras/quantization_facade.py +0 -3
  31. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +0 -2
  32. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +0 -6
  33. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +2 -4
  34. model_compression_toolkit/core/common/model_validation.py +0 -41
  35. model_compression_toolkit/core/keras/keras_model_validation.py +0 -37
  36. {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/WHEEL +0 -0
  37. {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/licenses/LICENSE.md +0 -0
  38. {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,10 @@ import numpy as np
29
29
  from model_compression_toolkit.logger import Logger
30
30
 
31
31
 
32
+ # default output channel axis to use when it's not defined in node's fw_info.
33
+ _default_output_channel_axis = 1
34
+
35
+
32
36
  class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplementation):
33
37
  """
34
38
  Implementation of the PruningFramework for the Pytorch framework. This class provides
@@ -190,6 +194,10 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
190
194
 
191
195
  return attributes_with_axis
192
196
 
197
+ @property
198
+ def default_output_channel_axis(self):
199
+ return _default_output_channel_axis
200
+
193
201
 
194
202
  def _is_pytorch_node_pruning_section_edge(node: BaseNode) -> bool:
195
203
  """
@@ -283,7 +291,7 @@ def _edit_node_input_shape(node: BaseNode,
283
291
 
284
292
  # Adjust the last dimension of the shape to match the number of unpruned (retained) channels.
285
293
  # This is done by summing the mask, as each '1' in the mask represents a retained channel.
286
- channel_axis = node.out_channel_axis
294
+ channel_axis = _default_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis
287
295
  new_input_shape[0][channel_axis] = int(np.sum(input_mask))
288
296
 
289
297
  # Update the node's input shape with the new dimensions.
@@ -87,8 +87,8 @@ def quantization_preparation_runner(graph: Graph,
87
87
  # Calculate quantization params
88
88
  ######################################
89
89
 
90
- calculate_quantization_params(graph, fw_impl=fw_impl, repr_data_gen_fn=representative_data_gen,
91
- hessian_info_service=hessian_info_service)
90
+ calculate_quantization_params(graph, core_config.quantization_config, fw_impl=fw_impl,
91
+ repr_data_gen_fn=representative_data_gen, hessian_info_service=hessian_info_service)
92
92
 
93
93
  if tb_w is not None:
94
94
  tb_w.add_graph(graph, 'thresholds_selection')
@@ -43,7 +43,6 @@ if FOUND_TF:
43
43
  import tensorflow as tf
44
44
  from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
45
45
  from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
46
- from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
47
46
  from tensorflow.keras.models import Model
48
47
  from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss, sample_layer_attention_loss
49
48
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
@@ -235,8 +234,6 @@ if FOUND_TF:
235
234
  if core_config.debug_config.bypass:
236
235
  return in_model, None
237
236
 
238
- KerasModelValidation(model=in_model).validate()
239
-
240
237
  if core_config.is_mixed_precision_enabled:
241
238
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
242
239
  Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
@@ -38,7 +38,6 @@ if FOUND_TF:
38
38
  AttachTpcToKeras
39
39
  from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
40
40
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
41
- from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
42
41
  from tensorflow.keras.models import Model
43
42
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
44
43
  from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
@@ -129,8 +128,6 @@ if FOUND_TF:
129
128
  if core_config.debug_config.bypass:
130
129
  return in_model, None
131
130
 
132
- KerasModelValidation(model=in_model).validate()
133
-
134
131
  if core_config.is_mixed_precision_enabled:
135
132
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
136
133
  Logger.critical("Given quantization config to mixed-precision facade is not of type "
@@ -38,7 +38,6 @@ if FOUND_TF:
38
38
 
39
39
  from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
40
40
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
41
- from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
42
41
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
43
42
  from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
44
43
 
@@ -175,8 +174,6 @@ if FOUND_TF:
175
174
  f"If you encounter an issue, please open an issue in our GitHub "
176
175
  f"project https://github.com/sony/model_optimization")
177
176
 
178
- KerasModelValidation(model=in_model).validate()
179
-
180
177
  if core_config.is_mixed_precision_enabled:
181
178
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
182
179
  Logger.critical("Given quantization config to mixed-precision facade is not of type "
@@ -48,7 +48,6 @@ def get_trainable_quantizer_weights_config(
48
48
  final_attr_cfg.enable_weights_quantization,
49
49
  final_attr_cfg.weights_channels_axis[0], # Output channel axis
50
50
  final_attr_cfg.weights_per_channel_threshold,
51
- final_node_cfg.min_threshold,
52
51
  weights_quantization_candidates)
53
52
 
54
53
 
@@ -76,7 +75,6 @@ def get_trainable_quantizer_activation_config(
76
75
  final_cfg.activation_n_bits,
77
76
  final_cfg.activation_quantization_params,
78
77
  final_cfg.enable_activation_quantization,
79
- final_cfg.min_threshold,
80
78
  activation_quantization_candidates)
81
79
 
82
80
 
@@ -44,7 +44,6 @@ class TrainableQuantizerActivationConfig:
44
44
  activation_n_bits: int,
45
45
  activation_quantization_params: Dict,
46
46
  enable_activation_quantization: bool,
47
- min_threshold: float,
48
47
  activation_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None,
49
48
  ):
50
49
  """
@@ -55,13 +54,11 @@ class TrainableQuantizerActivationConfig:
55
54
  activation_n_bits (int): Number of bits to quantize the activations.
56
55
  activation_quantization_params (Dict): Dictionary that contains activation quantization params.
57
56
  enable_activation_quantization (bool): Whether to quantize the layer's activations or not.
58
- min_threshold (float): Minimum threshold to use during thresholds selection.
59
57
  """
60
58
  self.activation_quantization_method = activation_quantization_method
61
59
  self.activation_n_bits = activation_n_bits
62
60
  self.activation_quantization_params = activation_quantization_params
63
61
  self.enable_activation_quantization = enable_activation_quantization
64
- self.min_threshold = min_threshold
65
62
  self.activation_bits_candidates = activation_quantization_candidates
66
63
 
67
64
 
@@ -73,7 +70,6 @@ class TrainableQuantizerWeightsConfig:
73
70
  enable_weights_quantization: bool,
74
71
  weights_channels_axis: int,
75
72
  weights_per_channel_threshold: bool,
76
- min_threshold: float,
77
73
  weights_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None,
78
74
  ):
79
75
  """
@@ -86,7 +82,6 @@ class TrainableQuantizerWeightsConfig:
86
82
  enable_weights_quantization (bool): Whether to quantize the layer's weights or not.
87
83
  weights_channels_axis (int): Axis to quantize a node's kernel when quantizing per-channel.
88
84
  weights_per_channel_threshold (bool): Whether to quantize the weights per-channel or not (per-tensor).
89
- min_threshold (float): Minimum threshold to use during thresholds selection.
90
85
  """
91
86
  self.weights_quantization_method = weights_quantization_method
92
87
  self.weights_n_bits = weights_n_bits
@@ -94,5 +89,4 @@ class TrainableQuantizerWeightsConfig:
94
89
  self.enable_weights_quantization = enable_weights_quantization
95
90
  self.weights_channels_axis = weights_channels_axis
96
91
  self.weights_per_channel_threshold = weights_per_channel_threshold
97
- self.min_threshold = min_threshold
98
92
  self.weights_bits_candidates = weights_quantization_candidates
@@ -77,13 +77,11 @@ def config_deserialization(in_config: dict) -> Union[TrainableQuantizerWeightsCo
77
77
  weights_quantization_params=weights_quantization_params,
78
78
  enable_weights_quantization=in_config[C.ENABLE_WEIGHTS_QUANTIZATION],
79
79
  weights_channels_axis=in_config[C.WEIGHTS_CHANNELS_AXIS],
80
- weights_per_channel_threshold=in_config[C.WEIGHTS_PER_CHANNEL_THRESHOLD],
81
- min_threshold=in_config[C.MIN_THRESHOLD])
80
+ weights_per_channel_threshold=in_config[C.WEIGHTS_PER_CHANNEL_THRESHOLD])
82
81
  elif in_config[C.IS_ACTIVATIONS]:
83
82
  return TrainableQuantizerActivationConfig(activation_quantization_method=QuantizationMethod(in_config[C.ACTIVATION_QUANTIZATION_METHOD]),
84
83
  activation_n_bits=in_config[C.ACTIVATION_N_BITS],
85
84
  activation_quantization_params=in_config[C.ACTIVATION_QUANTIZATION_PARAMS],
86
- enable_activation_quantization=in_config[C.ENABLE_ACTIVATION_QUANTIZATION],
87
- min_threshold=in_config[C.MIN_THRESHOLD])
85
+ enable_activation_quantization=in_config[C.ENABLE_ACTIVATION_QUANTIZATION])
88
86
  else:
89
87
  raise NotImplemented # pragma: no cover
@@ -1,41 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Any
3
-
4
- from model_compression_toolkit.core import FrameworkInfo
5
-
6
-
7
- class ModelValidation:
8
- """
9
- Class to define validation methods in order to validate the received model to quantize.
10
- """
11
-
12
- def __init__(self,
13
- model: Any):
14
- """
15
- Initialize a ModelValidation object.
16
-
17
- Args:
18
- model: Model to check its validity.
19
- """
20
- self.model = model
21
-
22
- @abstractmethod
23
- def validate_output_channel_consistency(self):
24
- """
25
-
26
- Validate that output channels index in all layers of the model are the same.
27
- If the model has layers with different output channels index, it should throw an exception.
28
-
29
- """
30
- raise NotImplemented(
31
- f'Framework validation class did not implement validate_output_channel_consistency') # pragma: no cover
32
-
33
- def validate(self):
34
- """
35
-
36
- Run all validation methods before the quantization process starts.
37
-
38
- """
39
- self.validate_output_channel_consistency()
40
-
41
-
@@ -1,37 +0,0 @@
1
- from tensorflow.keras.models import Model
2
-
3
- from model_compression_toolkit.core.common.framework_info import get_fw_info
4
- from model_compression_toolkit.core.common.framework_info import ChannelAxis
5
- from model_compression_toolkit.core.common.model_validation import ModelValidation
6
- from model_compression_toolkit.core.keras.constants import CHANNELS_FORMAT, CHANNELS_FORMAT_LAST, CHANNELS_FORMAT_FIRST
7
-
8
-
9
- class KerasModelValidation(ModelValidation):
10
- """
11
- Class to define validation methods in order to validate the received Keras model to quantize.
12
- """
13
-
14
- def __init__(self, model: Model):
15
- """
16
- Initialize a KerasModelValidation object.
17
-
18
- Args:
19
- model: Keras model to check its validity.
20
- """
21
-
22
- super(KerasModelValidation, self).__init__(model=model)
23
-
24
- def validate_output_channel_consistency(self):
25
- """
26
-
27
- Validate that output channels index in all layers of the model are the same.
28
- If the model has layers with different output channels index, an exception is thrown.
29
-
30
- """
31
- fw_info = get_fw_info()
32
- for layer in self.model.layers:
33
- data_format = layer.get_config().get(CHANNELS_FORMAT)
34
- if data_format is not None:
35
- assert (data_format == CHANNELS_FORMAT_LAST and fw_info.get_out_channel_axis(layer) == ChannelAxis.NHWC.value
36
- or data_format == CHANNELS_FORMAT_FIRST and fw_info.get_out_channel_axis(layer) == ChannelAxis.NCHW.value), \
37
- f'Model can not have layers with different data formats.'