mct-nightly 1.11.0.20240304.post404__py3-none-any.whl → 1.11.0.20240306.post426__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/METADATA +5 -5
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/RECORD +42 -40
- model_compression_toolkit/core/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +4 -70
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +19 -1
- model_compression_toolkit/core/common/quantization/core_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +0 -3
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +0 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -2
- model_compression_toolkit/core/keras/kpi_data_facade.py +5 -6
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +19 -19
- model_compression_toolkit/core/pytorch/constants.py +3 -0
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +5 -5
- model_compression_toolkit/core/pytorch/pruning/__init__.py +14 -0
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +315 -0
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +2 -2
- model_compression_toolkit/gptq/__init__.py +1 -1
- model_compression_toolkit/gptq/common/gptq_config.py +5 -72
- model_compression_toolkit/gptq/keras/gptq_training.py +2 -2
- model_compression_toolkit/gptq/keras/quantization_facade.py +19 -33
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +3 -3
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +2 -4
- model_compression_toolkit/gptq/pytorch/gptq_training.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +14 -31
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +3 -3
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +2 -4
- model_compression_toolkit/gptq/runner.py +3 -3
- model_compression_toolkit/pruning/__init__.py +1 -0
- model_compression_toolkit/pruning/pytorch/__init__.py +14 -0
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +166 -0
- model_compression_toolkit/ptq/__init__.py +2 -2
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -30
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +12 -30
- model_compression_toolkit/qat/keras/quantization_facade.py +6 -9
- model_compression_toolkit/qat/pytorch/quantization_facade.py +3 -7
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +0 -64
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +0 -53
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/WHEEL +0 -0
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/top_level.txt +0 -0
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.logger import Logger
|
|
|
22
22
|
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
|
|
23
23
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
24
24
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
25
|
-
|
|
25
|
+
MixedPrecisionQuantizationConfig
|
|
26
26
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
27
27
|
from model_compression_toolkit.core.exporter import export_model
|
|
28
28
|
from model_compression_toolkit.core.runner import core_runner
|
|
@@ -40,12 +40,11 @@ if FOUND_TF:
|
|
|
40
40
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
|
41
41
|
|
|
42
42
|
|
|
43
|
-
def
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
new_experimental_exporter: bool = True):
|
|
43
|
+
def keras_post_training_quantization(in_model: Model,
|
|
44
|
+
representative_data_gen: Callable,
|
|
45
|
+
target_kpi: KPI = None,
|
|
46
|
+
core_config: CoreConfig = CoreConfig(),
|
|
47
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
|
|
49
48
|
"""
|
|
50
49
|
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
|
51
50
|
symmetric constraint quantization thresholds (power of two).
|
|
@@ -65,7 +64,6 @@ if FOUND_TF:
|
|
|
65
64
|
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
|
|
66
65
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
67
66
|
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
68
|
-
new_experimental_exporter (bool): Whether to wrap the quantized model using quantization information or not. Enabled by default. Experimental and subject to future changes.
|
|
69
67
|
|
|
70
68
|
Returns:
|
|
71
69
|
|
|
@@ -99,7 +97,7 @@ if FOUND_TF:
|
|
|
99
97
|
The candidates bitwidth for quantization should be defined in the target platform model.
|
|
100
98
|
In this example we use 1 image to search mixed-precision configuration:
|
|
101
99
|
|
|
102
|
-
>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.
|
|
100
|
+
>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=1))
|
|
103
101
|
|
|
104
102
|
For mixed-precision set a target KPI object:
|
|
105
103
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
@@ -111,7 +109,7 @@ if FOUND_TF:
|
|
|
111
109
|
Pass the model, the representative dataset generator, the configuration and the target KPI to get a
|
|
112
110
|
quantized model:
|
|
113
111
|
|
|
114
|
-
>>> quantized_model, quantization_info = mct.ptq.
|
|
112
|
+
>>> quantized_model, quantization_info = mct.ptq.keras_post_training_quantization(model, repr_datagen, kpi, core_config=config)
|
|
115
113
|
|
|
116
114
|
For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
|
|
117
115
|
|
|
@@ -123,14 +121,11 @@ if FOUND_TF:
|
|
|
123
121
|
fw_info=fw_info).validate()
|
|
124
122
|
|
|
125
123
|
if core_config.mixed_precision_enable:
|
|
126
|
-
if not isinstance(core_config.mixed_precision_config,
|
|
124
|
+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
127
125
|
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
128
|
-
"
|
|
126
|
+
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
|
|
129
127
|
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
130
128
|
|
|
131
|
-
Logger.info("Using experimental mixed-precision quantization. "
|
|
132
|
-
"If you encounter an issue please file a bug.")
|
|
133
|
-
|
|
134
129
|
tb_w = init_tensorboard_writer(fw_info)
|
|
135
130
|
|
|
136
131
|
fw_impl = KerasImplementation()
|
|
@@ -153,26 +148,14 @@ if FOUND_TF:
|
|
|
153
148
|
fw_impl,
|
|
154
149
|
fw_info)
|
|
155
150
|
|
|
156
|
-
|
|
157
|
-
Logger.warning('Using new experimental wrapped and ready for export models. To '
|
|
158
|
-
'disable it, please set new_experimental_exporter to False when '
|
|
159
|
-
'calling keras_post_training_quantization_experimental. '
|
|
160
|
-
'If you encounter an issue please file a bug.')
|
|
161
|
-
|
|
162
|
-
return get_exportable_keras_model(tg)
|
|
163
|
-
|
|
164
|
-
return export_model(tg,
|
|
165
|
-
fw_info,
|
|
166
|
-
fw_impl,
|
|
167
|
-
tb_w,
|
|
168
|
-
bit_widths_config)
|
|
151
|
+
return get_exportable_keras_model(tg)
|
|
169
152
|
|
|
170
153
|
|
|
171
154
|
|
|
172
155
|
else:
|
|
173
156
|
# If tensorflow is not installed,
|
|
174
157
|
# we raise an exception when trying to use these functions.
|
|
175
|
-
def
|
|
158
|
+
def keras_post_training_quantization(*args, **kwargs):
|
|
176
159
|
Logger.critical('Installing tensorflow is mandatory '
|
|
177
|
-
'when using
|
|
160
|
+
'when using keras_post_training_quantization. '
|
|
178
161
|
'Could not find Tensorflow package.') # pragma: no cover
|
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.target_platform_capabilities.target_platform impo
|
|
|
22
22
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
23
23
|
from model_compression_toolkit.core import CoreConfig
|
|
24
24
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
25
|
-
|
|
25
|
+
MixedPrecisionQuantizationConfig
|
|
26
26
|
from model_compression_toolkit.core.runner import core_runner
|
|
27
27
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
28
28
|
from model_compression_toolkit.core.exporter import export_model
|
|
@@ -39,12 +39,11 @@ if FOUND_TORCH:
|
|
|
39
39
|
|
|
40
40
|
DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
|
41
41
|
|
|
42
|
-
def
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
new_experimental_exporter: bool = True):
|
|
42
|
+
def pytorch_post_training_quantization(in_module: Module,
|
|
43
|
+
representative_data_gen: Callable,
|
|
44
|
+
target_kpi: KPI = None,
|
|
45
|
+
core_config: CoreConfig = CoreConfig(),
|
|
46
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
|
|
48
47
|
"""
|
|
49
48
|
Quantize a trained Pytorch module using post-training quantization.
|
|
50
49
|
By default, the module is quantized using a symmetric constraint quantization thresholds
|
|
@@ -64,7 +63,6 @@ if FOUND_TORCH:
|
|
|
64
63
|
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
|
|
65
64
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
66
65
|
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
67
|
-
new_experimental_exporter (bool): Whether to wrap the quantized model using quantization information or not. Enabled by default. Experimental and subject to future changes.
|
|
68
66
|
|
|
69
67
|
Returns:
|
|
70
68
|
A quantized module and information the user may need to handle the quantized module.
|
|
@@ -89,20 +87,17 @@ if FOUND_TORCH:
|
|
|
89
87
|
Set number of clibration iterations to 1:
|
|
90
88
|
|
|
91
89
|
>>> import model_compression_toolkit as mct
|
|
92
|
-
>>> quantized_module, quantization_info = mct.ptq.
|
|
90
|
+
>>> quantized_module, quantization_info = mct.ptq.pytorch_post_training_quantization(module, repr_datagen)
|
|
93
91
|
|
|
94
92
|
"""
|
|
95
93
|
|
|
96
94
|
if core_config.mixed_precision_enable:
|
|
97
|
-
if not isinstance(core_config.mixed_precision_config,
|
|
95
|
+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
98
96
|
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
99
|
-
"
|
|
97
|
+
"MixedPrecisionQuantizationConfig. Please use "
|
|
100
98
|
"pytorch_post_training_quantization API, or pass a valid mixed precision "
|
|
101
99
|
"configuration.") # pragma: no cover
|
|
102
100
|
|
|
103
|
-
Logger.info("Using experimental mixed-precision quantization. "
|
|
104
|
-
"If you encounter an issue please file a bug.")
|
|
105
|
-
|
|
106
101
|
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
|
107
102
|
|
|
108
103
|
fw_impl = PytorchImplementation()
|
|
@@ -126,26 +121,13 @@ if FOUND_TORCH:
|
|
|
126
121
|
fw_impl,
|
|
127
122
|
DEFAULT_PYTORCH_INFO)
|
|
128
123
|
|
|
129
|
-
|
|
130
|
-
Logger.warning('Using new experimental wrapped and ready for export models. To '
|
|
131
|
-
'disable it, please set new_experimental_exporter to False when '
|
|
132
|
-
'calling pytorch_post_training_quantization_experimental. '
|
|
133
|
-
'If you encounter an issue please file a bug.')
|
|
134
|
-
|
|
135
|
-
return get_exportable_pytorch_model(tg)
|
|
136
|
-
|
|
137
|
-
quantized_model, user_info = export_model(tg,
|
|
138
|
-
DEFAULT_PYTORCH_INFO,
|
|
139
|
-
fw_impl,
|
|
140
|
-
tb_w,
|
|
141
|
-
bit_widths_config)
|
|
124
|
+
return get_exportable_pytorch_model(tg)
|
|
142
125
|
|
|
143
|
-
return quantized_model, user_info
|
|
144
126
|
|
|
145
127
|
else:
|
|
146
128
|
# If torch is not installed,
|
|
147
129
|
# we raise an exception when trying to use these functions.
|
|
148
|
-
def
|
|
130
|
+
def pytorch_post_training_quantization(*args, **kwargs):
|
|
149
131
|
Logger.critical('Installing Pytorch is mandatory '
|
|
150
|
-
'when using
|
|
132
|
+
'when using pytorch_post_training_quantization. '
|
|
151
133
|
'Could not find the torch package.') # pragma: no cover
|
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.logger import Logger
|
|
|
22
22
|
from model_compression_toolkit.constants import FOUND_TF
|
|
23
23
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
24
24
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
25
|
-
|
|
25
|
+
MixedPrecisionQuantizationConfig
|
|
26
26
|
from mct_quantizers import KerasActivationQuantizationHolder
|
|
27
27
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
|
28
28
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
@@ -145,7 +145,7 @@ if FOUND_TF:
|
|
|
145
145
|
If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
|
|
146
146
|
The candidates bitwidth for quantization should be defined in the target platform model:
|
|
147
147
|
|
|
148
|
-
>>> config = mct.core.CoreConfig(mixed_precision_config=
|
|
148
|
+
>>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig())
|
|
149
149
|
|
|
150
150
|
For mixed-precision set a target KPI object:
|
|
151
151
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
@@ -170,13 +170,10 @@ if FOUND_TF:
|
|
|
170
170
|
fw_info=fw_info).validate()
|
|
171
171
|
|
|
172
172
|
if core_config.mixed_precision_enable:
|
|
173
|
-
if not isinstance(core_config.mixed_precision_config,
|
|
173
|
+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
174
174
|
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
Logger.info("Using experimental mixed-precision quantization. "
|
|
179
|
-
"If you encounter an issue please file a bug.")
|
|
175
|
+
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization API,"
|
|
176
|
+
"or pass a valid mixed precision configuration.")
|
|
180
177
|
|
|
181
178
|
tb_w = init_tensorboard_writer(fw_info)
|
|
182
179
|
|
|
@@ -239,7 +236,7 @@ if FOUND_TF:
|
|
|
239
236
|
If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
|
|
240
237
|
The candidates bitwidth for quantization should be defined in the target platform model:
|
|
241
238
|
|
|
242
|
-
>>> config = mct.core.CoreConfig(mixed_precision_config=
|
|
239
|
+
>>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig())
|
|
243
240
|
|
|
244
241
|
For mixed-precision set a target KPI object:
|
|
245
242
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
@@ -25,7 +25,7 @@ from model_compression_toolkit.logger import Logger
|
|
|
25
25
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
26
26
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
27
27
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
28
|
-
|
|
28
|
+
MixedPrecisionQuantizationConfig
|
|
29
29
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
|
|
30
30
|
TargetPlatformCapabilities
|
|
31
31
|
from model_compression_toolkit.core.runner import core_runner
|
|
@@ -138,16 +138,12 @@ if FOUND_TORCH:
|
|
|
138
138
|
"""
|
|
139
139
|
|
|
140
140
|
if core_config.mixed_precision_enable:
|
|
141
|
-
if not isinstance(core_config.mixed_precision_config,
|
|
141
|
+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
142
142
|
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
143
|
-
"
|
|
143
|
+
"MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API,"
|
|
144
144
|
"or pass a valid mixed precision configuration.")
|
|
145
145
|
|
|
146
|
-
Logger.info("Using experimental mixed-precision quantization. "
|
|
147
|
-
"If you encounter an issue please file a bug.")
|
|
148
|
-
|
|
149
146
|
tb_w = init_tensorboard_writer(fw_info)
|
|
150
|
-
|
|
151
147
|
fw_impl = PytorchImplementation()
|
|
152
148
|
|
|
153
149
|
# Ignore trace hessian service as we do not use it here
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py
DELETED
|
@@ -1,64 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
import numpy as np
|
|
17
|
-
from sklearn.cluster import KMeans
|
|
18
|
-
|
|
19
|
-
import model_compression_toolkit.core.common.quantization.quantization_config as qc
|
|
20
|
-
from model_compression_toolkit.constants import LUT_VALUES, SCALE_PER_CHANNEL, MIN_THRESHOLD, EPS
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def kmeans_tensor(tensor_data: np.ndarray,
|
|
24
|
-
p: int,
|
|
25
|
-
n_bits: int,
|
|
26
|
-
per_channel: bool = False,
|
|
27
|
-
channel_axis: int = 1,
|
|
28
|
-
n_iter: int = 10,
|
|
29
|
-
min_threshold: float = MIN_THRESHOLD,
|
|
30
|
-
quant_error_method: qc.QuantizationErrorMethod = None) -> dict:
|
|
31
|
-
"""
|
|
32
|
-
Compute the 2^nbit cluster assignments for the given tensor according to the k-means algorithm.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
tensor_data: Tensor content as Numpy array.
|
|
36
|
-
p: p-norm to use for the Lp-norm distance.
|
|
37
|
-
n_bits: Number of bits to quantize the tensor.
|
|
38
|
-
per_channel: Whether the quantization should be per-channel or not.
|
|
39
|
-
channel_axis: Output channel index.
|
|
40
|
-
n_iter: Number of iterations to search_methods for the optimal threshold.
|
|
41
|
-
min_threshold: Minimal threshold to chose when the computed one is smaller.
|
|
42
|
-
quant_error_method: an error function to optimize the parameters' selection accordingly (not used for this method).
|
|
43
|
-
|
|
44
|
-
Returns:
|
|
45
|
-
A dictionary containing the cluster assignments according to the k-means algorithm and the scales per channel.
|
|
46
|
-
"""
|
|
47
|
-
if len(np.unique(tensor_data.flatten())) < 2 ** n_bits:
|
|
48
|
-
n_clusters = len(np.unique(tensor_data.flatten()))
|
|
49
|
-
else:
|
|
50
|
-
n_clusters = 2 ** n_bits
|
|
51
|
-
kmeans = KMeans(n_clusters=n_clusters)
|
|
52
|
-
axis_not_channel = [i for i in range(len(tensor_data.shape))]
|
|
53
|
-
if channel_axis in axis_not_channel:
|
|
54
|
-
axis_not_channel.remove(channel_axis)
|
|
55
|
-
if per_channel:
|
|
56
|
-
scales_per_channel = np.max(np.abs(tensor_data), axis=tuple(axis_not_channel), keepdims=True)
|
|
57
|
-
else:
|
|
58
|
-
scales_per_channel = np.max(np.abs(tensor_data), keepdims=True)
|
|
59
|
-
tensor_for_kmeans = (tensor_data / (scales_per_channel + EPS))
|
|
60
|
-
kmeans.fit(tensor_for_kmeans.reshape(-1, 1))
|
|
61
|
-
|
|
62
|
-
return {LUT_VALUES: kmeans.cluster_centers_,
|
|
63
|
-
SCALE_PER_CHANNEL: scales_per_channel,
|
|
64
|
-
}
|
|
@@ -1,53 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from sklearn.cluster import KMeans
|
|
17
|
-
import numpy as np
|
|
18
|
-
|
|
19
|
-
from model_compression_toolkit.constants import LUT_VALUES, MIN_THRESHOLD, SCALE_PER_CHANNEL
|
|
20
|
-
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import kmeans_assign_clusters
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def kmeans_quantizer(tensor_data: np.ndarray,
|
|
24
|
-
n_bits: int,
|
|
25
|
-
signed: bool,
|
|
26
|
-
quantization_params: dict,
|
|
27
|
-
per_channel: bool,
|
|
28
|
-
output_channels_axis: int) -> np.ndarray:
|
|
29
|
-
"""
|
|
30
|
-
Quantize a tensor according to k-means algorithm. This function assigns cluster centers
|
|
31
|
-
to the tensor data values.
|
|
32
|
-
|
|
33
|
-
Args:
|
|
34
|
-
tensor_data: Tensor values to quantize.
|
|
35
|
-
n_bits: Number of bits to quantize the tensor.
|
|
36
|
-
signed: Whether the tensor contains negative values or not.
|
|
37
|
-
quantization_params: Dictionary of specific parameters for this quantization function.
|
|
38
|
-
per_channel: Whether to use separate quantization per output channel.
|
|
39
|
-
output_channels_axis: Axis of the output channel.
|
|
40
|
-
|
|
41
|
-
Returns:
|
|
42
|
-
Quantized data.
|
|
43
|
-
"""
|
|
44
|
-
eps = 1e-8
|
|
45
|
-
lut_values = quantization_params[LUT_VALUES]
|
|
46
|
-
scales_per_channel = quantization_params[SCALE_PER_CHANNEL]
|
|
47
|
-
tensor = (tensor_data / (scales_per_channel + eps))
|
|
48
|
-
shape_before_kmeans = tensor.shape
|
|
49
|
-
cluster_assignments = kmeans_assign_clusters(lut_values, tensor.reshape(-1, 1))
|
|
50
|
-
quant_tensor = lut_values[cluster_assignments].reshape(shape_before_kmeans)
|
|
51
|
-
if per_channel:
|
|
52
|
-
quant_tensor = (quant_tensor * scales_per_channel)
|
|
53
|
-
return quant_tensor
|
|
File without changes
|
|
File without changes
|
|
File without changes
|