mct-nightly 1.11.0.20240304.post404__py3-none-any.whl → 1.11.0.20240305.post352__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/METADATA +5 -5
  2. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/RECORD +32 -30
  3. model_compression_toolkit/core/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_implementation.py +2 -2
  5. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +4 -70
  6. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  7. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  8. model_compression_toolkit/core/common/pruning/memory_calculator.py +19 -1
  9. model_compression_toolkit/core/common/quantization/core_config.py +3 -3
  10. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +0 -3
  11. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -3
  12. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +0 -1
  13. model_compression_toolkit/core/keras/keras_implementation.py +2 -2
  14. model_compression_toolkit/core/keras/kpi_data_facade.py +5 -6
  15. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +19 -19
  16. model_compression_toolkit/core/pytorch/constants.py +3 -0
  17. model_compression_toolkit/core/pytorch/kpi_data_facade.py +5 -5
  18. model_compression_toolkit/core/pytorch/pruning/__init__.py +14 -0
  19. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +315 -0
  20. model_compression_toolkit/core/pytorch/pytorch_implementation.py +2 -2
  21. model_compression_toolkit/gptq/keras/quantization_facade.py +4 -4
  22. model_compression_toolkit/gptq/pytorch/quantization_facade.py +3 -3
  23. model_compression_toolkit/pruning/__init__.py +1 -0
  24. model_compression_toolkit/pruning/pytorch/__init__.py +14 -0
  25. model_compression_toolkit/pruning/pytorch/pruning_facade.py +166 -0
  26. model_compression_toolkit/ptq/keras/quantization_facade.py +4 -7
  27. model_compression_toolkit/ptq/pytorch/quantization_facade.py +3 -6
  28. model_compression_toolkit/qat/keras/quantization_facade.py +6 -9
  29. model_compression_toolkit/qat/pytorch/quantization_facade.py +3 -7
  30. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +0 -64
  31. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +0 -53
  32. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/LICENSE.md +0 -0
  33. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/WHEEL +0 -0
  34. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,166 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Callable, Tuple
17
+ from model_compression_toolkit import get_target_platform_capabilities
18
+ from model_compression_toolkit.constants import FOUND_TORCH, PYTORCH
19
+ from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
20
+ from model_compression_toolkit.core.common.pruning.pruner import Pruner
21
+ from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
22
+ from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
23
+ from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
24
+ from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
25
+ from model_compression_toolkit.logger import Logger
26
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
27
+ from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
28
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
29
+
30
+
31
+ # Check if PyTorch is available in the environment.
32
+ if FOUND_TORCH:
33
+ # Import PyTorch-specific modules from the model compression toolkit.
34
+ from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
35
+ from model_compression_toolkit.core.pytorch.pruning.pruning_pytorch_implementation import \
36
+ PruningPytorchImplementation
37
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
38
+ from torch.nn import Module
39
+
40
+ # Set the default Target Platform Capabilities (TPC) for PyTorch.
41
+ DEFAULT_PYOTRCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
42
+
43
+ def pytorch_pruning_experimental(model: Module,
44
+ target_kpi: KPI,
45
+ representative_data_gen: Callable,
46
+ pruning_config: PruningConfig = PruningConfig(),
47
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYOTRCH_TPC) -> \
48
+ Tuple[Module, PruningInfo]:
49
+ """
50
+ Perform structured pruning on a Pytorch model to meet a specified target KPI.
51
+ This function prunes the provided model according to the target KPI by grouping and pruning
52
+ channels based on each layer's SIMD configuration in the Target Platform Capabilities (TPC).
53
+ By default, the importance of each channel group is determined using the Label-Free Hessian
54
+ (LFH) method, assessing each channel's sensitivity to the Hessian of the loss function.
55
+ This pruning strategy considers groups of channels together for a more hardware-friendly
56
+ architecture. The process involves analyzing the model with a representative dataset to
57
+ identify groups of channels that can be removed with minimal impact on performance.
58
+
59
+ Notice that the pruned model must be retrained to recover the compressed model's performance.
60
+
61
+ Args:
62
+ model (Module): The PyTorch model to be pruned.
63
+ target_kpi (KPI): Key Performance Indicators specifying the pruning targets.
64
+ representative_data_gen (Callable): A function to generate representative data for pruning analysis.
65
+ pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config.
66
+ target_platform_capabilities (TargetPlatformCapabilities): Platform-specific constraints and capabilities.
67
+ Defaults to DEFAULT_PYTORCH_TPC.
68
+
69
+ Returns:
70
+ Tuple[Model, PruningInfo]: A tuple containing the pruned Pytorch model and associated pruning information.
71
+
72
+ Note:
73
+ The pruned model should be fine-tuned or retrained to recover or improve its performance post-pruning.
74
+
75
+ Examples:
76
+
77
+ Import MCT:
78
+
79
+ >>> import model_compression_toolkit as mct
80
+
81
+ Import a Pytorch model:
82
+
83
+ >>> from torchvision.models import resnet50, ResNet50_Weights
84
+ >>> model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
85
+
86
+ Create a random dataset generator:
87
+
88
+ >>> import numpy as np
89
+ >>> def repr_datagen(): yield [np.random.random((1, 3, 224, 224))]
90
+
91
+ Define a target KPI for pruning.
92
+ Here, we aim to reduce the memory footprint of weights by 50%, assuming the model weights
93
+ are represented in float32 data type (thus, each parameter is represented using 4 bytes):
94
+
95
+ >>> dense_nparams = sum(p.numel() for p in model.state_dict().values())
96
+ >>> target_kpi = mct.KPI(weights_memory=dense_nparams * 4 * 0.5)
97
+
98
+ Optionally, define a pruning configuration. num_score_approximations can be passed
99
+ to configure the number of importance scores that will be calculated for each channel.
100
+ A higher value for this parameter yields more precise score approximations but also
101
+ extends the duration of the pruning process:
102
+
103
+ >>> pruning_config = mct.pruning.PruningConfig(num_score_approximations=1)
104
+
105
+ Perform pruning:
106
+
107
+ >>> pruned_model, pruning_info = mct.pruning.pytorch_pruning_experimental(model=model, target_kpi=target_kpi, representative_data_gen=repr_datagen, pruning_config=pruning_config)
108
+
109
+ """
110
+
111
+ # Instantiate the Pytorch framework implementation.
112
+ fw_impl = PruningPytorchImplementation()
113
+
114
+ # Convert the original Pytorch model to an internal graph representation.
115
+ float_graph = read_model_to_graph(model,
116
+ representative_data_gen,
117
+ target_platform_capabilities,
118
+ DEFAULT_PYTORCH_INFO,
119
+ fw_impl)
120
+
121
+ # Apply quantization configuration to the graph. This step is necessary even when not quantizing,
122
+ # as it prepares the graph for the pruning process.
123
+ float_graph_with_compression_config = set_quantization_configuration_to_graph(float_graph,
124
+ quant_config=DEFAULTCONFIG,
125
+ mixed_precision_enable=False)
126
+
127
+ # Create a Pruner object with the graph and configuration.
128
+ pruner = Pruner(float_graph_with_compression_config,
129
+ DEFAULT_PYTORCH_INFO,
130
+ fw_impl,
131
+ target_kpi,
132
+ representative_data_gen,
133
+ pruning_config,
134
+ target_platform_capabilities)
135
+
136
+ # Apply the pruning process.
137
+ pruned_graph = pruner.prune_graph()
138
+
139
+ # Retrieve pruning information which includes the pruning masks and scores.
140
+ pruning_info = pruner.get_pruning_info()
141
+
142
+ # Rebuild the pruned graph back into a trainable Pytorch model.
143
+ pruned_model, _ = FloatPyTorchModelBuilder(graph=pruned_graph).build_model()
144
+ pruned_model.trainable = True
145
+
146
+ # Return the pruned model along with its pruning information.
147
+ return pruned_model, pruning_info
148
+
149
+ else:
150
+ def pytorch_pruning_experimental(*args, **kwargs):
151
+ """
152
+ Raises a critical error if PyTorch is not installed but the pruning function is invoked.
153
+
154
+ This function acts as a placeholder to provide a clear error message when PyTorch dependencies are missing,
155
+ indicating that the pruning functionality cannot be used without the PyTorch framework installed.
156
+
157
+ Args:
158
+ *args: Variable length argument list, not used.
159
+ **kwargs: Arbitrary keyword arguments, not used.
160
+
161
+ Raises:
162
+ CriticalError: Indicates that PyTorch must be installed to use this function.
163
+ """
164
+ Logger.critical('Installing Pytorch is mandatory '
165
+ 'when using pytorch_pruning_experimental. '
166
+ 'Could not find the torch package.') # pragma: no cover
@@ -22,7 +22,7 @@ from model_compression_toolkit.logger import Logger
22
22
  from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
23
23
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
24
24
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
25
- MixedPrecisionQuantizationConfigV2
25
+ MixedPrecisionQuantizationConfig
26
26
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
27
27
  from model_compression_toolkit.core.exporter import export_model
28
28
  from model_compression_toolkit.core.runner import core_runner
@@ -99,7 +99,7 @@ if FOUND_TF:
99
99
  The candidates bitwidth for quantization should be defined in the target platform model.
100
100
  In this example we use 1 image to search mixed-precision configuration:
101
101
 
102
- >>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfigV2(num_of_images=1))
102
+ >>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=1))
103
103
 
104
104
  For mixed-precision set a target KPI object:
105
105
  Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
@@ -123,14 +123,11 @@ if FOUND_TF:
123
123
  fw_info=fw_info).validate()
124
124
 
125
125
  if core_config.mixed_precision_enable:
126
- if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
126
+ if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
127
127
  Logger.error("Given quantization config to mixed-precision facade is not of type "
128
- "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
128
+ "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
129
129
  "API, or pass a valid mixed precision configuration.") # pragma: no cover
130
130
 
131
- Logger.info("Using experimental mixed-precision quantization. "
132
- "If you encounter an issue please file a bug.")
133
-
134
131
  tb_w = init_tensorboard_writer(fw_info)
135
132
 
136
133
  fw_impl = KerasImplementation()
@@ -22,7 +22,7 @@ from model_compression_toolkit.target_platform_capabilities.target_platform impo
22
22
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
23
23
  from model_compression_toolkit.core import CoreConfig
24
24
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
25
- MixedPrecisionQuantizationConfigV2
25
+ MixedPrecisionQuantizationConfig
26
26
  from model_compression_toolkit.core.runner import core_runner
27
27
  from model_compression_toolkit.ptq.runner import ptq_runner
28
28
  from model_compression_toolkit.core.exporter import export_model
@@ -94,15 +94,12 @@ if FOUND_TORCH:
94
94
  """
95
95
 
96
96
  if core_config.mixed_precision_enable:
97
- if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
97
+ if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
98
98
  Logger.error("Given quantization config to mixed-precision facade is not of type "
99
- "MixedPrecisionQuantizationConfigV2. Please use "
99
+ "MixedPrecisionQuantizationConfig. Please use "
100
100
  "pytorch_post_training_quantization API, or pass a valid mixed precision "
101
101
  "configuration.") # pragma: no cover
102
102
 
103
- Logger.info("Using experimental mixed-precision quantization. "
104
- "If you encounter an issue please file a bug.")
105
-
106
103
  tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
107
104
 
108
105
  fw_impl = PytorchImplementation()
@@ -22,7 +22,7 @@ from model_compression_toolkit.logger import Logger
22
22
  from model_compression_toolkit.constants import FOUND_TF
23
23
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
24
24
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
25
- MixedPrecisionQuantizationConfigV2
25
+ MixedPrecisionQuantizationConfig
26
26
  from mct_quantizers import KerasActivationQuantizationHolder
27
27
  from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
28
28
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
@@ -145,7 +145,7 @@ if FOUND_TF:
145
145
  If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
146
146
  The candidates bitwidth for quantization should be defined in the target platform model:
147
147
 
148
- >>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfigV2())
148
+ >>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig())
149
149
 
150
150
  For mixed-precision set a target KPI object:
151
151
  Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
@@ -170,13 +170,10 @@ if FOUND_TF:
170
170
  fw_info=fw_info).validate()
171
171
 
172
172
  if core_config.mixed_precision_enable:
173
- if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
173
+ if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
174
174
  Logger.error("Given quantization config to mixed-precision facade is not of type "
175
- "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization API,"
176
- "or pass a valid mixed precision configuration.")
177
-
178
- Logger.info("Using experimental mixed-precision quantization. "
179
- "If you encounter an issue please file a bug.")
175
+ "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization API,"
176
+ "or pass a valid mixed precision configuration.")
180
177
 
181
178
  tb_w = init_tensorboard_writer(fw_info)
182
179
 
@@ -239,7 +236,7 @@ if FOUND_TF:
239
236
  If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
240
237
  The candidates bitwidth for quantization should be defined in the target platform model:
241
238
 
242
- >>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfigV2())
239
+ >>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig())
243
240
 
244
241
  For mixed-precision set a target KPI object:
245
242
  Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
@@ -25,7 +25,7 @@ from model_compression_toolkit.logger import Logger
25
25
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
26
26
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
27
27
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
28
- MixedPrecisionQuantizationConfigV2
28
+ MixedPrecisionQuantizationConfig
29
29
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
30
30
  TargetPlatformCapabilities
31
31
  from model_compression_toolkit.core.runner import core_runner
@@ -138,16 +138,12 @@ if FOUND_TORCH:
138
138
  """
139
139
 
140
140
  if core_config.mixed_precision_enable:
141
- if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
141
+ if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
142
142
  Logger.error("Given quantization config to mixed-precision facade is not of type "
143
- "MixedPrecisionQuantizationConfigV2. Please use pytorch_post_training_quantization API,"
143
+ "MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API,"
144
144
  "or pass a valid mixed precision configuration.")
145
145
 
146
- Logger.info("Using experimental mixed-precision quantization. "
147
- "If you encounter an issue please file a bug.")
148
-
149
146
  tb_w = init_tensorboard_writer(fw_info)
150
-
151
147
  fw_impl = PytorchImplementation()
152
148
 
153
149
  # Ignore trace hessian service as we do not use it here
@@ -1,64 +0,0 @@
1
- # Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import numpy as np
17
- from sklearn.cluster import KMeans
18
-
19
- import model_compression_toolkit.core.common.quantization.quantization_config as qc
20
- from model_compression_toolkit.constants import LUT_VALUES, SCALE_PER_CHANNEL, MIN_THRESHOLD, EPS
21
-
22
-
23
- def kmeans_tensor(tensor_data: np.ndarray,
24
- p: int,
25
- n_bits: int,
26
- per_channel: bool = False,
27
- channel_axis: int = 1,
28
- n_iter: int = 10,
29
- min_threshold: float = MIN_THRESHOLD,
30
- quant_error_method: qc.QuantizationErrorMethod = None) -> dict:
31
- """
32
- Compute the 2^nbit cluster assignments for the given tensor according to the k-means algorithm.
33
-
34
- Args:
35
- tensor_data: Tensor content as Numpy array.
36
- p: p-norm to use for the Lp-norm distance.
37
- n_bits: Number of bits to quantize the tensor.
38
- per_channel: Whether the quantization should be per-channel or not.
39
- channel_axis: Output channel index.
40
- n_iter: Number of iterations to search_methods for the optimal threshold.
41
- min_threshold: Minimal threshold to chose when the computed one is smaller.
42
- quant_error_method: an error function to optimize the parameters' selection accordingly (not used for this method).
43
-
44
- Returns:
45
- A dictionary containing the cluster assignments according to the k-means algorithm and the scales per channel.
46
- """
47
- if len(np.unique(tensor_data.flatten())) < 2 ** n_bits:
48
- n_clusters = len(np.unique(tensor_data.flatten()))
49
- else:
50
- n_clusters = 2 ** n_bits
51
- kmeans = KMeans(n_clusters=n_clusters)
52
- axis_not_channel = [i for i in range(len(tensor_data.shape))]
53
- if channel_axis in axis_not_channel:
54
- axis_not_channel.remove(channel_axis)
55
- if per_channel:
56
- scales_per_channel = np.max(np.abs(tensor_data), axis=tuple(axis_not_channel), keepdims=True)
57
- else:
58
- scales_per_channel = np.max(np.abs(tensor_data), keepdims=True)
59
- tensor_for_kmeans = (tensor_data / (scales_per_channel + EPS))
60
- kmeans.fit(tensor_for_kmeans.reshape(-1, 1))
61
-
62
- return {LUT_VALUES: kmeans.cluster_centers_,
63
- SCALE_PER_CHANNEL: scales_per_channel,
64
- }
@@ -1,53 +0,0 @@
1
- # Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from sklearn.cluster import KMeans
17
- import numpy as np
18
-
19
- from model_compression_toolkit.constants import LUT_VALUES, MIN_THRESHOLD, SCALE_PER_CHANNEL
20
- from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import kmeans_assign_clusters
21
-
22
-
23
- def kmeans_quantizer(tensor_data: np.ndarray,
24
- n_bits: int,
25
- signed: bool,
26
- quantization_params: dict,
27
- per_channel: bool,
28
- output_channels_axis: int) -> np.ndarray:
29
- """
30
- Quantize a tensor according to k-means algorithm. This function assigns cluster centers
31
- to the tensor data values.
32
-
33
- Args:
34
- tensor_data: Tensor values to quantize.
35
- n_bits: Number of bits to quantize the tensor.
36
- signed: Whether the tensor contains negative values or not.
37
- quantization_params: Dictionary of specific parameters for this quantization function.
38
- per_channel: Whether to use separate quantization per output channel.
39
- output_channels_axis: Axis of the output channel.
40
-
41
- Returns:
42
- Quantized data.
43
- """
44
- eps = 1e-8
45
- lut_values = quantization_params[LUT_VALUES]
46
- scales_per_channel = quantization_params[SCALE_PER_CHANNEL]
47
- tensor = (tensor_data / (scales_per_channel + eps))
48
- shape_before_kmeans = tensor.shape
49
- cluster_assignments = kmeans_assign_clusters(lut_values, tensor.reshape(-1, 1))
50
- quant_tensor = lut_values[cluster_assignments].reshape(shape_before_kmeans)
51
- if per_channel:
52
- quant_tensor = (quant_tensor * scales_per_channel)
53
- return quant_tensor