mct-nightly 1.11.0.20240304.post404__py3-none-any.whl → 1.11.0.20240306.post426__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/METADATA +5 -5
  2. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/RECORD +42 -40
  3. model_compression_toolkit/core/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_implementation.py +2 -2
  5. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +4 -70
  6. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  7. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  8. model_compression_toolkit/core/common/pruning/memory_calculator.py +19 -1
  9. model_compression_toolkit/core/common/quantization/core_config.py +3 -3
  10. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +0 -3
  11. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -3
  12. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +0 -1
  13. model_compression_toolkit/core/keras/keras_implementation.py +2 -2
  14. model_compression_toolkit/core/keras/kpi_data_facade.py +5 -6
  15. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +19 -19
  16. model_compression_toolkit/core/pytorch/constants.py +3 -0
  17. model_compression_toolkit/core/pytorch/kpi_data_facade.py +5 -5
  18. model_compression_toolkit/core/pytorch/pruning/__init__.py +14 -0
  19. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +315 -0
  20. model_compression_toolkit/core/pytorch/pytorch_implementation.py +2 -2
  21. model_compression_toolkit/gptq/__init__.py +1 -1
  22. model_compression_toolkit/gptq/common/gptq_config.py +5 -72
  23. model_compression_toolkit/gptq/keras/gptq_training.py +2 -2
  24. model_compression_toolkit/gptq/keras/quantization_facade.py +19 -33
  25. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +3 -3
  26. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +2 -4
  27. model_compression_toolkit/gptq/pytorch/gptq_training.py +2 -2
  28. model_compression_toolkit/gptq/pytorch/quantization_facade.py +14 -31
  29. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +3 -3
  30. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +2 -4
  31. model_compression_toolkit/gptq/runner.py +3 -3
  32. model_compression_toolkit/pruning/__init__.py +1 -0
  33. model_compression_toolkit/pruning/pytorch/__init__.py +14 -0
  34. model_compression_toolkit/pruning/pytorch/pruning_facade.py +166 -0
  35. model_compression_toolkit/ptq/__init__.py +2 -2
  36. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -30
  37. model_compression_toolkit/ptq/pytorch/quantization_facade.py +12 -30
  38. model_compression_toolkit/qat/keras/quantization_facade.py +6 -9
  39. model_compression_toolkit/qat/pytorch/quantization_facade.py +3 -7
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +0 -64
  41. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +0 -53
  42. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/LICENSE.md +0 -0
  43. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/WHEEL +0 -0
  44. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from model_compression_toolkit.logger import Logger
22
22
  from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
23
23
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
24
24
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
25
- MixedPrecisionQuantizationConfigV2
25
+ MixedPrecisionQuantizationConfig
26
26
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
27
27
  from model_compression_toolkit.core.exporter import export_model
28
28
  from model_compression_toolkit.core.runner import core_runner
@@ -40,12 +40,11 @@ if FOUND_TF:
40
40
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
41
41
 
42
42
 
43
- def keras_post_training_quantization_experimental(in_model: Model,
44
- representative_data_gen: Callable,
45
- target_kpi: KPI = None,
46
- core_config: CoreConfig = CoreConfig(),
47
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC,
48
- new_experimental_exporter: bool = True):
43
+ def keras_post_training_quantization(in_model: Model,
44
+ representative_data_gen: Callable,
45
+ target_kpi: KPI = None,
46
+ core_config: CoreConfig = CoreConfig(),
47
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
49
48
  """
50
49
  Quantize a trained Keras model using post-training quantization. The model is quantized using a
51
50
  symmetric constraint quantization thresholds (power of two).
@@ -65,7 +64,6 @@ if FOUND_TF:
65
64
  target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
66
65
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
67
66
  target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
68
- new_experimental_exporter (bool): Whether to wrap the quantized model using quantization information or not. Enabled by default. Experimental and subject to future changes.
69
67
 
70
68
  Returns:
71
69
 
@@ -99,7 +97,7 @@ if FOUND_TF:
99
97
  The candidates bitwidth for quantization should be defined in the target platform model.
100
98
  In this example we use 1 image to search mixed-precision configuration:
101
99
 
102
- >>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfigV2(num_of_images=1))
100
+ >>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=1))
103
101
 
104
102
  For mixed-precision set a target KPI object:
105
103
  Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
@@ -111,7 +109,7 @@ if FOUND_TF:
111
109
  Pass the model, the representative dataset generator, the configuration and the target KPI to get a
112
110
  quantized model:
113
111
 
114
- >>> quantized_model, quantization_info = mct.ptq.keras_post_training_quantization_experimental(model, repr_datagen, kpi, core_config=config)
112
+ >>> quantized_model, quantization_info = mct.ptq.keras_post_training_quantization(model, repr_datagen, kpi, core_config=config)
115
113
 
116
114
  For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
117
115
 
@@ -123,14 +121,11 @@ if FOUND_TF:
123
121
  fw_info=fw_info).validate()
124
122
 
125
123
  if core_config.mixed_precision_enable:
126
- if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
124
+ if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
127
125
  Logger.error("Given quantization config to mixed-precision facade is not of type "
128
- "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
126
+ "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
129
127
  "API, or pass a valid mixed precision configuration.") # pragma: no cover
130
128
 
131
- Logger.info("Using experimental mixed-precision quantization. "
132
- "If you encounter an issue please file a bug.")
133
-
134
129
  tb_w = init_tensorboard_writer(fw_info)
135
130
 
136
131
  fw_impl = KerasImplementation()
@@ -153,26 +148,14 @@ if FOUND_TF:
153
148
  fw_impl,
154
149
  fw_info)
155
150
 
156
- if new_experimental_exporter:
157
- Logger.warning('Using new experimental wrapped and ready for export models. To '
158
- 'disable it, please set new_experimental_exporter to False when '
159
- 'calling keras_post_training_quantization_experimental. '
160
- 'If you encounter an issue please file a bug.')
161
-
162
- return get_exportable_keras_model(tg)
163
-
164
- return export_model(tg,
165
- fw_info,
166
- fw_impl,
167
- tb_w,
168
- bit_widths_config)
151
+ return get_exportable_keras_model(tg)
169
152
 
170
153
 
171
154
 
172
155
  else:
173
156
  # If tensorflow is not installed,
174
157
  # we raise an exception when trying to use these functions.
175
- def keras_post_training_quantization_experimental(*args, **kwargs):
158
+ def keras_post_training_quantization(*args, **kwargs):
176
159
  Logger.critical('Installing tensorflow is mandatory '
177
- 'when using keras_post_training_quantization_experimental. '
160
+ 'when using keras_post_training_quantization. '
178
161
  'Could not find Tensorflow package.') # pragma: no cover
@@ -22,7 +22,7 @@ from model_compression_toolkit.target_platform_capabilities.target_platform impo
22
22
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
23
23
  from model_compression_toolkit.core import CoreConfig
24
24
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
25
- MixedPrecisionQuantizationConfigV2
25
+ MixedPrecisionQuantizationConfig
26
26
  from model_compression_toolkit.core.runner import core_runner
27
27
  from model_compression_toolkit.ptq.runner import ptq_runner
28
28
  from model_compression_toolkit.core.exporter import export_model
@@ -39,12 +39,11 @@ if FOUND_TORCH:
39
39
 
40
40
  DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
41
41
 
42
- def pytorch_post_training_quantization_experimental(in_module: Module,
43
- representative_data_gen: Callable,
44
- target_kpi: KPI = None,
45
- core_config: CoreConfig = CoreConfig(),
46
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC,
47
- new_experimental_exporter: bool = True):
42
+ def pytorch_post_training_quantization(in_module: Module,
43
+ representative_data_gen: Callable,
44
+ target_kpi: KPI = None,
45
+ core_config: CoreConfig = CoreConfig(),
46
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
48
47
  """
49
48
  Quantize a trained Pytorch module using post-training quantization.
50
49
  By default, the module is quantized using a symmetric constraint quantization thresholds
@@ -64,7 +63,6 @@ if FOUND_TORCH:
64
63
  target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
65
64
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
66
65
  target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
67
- new_experimental_exporter (bool): Whether to wrap the quantized model using quantization information or not. Enabled by default. Experimental and subject to future changes.
68
66
 
69
67
  Returns:
70
68
  A quantized module and information the user may need to handle the quantized module.
@@ -89,20 +87,17 @@ if FOUND_TORCH:
89
87
  Set number of clibration iterations to 1:
90
88
 
91
89
  >>> import model_compression_toolkit as mct
92
- >>> quantized_module, quantization_info = mct.ptq.pytorch_post_training_quantization_experimental(module, repr_datagen)
90
+ >>> quantized_module, quantization_info = mct.ptq.pytorch_post_training_quantization(module, repr_datagen)
93
91
 
94
92
  """
95
93
 
96
94
  if core_config.mixed_precision_enable:
97
- if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
95
+ if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
98
96
  Logger.error("Given quantization config to mixed-precision facade is not of type "
99
- "MixedPrecisionQuantizationConfigV2. Please use "
97
+ "MixedPrecisionQuantizationConfig. Please use "
100
98
  "pytorch_post_training_quantization API, or pass a valid mixed precision "
101
99
  "configuration.") # pragma: no cover
102
100
 
103
- Logger.info("Using experimental mixed-precision quantization. "
104
- "If you encounter an issue please file a bug.")
105
-
106
101
  tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
107
102
 
108
103
  fw_impl = PytorchImplementation()
@@ -126,26 +121,13 @@ if FOUND_TORCH:
126
121
  fw_impl,
127
122
  DEFAULT_PYTORCH_INFO)
128
123
 
129
- if new_experimental_exporter:
130
- Logger.warning('Using new experimental wrapped and ready for export models. To '
131
- 'disable it, please set new_experimental_exporter to False when '
132
- 'calling pytorch_post_training_quantization_experimental. '
133
- 'If you encounter an issue please file a bug.')
134
-
135
- return get_exportable_pytorch_model(tg)
136
-
137
- quantized_model, user_info = export_model(tg,
138
- DEFAULT_PYTORCH_INFO,
139
- fw_impl,
140
- tb_w,
141
- bit_widths_config)
124
+ return get_exportable_pytorch_model(tg)
142
125
 
143
- return quantized_model, user_info
144
126
 
145
127
  else:
146
128
  # If torch is not installed,
147
129
  # we raise an exception when trying to use these functions.
148
- def pytorch_post_training_quantization_experimental(*args, **kwargs):
130
+ def pytorch_post_training_quantization(*args, **kwargs):
149
131
  Logger.critical('Installing Pytorch is mandatory '
150
- 'when using pytorch_post_training_quantization_experimental. '
132
+ 'when using pytorch_post_training_quantization. '
151
133
  'Could not find the torch package.') # pragma: no cover
@@ -22,7 +22,7 @@ from model_compression_toolkit.logger import Logger
22
22
  from model_compression_toolkit.constants import FOUND_TF
23
23
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
24
24
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
25
- MixedPrecisionQuantizationConfigV2
25
+ MixedPrecisionQuantizationConfig
26
26
  from mct_quantizers import KerasActivationQuantizationHolder
27
27
  from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
28
28
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
@@ -145,7 +145,7 @@ if FOUND_TF:
145
145
  If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
146
146
  The candidates bitwidth for quantization should be defined in the target platform model:
147
147
 
148
- >>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfigV2())
148
+ >>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig())
149
149
 
150
150
  For mixed-precision set a target KPI object:
151
151
  Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
@@ -170,13 +170,10 @@ if FOUND_TF:
170
170
  fw_info=fw_info).validate()
171
171
 
172
172
  if core_config.mixed_precision_enable:
173
- if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
173
+ if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
174
174
  Logger.error("Given quantization config to mixed-precision facade is not of type "
175
- "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization API,"
176
- "or pass a valid mixed precision configuration.")
177
-
178
- Logger.info("Using experimental mixed-precision quantization. "
179
- "If you encounter an issue please file a bug.")
175
+ "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization API,"
176
+ "or pass a valid mixed precision configuration.")
180
177
 
181
178
  tb_w = init_tensorboard_writer(fw_info)
182
179
 
@@ -239,7 +236,7 @@ if FOUND_TF:
239
236
  If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
240
237
  The candidates bitwidth for quantization should be defined in the target platform model:
241
238
 
242
- >>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfigV2())
239
+ >>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig())
243
240
 
244
241
  For mixed-precision set a target KPI object:
245
242
  Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
@@ -25,7 +25,7 @@ from model_compression_toolkit.logger import Logger
25
25
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
26
26
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
27
27
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
28
- MixedPrecisionQuantizationConfigV2
28
+ MixedPrecisionQuantizationConfig
29
29
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
30
30
  TargetPlatformCapabilities
31
31
  from model_compression_toolkit.core.runner import core_runner
@@ -138,16 +138,12 @@ if FOUND_TORCH:
138
138
  """
139
139
 
140
140
  if core_config.mixed_precision_enable:
141
- if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
141
+ if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
142
142
  Logger.error("Given quantization config to mixed-precision facade is not of type "
143
- "MixedPrecisionQuantizationConfigV2. Please use pytorch_post_training_quantization API,"
143
+ "MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API,"
144
144
  "or pass a valid mixed precision configuration.")
145
145
 
146
- Logger.info("Using experimental mixed-precision quantization. "
147
- "If you encounter an issue please file a bug.")
148
-
149
146
  tb_w = init_tensorboard_writer(fw_info)
150
-
151
147
  fw_impl = PytorchImplementation()
152
148
 
153
149
  # Ignore trace hessian service as we do not use it here
@@ -1,64 +0,0 @@
1
- # Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import numpy as np
17
- from sklearn.cluster import KMeans
18
-
19
- import model_compression_toolkit.core.common.quantization.quantization_config as qc
20
- from model_compression_toolkit.constants import LUT_VALUES, SCALE_PER_CHANNEL, MIN_THRESHOLD, EPS
21
-
22
-
23
- def kmeans_tensor(tensor_data: np.ndarray,
24
- p: int,
25
- n_bits: int,
26
- per_channel: bool = False,
27
- channel_axis: int = 1,
28
- n_iter: int = 10,
29
- min_threshold: float = MIN_THRESHOLD,
30
- quant_error_method: qc.QuantizationErrorMethod = None) -> dict:
31
- """
32
- Compute the 2^nbit cluster assignments for the given tensor according to the k-means algorithm.
33
-
34
- Args:
35
- tensor_data: Tensor content as Numpy array.
36
- p: p-norm to use for the Lp-norm distance.
37
- n_bits: Number of bits to quantize the tensor.
38
- per_channel: Whether the quantization should be per-channel or not.
39
- channel_axis: Output channel index.
40
- n_iter: Number of iterations to search_methods for the optimal threshold.
41
- min_threshold: Minimal threshold to chose when the computed one is smaller.
42
- quant_error_method: an error function to optimize the parameters' selection accordingly (not used for this method).
43
-
44
- Returns:
45
- A dictionary containing the cluster assignments according to the k-means algorithm and the scales per channel.
46
- """
47
- if len(np.unique(tensor_data.flatten())) < 2 ** n_bits:
48
- n_clusters = len(np.unique(tensor_data.flatten()))
49
- else:
50
- n_clusters = 2 ** n_bits
51
- kmeans = KMeans(n_clusters=n_clusters)
52
- axis_not_channel = [i for i in range(len(tensor_data.shape))]
53
- if channel_axis in axis_not_channel:
54
- axis_not_channel.remove(channel_axis)
55
- if per_channel:
56
- scales_per_channel = np.max(np.abs(tensor_data), axis=tuple(axis_not_channel), keepdims=True)
57
- else:
58
- scales_per_channel = np.max(np.abs(tensor_data), keepdims=True)
59
- tensor_for_kmeans = (tensor_data / (scales_per_channel + EPS))
60
- kmeans.fit(tensor_for_kmeans.reshape(-1, 1))
61
-
62
- return {LUT_VALUES: kmeans.cluster_centers_,
63
- SCALE_PER_CHANNEL: scales_per_channel,
64
- }
@@ -1,53 +0,0 @@
1
- # Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from sklearn.cluster import KMeans
17
- import numpy as np
18
-
19
- from model_compression_toolkit.constants import LUT_VALUES, MIN_THRESHOLD, SCALE_PER_CHANNEL
20
- from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import kmeans_assign_clusters
21
-
22
-
23
- def kmeans_quantizer(tensor_data: np.ndarray,
24
- n_bits: int,
25
- signed: bool,
26
- quantization_params: dict,
27
- per_channel: bool,
28
- output_channels_axis: int) -> np.ndarray:
29
- """
30
- Quantize a tensor according to k-means algorithm. This function assigns cluster centers
31
- to the tensor data values.
32
-
33
- Args:
34
- tensor_data: Tensor values to quantize.
35
- n_bits: Number of bits to quantize the tensor.
36
- signed: Whether the tensor contains negative values or not.
37
- quantization_params: Dictionary of specific parameters for this quantization function.
38
- per_channel: Whether to use separate quantization per output channel.
39
- output_channels_axis: Axis of the output channel.
40
-
41
- Returns:
42
- Quantized data.
43
- """
44
- eps = 1e-8
45
- lut_values = quantization_params[LUT_VALUES]
46
- scales_per_channel = quantization_params[SCALE_PER_CHANNEL]
47
- tensor = (tensor_data / (scales_per_channel + eps))
48
- shape_before_kmeans = tensor.shape
49
- cluster_assignments = kmeans_assign_clusters(lut_values, tensor.reshape(-1, 1))
50
- quant_tensor = lut_values[cluster_assignments].reshape(shape_before_kmeans)
51
- if per_channel:
52
- quant_tensor = (quant_tensor * scales_per_channel)
53
- return quant_tensor