mct-nightly 1.11.0.20240304.post404__py3-none-any.whl → 1.11.0.20240305.post352__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/METADATA +5 -5
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/RECORD +32 -30
- model_compression_toolkit/core/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +4 -70
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +19 -1
- model_compression_toolkit/core/common/quantization/core_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +0 -3
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +0 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -2
- model_compression_toolkit/core/keras/kpi_data_facade.py +5 -6
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +19 -19
- model_compression_toolkit/core/pytorch/constants.py +3 -0
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +5 -5
- model_compression_toolkit/core/pytorch/pruning/__init__.py +14 -0
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +315 -0
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +2 -2
- model_compression_toolkit/gptq/keras/quantization_facade.py +4 -4
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +3 -3
- model_compression_toolkit/pruning/__init__.py +1 -0
- model_compression_toolkit/pruning/pytorch/__init__.py +14 -0
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +166 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +4 -7
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +3 -6
- model_compression_toolkit/qat/keras/quantization_facade.py +6 -9
- model_compression_toolkit/qat/pytorch/quantization_facade.py +3 -7
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +0 -64
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +0 -53
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/WHEEL +0 -0
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Callable, Tuple
|
|
17
|
+
from model_compression_toolkit import get_target_platform_capabilities
|
|
18
|
+
from model_compression_toolkit.constants import FOUND_TORCH, PYTORCH
|
|
19
|
+
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
20
|
+
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
|
21
|
+
from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
|
|
22
|
+
from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
|
|
23
|
+
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
|
24
|
+
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
|
25
|
+
from model_compression_toolkit.logger import Logger
|
|
26
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
27
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
|
28
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Check if PyTorch is available in the environment.
|
|
32
|
+
if FOUND_TORCH:
|
|
33
|
+
# Import PyTorch-specific modules from the model compression toolkit.
|
|
34
|
+
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
|
|
35
|
+
from model_compression_toolkit.core.pytorch.pruning.pruning_pytorch_implementation import \
|
|
36
|
+
PruningPytorchImplementation
|
|
37
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
38
|
+
from torch.nn import Module
|
|
39
|
+
|
|
40
|
+
# Set the default Target Platform Capabilities (TPC) for PyTorch.
|
|
41
|
+
DEFAULT_PYOTRCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
|
42
|
+
|
|
43
|
+
def pytorch_pruning_experimental(model: Module,
|
|
44
|
+
target_kpi: KPI,
|
|
45
|
+
representative_data_gen: Callable,
|
|
46
|
+
pruning_config: PruningConfig = PruningConfig(),
|
|
47
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYOTRCH_TPC) -> \
|
|
48
|
+
Tuple[Module, PruningInfo]:
|
|
49
|
+
"""
|
|
50
|
+
Perform structured pruning on a Pytorch model to meet a specified target KPI.
|
|
51
|
+
This function prunes the provided model according to the target KPI by grouping and pruning
|
|
52
|
+
channels based on each layer's SIMD configuration in the Target Platform Capabilities (TPC).
|
|
53
|
+
By default, the importance of each channel group is determined using the Label-Free Hessian
|
|
54
|
+
(LFH) method, assessing each channel's sensitivity to the Hessian of the loss function.
|
|
55
|
+
This pruning strategy considers groups of channels together for a more hardware-friendly
|
|
56
|
+
architecture. The process involves analyzing the model with a representative dataset to
|
|
57
|
+
identify groups of channels that can be removed with minimal impact on performance.
|
|
58
|
+
|
|
59
|
+
Notice that the pruned model must be retrained to recover the compressed model's performance.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
model (Module): The PyTorch model to be pruned.
|
|
63
|
+
target_kpi (KPI): Key Performance Indicators specifying the pruning targets.
|
|
64
|
+
representative_data_gen (Callable): A function to generate representative data for pruning analysis.
|
|
65
|
+
pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config.
|
|
66
|
+
target_platform_capabilities (TargetPlatformCapabilities): Platform-specific constraints and capabilities.
|
|
67
|
+
Defaults to DEFAULT_PYTORCH_TPC.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Tuple[Model, PruningInfo]: A tuple containing the pruned Pytorch model and associated pruning information.
|
|
71
|
+
|
|
72
|
+
Note:
|
|
73
|
+
The pruned model should be fine-tuned or retrained to recover or improve its performance post-pruning.
|
|
74
|
+
|
|
75
|
+
Examples:
|
|
76
|
+
|
|
77
|
+
Import MCT:
|
|
78
|
+
|
|
79
|
+
>>> import model_compression_toolkit as mct
|
|
80
|
+
|
|
81
|
+
Import a Pytorch model:
|
|
82
|
+
|
|
83
|
+
>>> from torchvision.models import resnet50, ResNet50_Weights
|
|
84
|
+
>>> model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
|
85
|
+
|
|
86
|
+
Create a random dataset generator:
|
|
87
|
+
|
|
88
|
+
>>> import numpy as np
|
|
89
|
+
>>> def repr_datagen(): yield [np.random.random((1, 3, 224, 224))]
|
|
90
|
+
|
|
91
|
+
Define a target KPI for pruning.
|
|
92
|
+
Here, we aim to reduce the memory footprint of weights by 50%, assuming the model weights
|
|
93
|
+
are represented in float32 data type (thus, each parameter is represented using 4 bytes):
|
|
94
|
+
|
|
95
|
+
>>> dense_nparams = sum(p.numel() for p in model.state_dict().values())
|
|
96
|
+
>>> target_kpi = mct.KPI(weights_memory=dense_nparams * 4 * 0.5)
|
|
97
|
+
|
|
98
|
+
Optionally, define a pruning configuration. num_score_approximations can be passed
|
|
99
|
+
to configure the number of importance scores that will be calculated for each channel.
|
|
100
|
+
A higher value for this parameter yields more precise score approximations but also
|
|
101
|
+
extends the duration of the pruning process:
|
|
102
|
+
|
|
103
|
+
>>> pruning_config = mct.pruning.PruningConfig(num_score_approximations=1)
|
|
104
|
+
|
|
105
|
+
Perform pruning:
|
|
106
|
+
|
|
107
|
+
>>> pruned_model, pruning_info = mct.pruning.pytorch_pruning_experimental(model=model, target_kpi=target_kpi, representative_data_gen=repr_datagen, pruning_config=pruning_config)
|
|
108
|
+
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
# Instantiate the Pytorch framework implementation.
|
|
112
|
+
fw_impl = PruningPytorchImplementation()
|
|
113
|
+
|
|
114
|
+
# Convert the original Pytorch model to an internal graph representation.
|
|
115
|
+
float_graph = read_model_to_graph(model,
|
|
116
|
+
representative_data_gen,
|
|
117
|
+
target_platform_capabilities,
|
|
118
|
+
DEFAULT_PYTORCH_INFO,
|
|
119
|
+
fw_impl)
|
|
120
|
+
|
|
121
|
+
# Apply quantization configuration to the graph. This step is necessary even when not quantizing,
|
|
122
|
+
# as it prepares the graph for the pruning process.
|
|
123
|
+
float_graph_with_compression_config = set_quantization_configuration_to_graph(float_graph,
|
|
124
|
+
quant_config=DEFAULTCONFIG,
|
|
125
|
+
mixed_precision_enable=False)
|
|
126
|
+
|
|
127
|
+
# Create a Pruner object with the graph and configuration.
|
|
128
|
+
pruner = Pruner(float_graph_with_compression_config,
|
|
129
|
+
DEFAULT_PYTORCH_INFO,
|
|
130
|
+
fw_impl,
|
|
131
|
+
target_kpi,
|
|
132
|
+
representative_data_gen,
|
|
133
|
+
pruning_config,
|
|
134
|
+
target_platform_capabilities)
|
|
135
|
+
|
|
136
|
+
# Apply the pruning process.
|
|
137
|
+
pruned_graph = pruner.prune_graph()
|
|
138
|
+
|
|
139
|
+
# Retrieve pruning information which includes the pruning masks and scores.
|
|
140
|
+
pruning_info = pruner.get_pruning_info()
|
|
141
|
+
|
|
142
|
+
# Rebuild the pruned graph back into a trainable Pytorch model.
|
|
143
|
+
pruned_model, _ = FloatPyTorchModelBuilder(graph=pruned_graph).build_model()
|
|
144
|
+
pruned_model.trainable = True
|
|
145
|
+
|
|
146
|
+
# Return the pruned model along with its pruning information.
|
|
147
|
+
return pruned_model, pruning_info
|
|
148
|
+
|
|
149
|
+
else:
|
|
150
|
+
def pytorch_pruning_experimental(*args, **kwargs):
|
|
151
|
+
"""
|
|
152
|
+
Raises a critical error if PyTorch is not installed but the pruning function is invoked.
|
|
153
|
+
|
|
154
|
+
This function acts as a placeholder to provide a clear error message when PyTorch dependencies are missing,
|
|
155
|
+
indicating that the pruning functionality cannot be used without the PyTorch framework installed.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
*args: Variable length argument list, not used.
|
|
159
|
+
**kwargs: Arbitrary keyword arguments, not used.
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
CriticalError: Indicates that PyTorch must be installed to use this function.
|
|
163
|
+
"""
|
|
164
|
+
Logger.critical('Installing Pytorch is mandatory '
|
|
165
|
+
'when using pytorch_pruning_experimental. '
|
|
166
|
+
'Could not find the torch package.') # pragma: no cover
|
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.logger import Logger
|
|
|
22
22
|
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
|
|
23
23
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
24
24
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
25
|
-
|
|
25
|
+
MixedPrecisionQuantizationConfig
|
|
26
26
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
27
27
|
from model_compression_toolkit.core.exporter import export_model
|
|
28
28
|
from model_compression_toolkit.core.runner import core_runner
|
|
@@ -99,7 +99,7 @@ if FOUND_TF:
|
|
|
99
99
|
The candidates bitwidth for quantization should be defined in the target platform model.
|
|
100
100
|
In this example we use 1 image to search mixed-precision configuration:
|
|
101
101
|
|
|
102
|
-
>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.
|
|
102
|
+
>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=1))
|
|
103
103
|
|
|
104
104
|
For mixed-precision set a target KPI object:
|
|
105
105
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
@@ -123,14 +123,11 @@ if FOUND_TF:
|
|
|
123
123
|
fw_info=fw_info).validate()
|
|
124
124
|
|
|
125
125
|
if core_config.mixed_precision_enable:
|
|
126
|
-
if not isinstance(core_config.mixed_precision_config,
|
|
126
|
+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
127
127
|
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
128
|
-
"
|
|
128
|
+
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
|
|
129
129
|
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
130
130
|
|
|
131
|
-
Logger.info("Using experimental mixed-precision quantization. "
|
|
132
|
-
"If you encounter an issue please file a bug.")
|
|
133
|
-
|
|
134
131
|
tb_w = init_tensorboard_writer(fw_info)
|
|
135
132
|
|
|
136
133
|
fw_impl = KerasImplementation()
|
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.target_platform_capabilities.target_platform impo
|
|
|
22
22
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
23
23
|
from model_compression_toolkit.core import CoreConfig
|
|
24
24
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
25
|
-
|
|
25
|
+
MixedPrecisionQuantizationConfig
|
|
26
26
|
from model_compression_toolkit.core.runner import core_runner
|
|
27
27
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
28
28
|
from model_compression_toolkit.core.exporter import export_model
|
|
@@ -94,15 +94,12 @@ if FOUND_TORCH:
|
|
|
94
94
|
"""
|
|
95
95
|
|
|
96
96
|
if core_config.mixed_precision_enable:
|
|
97
|
-
if not isinstance(core_config.mixed_precision_config,
|
|
97
|
+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
98
98
|
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
99
|
-
"
|
|
99
|
+
"MixedPrecisionQuantizationConfig. Please use "
|
|
100
100
|
"pytorch_post_training_quantization API, or pass a valid mixed precision "
|
|
101
101
|
"configuration.") # pragma: no cover
|
|
102
102
|
|
|
103
|
-
Logger.info("Using experimental mixed-precision quantization. "
|
|
104
|
-
"If you encounter an issue please file a bug.")
|
|
105
|
-
|
|
106
103
|
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
|
107
104
|
|
|
108
105
|
fw_impl = PytorchImplementation()
|
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.logger import Logger
|
|
|
22
22
|
from model_compression_toolkit.constants import FOUND_TF
|
|
23
23
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
24
24
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
25
|
-
|
|
25
|
+
MixedPrecisionQuantizationConfig
|
|
26
26
|
from mct_quantizers import KerasActivationQuantizationHolder
|
|
27
27
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
|
28
28
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
@@ -145,7 +145,7 @@ if FOUND_TF:
|
|
|
145
145
|
If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
|
|
146
146
|
The candidates bitwidth for quantization should be defined in the target platform model:
|
|
147
147
|
|
|
148
|
-
>>> config = mct.core.CoreConfig(mixed_precision_config=
|
|
148
|
+
>>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig())
|
|
149
149
|
|
|
150
150
|
For mixed-precision set a target KPI object:
|
|
151
151
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
@@ -170,13 +170,10 @@ if FOUND_TF:
|
|
|
170
170
|
fw_info=fw_info).validate()
|
|
171
171
|
|
|
172
172
|
if core_config.mixed_precision_enable:
|
|
173
|
-
if not isinstance(core_config.mixed_precision_config,
|
|
173
|
+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
174
174
|
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
Logger.info("Using experimental mixed-precision quantization. "
|
|
179
|
-
"If you encounter an issue please file a bug.")
|
|
175
|
+
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization API,"
|
|
176
|
+
"or pass a valid mixed precision configuration.")
|
|
180
177
|
|
|
181
178
|
tb_w = init_tensorboard_writer(fw_info)
|
|
182
179
|
|
|
@@ -239,7 +236,7 @@ if FOUND_TF:
|
|
|
239
236
|
If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
|
|
240
237
|
The candidates bitwidth for quantization should be defined in the target platform model:
|
|
241
238
|
|
|
242
|
-
>>> config = mct.core.CoreConfig(mixed_precision_config=
|
|
239
|
+
>>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig())
|
|
243
240
|
|
|
244
241
|
For mixed-precision set a target KPI object:
|
|
245
242
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
@@ -25,7 +25,7 @@ from model_compression_toolkit.logger import Logger
|
|
|
25
25
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
26
26
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
27
27
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
28
|
-
|
|
28
|
+
MixedPrecisionQuantizationConfig
|
|
29
29
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
|
|
30
30
|
TargetPlatformCapabilities
|
|
31
31
|
from model_compression_toolkit.core.runner import core_runner
|
|
@@ -138,16 +138,12 @@ if FOUND_TORCH:
|
|
|
138
138
|
"""
|
|
139
139
|
|
|
140
140
|
if core_config.mixed_precision_enable:
|
|
141
|
-
if not isinstance(core_config.mixed_precision_config,
|
|
141
|
+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
142
142
|
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
143
|
-
"
|
|
143
|
+
"MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API,"
|
|
144
144
|
"or pass a valid mixed precision configuration.")
|
|
145
145
|
|
|
146
|
-
Logger.info("Using experimental mixed-precision quantization. "
|
|
147
|
-
"If you encounter an issue please file a bug.")
|
|
148
|
-
|
|
149
146
|
tb_w = init_tensorboard_writer(fw_info)
|
|
150
|
-
|
|
151
147
|
fw_impl = PytorchImplementation()
|
|
152
148
|
|
|
153
149
|
# Ignore trace hessian service as we do not use it here
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py
DELETED
|
@@ -1,64 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
import numpy as np
|
|
17
|
-
from sklearn.cluster import KMeans
|
|
18
|
-
|
|
19
|
-
import model_compression_toolkit.core.common.quantization.quantization_config as qc
|
|
20
|
-
from model_compression_toolkit.constants import LUT_VALUES, SCALE_PER_CHANNEL, MIN_THRESHOLD, EPS
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def kmeans_tensor(tensor_data: np.ndarray,
|
|
24
|
-
p: int,
|
|
25
|
-
n_bits: int,
|
|
26
|
-
per_channel: bool = False,
|
|
27
|
-
channel_axis: int = 1,
|
|
28
|
-
n_iter: int = 10,
|
|
29
|
-
min_threshold: float = MIN_THRESHOLD,
|
|
30
|
-
quant_error_method: qc.QuantizationErrorMethod = None) -> dict:
|
|
31
|
-
"""
|
|
32
|
-
Compute the 2^nbit cluster assignments for the given tensor according to the k-means algorithm.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
tensor_data: Tensor content as Numpy array.
|
|
36
|
-
p: p-norm to use for the Lp-norm distance.
|
|
37
|
-
n_bits: Number of bits to quantize the tensor.
|
|
38
|
-
per_channel: Whether the quantization should be per-channel or not.
|
|
39
|
-
channel_axis: Output channel index.
|
|
40
|
-
n_iter: Number of iterations to search_methods for the optimal threshold.
|
|
41
|
-
min_threshold: Minimal threshold to chose when the computed one is smaller.
|
|
42
|
-
quant_error_method: an error function to optimize the parameters' selection accordingly (not used for this method).
|
|
43
|
-
|
|
44
|
-
Returns:
|
|
45
|
-
A dictionary containing the cluster assignments according to the k-means algorithm and the scales per channel.
|
|
46
|
-
"""
|
|
47
|
-
if len(np.unique(tensor_data.flatten())) < 2 ** n_bits:
|
|
48
|
-
n_clusters = len(np.unique(tensor_data.flatten()))
|
|
49
|
-
else:
|
|
50
|
-
n_clusters = 2 ** n_bits
|
|
51
|
-
kmeans = KMeans(n_clusters=n_clusters)
|
|
52
|
-
axis_not_channel = [i for i in range(len(tensor_data.shape))]
|
|
53
|
-
if channel_axis in axis_not_channel:
|
|
54
|
-
axis_not_channel.remove(channel_axis)
|
|
55
|
-
if per_channel:
|
|
56
|
-
scales_per_channel = np.max(np.abs(tensor_data), axis=tuple(axis_not_channel), keepdims=True)
|
|
57
|
-
else:
|
|
58
|
-
scales_per_channel = np.max(np.abs(tensor_data), keepdims=True)
|
|
59
|
-
tensor_for_kmeans = (tensor_data / (scales_per_channel + EPS))
|
|
60
|
-
kmeans.fit(tensor_for_kmeans.reshape(-1, 1))
|
|
61
|
-
|
|
62
|
-
return {LUT_VALUES: kmeans.cluster_centers_,
|
|
63
|
-
SCALE_PER_CHANNEL: scales_per_channel,
|
|
64
|
-
}
|
|
@@ -1,53 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from sklearn.cluster import KMeans
|
|
17
|
-
import numpy as np
|
|
18
|
-
|
|
19
|
-
from model_compression_toolkit.constants import LUT_VALUES, MIN_THRESHOLD, SCALE_PER_CHANNEL
|
|
20
|
-
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import kmeans_assign_clusters
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def kmeans_quantizer(tensor_data: np.ndarray,
|
|
24
|
-
n_bits: int,
|
|
25
|
-
signed: bool,
|
|
26
|
-
quantization_params: dict,
|
|
27
|
-
per_channel: bool,
|
|
28
|
-
output_channels_axis: int) -> np.ndarray:
|
|
29
|
-
"""
|
|
30
|
-
Quantize a tensor according to k-means algorithm. This function assigns cluster centers
|
|
31
|
-
to the tensor data values.
|
|
32
|
-
|
|
33
|
-
Args:
|
|
34
|
-
tensor_data: Tensor values to quantize.
|
|
35
|
-
n_bits: Number of bits to quantize the tensor.
|
|
36
|
-
signed: Whether the tensor contains negative values or not.
|
|
37
|
-
quantization_params: Dictionary of specific parameters for this quantization function.
|
|
38
|
-
per_channel: Whether to use separate quantization per output channel.
|
|
39
|
-
output_channels_axis: Axis of the output channel.
|
|
40
|
-
|
|
41
|
-
Returns:
|
|
42
|
-
Quantized data.
|
|
43
|
-
"""
|
|
44
|
-
eps = 1e-8
|
|
45
|
-
lut_values = quantization_params[LUT_VALUES]
|
|
46
|
-
scales_per_channel = quantization_params[SCALE_PER_CHANNEL]
|
|
47
|
-
tensor = (tensor_data / (scales_per_channel + eps))
|
|
48
|
-
shape_before_kmeans = tensor.shape
|
|
49
|
-
cluster_assignments = kmeans_assign_clusters(lut_values, tensor.reshape(-1, 1))
|
|
50
|
-
quant_tensor = lut_values[cluster_assignments].reshape(shape_before_kmeans)
|
|
51
|
-
if per_channel:
|
|
52
|
-
quant_tensor = (quant_tensor * scales_per_channel)
|
|
53
|
-
return quant_tensor
|
|
File without changes
|
|
File without changes
|
|
File without changes
|