mct-nightly 0.0.0__py3-none-any.whl → 1.1.0.02122021-003117__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/METADATA +3 -2
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/RECORD +31 -38
- model_compression_toolkit/__init__.py +2 -6
- model_compression_toolkit/common/base_substitutions.py +1 -0
- model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +9 -12
- model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +8 -21
- model_compression_toolkit/common/collectors/histogram_collector.py +1 -1
- model_compression_toolkit/common/graph/base_graph.py +2 -4
- model_compression_toolkit/common/graph/graph_matchers.py +3 -1
- model_compression_toolkit/common/graph/graph_searches.py +3 -1
- model_compression_toolkit/common/mixed_precision/bit_width_setter.py +1 -2
- model_compression_toolkit/common/network_editors/node_filters.py +1 -0
- model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py +1 -1
- model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +3 -5
- model_compression_toolkit/common/quantization/quantize_graph_weights.py +4 -7
- model_compression_toolkit/common/quantization/quantize_node.py +3 -5
- model_compression_toolkit/keras/__init__.py +2 -0
- model_compression_toolkit/keras/back2framework/model_builder.py +24 -1
- model_compression_toolkit/{common → keras/back2framework}/model_collector.py +9 -18
- model_compression_toolkit/keras/default_framework_info.py +0 -1
- model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +57 -10
- model_compression_toolkit/keras/graph_substitutions/substituter.py +171 -0
- model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +26 -6
- model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +12 -5
- model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +3 -4
- model_compression_toolkit/keras/quantization_facade.py +524 -188
- model_compression_toolkit/keras/reader/connectivity_handler.py +4 -1
- model_compression_toolkit/keras/visualization/nn_visualizer.py +1 -2
- model_compression_toolkit/common/framework_implementation.py +0 -239
- model_compression_toolkit/common/gptq/__init__.py +0 -14
- model_compression_toolkit/common/gptq/gptq_config.py +0 -65
- model_compression_toolkit/common/model_builder_mode.py +0 -34
- model_compression_toolkit/common/post_training_quantization.py +0 -459
- model_compression_toolkit/common/substitutions/__init__.py +0 -14
- model_compression_toolkit/common/substitutions/apply_substitutions.py +0 -40
- model_compression_toolkit/keras/keras_implementation.py +0 -256
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/LICENSE +0 -0
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/WHEEL +0 -0
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/top_level.txt +0 -0
|
@@ -1,256 +0,0 @@
|
|
|
1
|
-
from typing import List, Any, Tuple, Callable
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import tensorflow as tf
|
|
5
|
-
from tensorflow.keras.models import Model
|
|
6
|
-
|
|
7
|
-
from model_compression_toolkit import QuantizationConfig, FrameworkInfo, common, GradientPTQConfig, \
|
|
8
|
-
MixedPrecisionQuantizationConfig
|
|
9
|
-
from model_compression_toolkit.common import Graph, Node
|
|
10
|
-
from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
|
|
11
|
-
from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
|
|
12
|
-
from model_compression_toolkit.common.user_info import UserInformation
|
|
13
|
-
from model_compression_toolkit.keras.back2framework.model_builder import model_builder
|
|
14
|
-
from model_compression_toolkit.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
15
|
-
from model_compression_toolkit.keras.gradient_ptq.training_wrapper import gptq_training_wrapper
|
|
16
|
-
from model_compression_toolkit.keras.graph_substitutions.substitutions.activation_decomposition import \
|
|
17
|
-
ActivationDecomposition
|
|
18
|
-
from model_compression_toolkit.keras.graph_substitutions.substitutions.batchnorm_folding import \
|
|
19
|
-
BatchNormalizationFolding
|
|
20
|
-
from model_compression_toolkit.keras.graph_substitutions.substitutions.input_scaling import InputScaling, \
|
|
21
|
-
InputScalingWithPad
|
|
22
|
-
from model_compression_toolkit.keras.graph_substitutions.substitutions.mark_activation import MarkActivation
|
|
23
|
-
from model_compression_toolkit.keras.graph_substitutions.substitutions.relu_bound_correction import \
|
|
24
|
-
ReLUBoundCorrection
|
|
25
|
-
from model_compression_toolkit.keras.graph_substitutions.substitutions.remove_relu_upper_bound import \
|
|
26
|
-
RemoveReLUUpperBound
|
|
27
|
-
from model_compression_toolkit.keras.graph_substitutions.substitutions.scale_equalization import \
|
|
28
|
-
ScaleEqualization, ScaleEqualizationWithPad, ScaleEqualizationMidActivation, ScaleEqualizationMidActivationWithPad
|
|
29
|
-
from model_compression_toolkit.keras.graph_substitutions.substitutions.separableconv_decomposition import \
|
|
30
|
-
SeparableConvDecomposition
|
|
31
|
-
from model_compression_toolkit.keras.graph_substitutions.substitutions.shift_negative_activation import \
|
|
32
|
-
apply_shift_negative_correction
|
|
33
|
-
from model_compression_toolkit.keras.mixed_precision.sensitivity_evaluation import get_sensitivity_evaluation
|
|
34
|
-
from model_compression_toolkit.keras.reader.reader import model_reader
|
|
35
|
-
from model_compression_toolkit.keras.tensor_marking import get_node_stats_collector
|
|
36
|
-
import model_compression_toolkit.keras.constants as keras_constants
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
class KerasImplementation(FrameworkImplementation):
|
|
40
|
-
"""
|
|
41
|
-
An class with implemented methods to support optimizing Keras models.
|
|
42
|
-
"""
|
|
43
|
-
|
|
44
|
-
def __init__(self):
|
|
45
|
-
super().__init__()
|
|
46
|
-
|
|
47
|
-
@property
|
|
48
|
-
def constants(self):
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
Returns: Module of Keras constants.
|
|
52
|
-
|
|
53
|
-
"""
|
|
54
|
-
return keras_constants
|
|
55
|
-
|
|
56
|
-
def model_reader(self, model: Model) -> Graph:
|
|
57
|
-
"""
|
|
58
|
-
Convert a framework's model into a graph.
|
|
59
|
-
Args:
|
|
60
|
-
model: Framework's model.
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
Graph representing the input model.
|
|
64
|
-
"""
|
|
65
|
-
return model_reader(model)
|
|
66
|
-
|
|
67
|
-
def to_numpy(self, tensor: tf.Tensor) -> np.ndarray:
|
|
68
|
-
"""
|
|
69
|
-
Convert framework's tensor to a Numpy array.
|
|
70
|
-
Args:
|
|
71
|
-
tensor: Framework's tensor.
|
|
72
|
-
|
|
73
|
-
Returns:
|
|
74
|
-
Numpy array converted from the input tensor.
|
|
75
|
-
"""
|
|
76
|
-
return tensor.numpy()
|
|
77
|
-
|
|
78
|
-
def model_builder(self,
|
|
79
|
-
graph: Graph,
|
|
80
|
-
mode: ModelBuilderMode,
|
|
81
|
-
append2output: List[Any] = None,
|
|
82
|
-
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO) -> Tuple[Model, UserInformation]:
|
|
83
|
-
"""
|
|
84
|
-
Build a Keras model from a graph.
|
|
85
|
-
The mode determines how the model should be build. append2output is a list of Nodes
|
|
86
|
-
to set as the model outputs.
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
graph: Graph to build the model from it.
|
|
90
|
-
mode: Mode for how to build the model.
|
|
91
|
-
append2output: List of Nodes to set as the model's outputs.
|
|
92
|
-
fw_info: FrameworkInfo object with information about the specific framework's model
|
|
93
|
-
|
|
94
|
-
Returns:
|
|
95
|
-
A tuple of the Keras model that was built and an UserInformation object.
|
|
96
|
-
"""
|
|
97
|
-
return model_builder(graph,
|
|
98
|
-
mode,
|
|
99
|
-
append2output,
|
|
100
|
-
fw_info)
|
|
101
|
-
|
|
102
|
-
def shift_negative_correction(self,
|
|
103
|
-
graph: Graph,
|
|
104
|
-
qc: QuantizationConfig,
|
|
105
|
-
fw_info: FrameworkInfo) -> Graph:
|
|
106
|
-
"""
|
|
107
|
-
Apply shift negative correction (SNC) on a graph.
|
|
108
|
-
|
|
109
|
-
Args:
|
|
110
|
-
graph: Graph to apply SNC on.
|
|
111
|
-
qc: Quantization configuration.
|
|
112
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
|
113
|
-
|
|
114
|
-
Returns:
|
|
115
|
-
Graph after SNC.
|
|
116
|
-
"""
|
|
117
|
-
return apply_shift_negative_correction(graph,
|
|
118
|
-
qc,
|
|
119
|
-
fw_info)
|
|
120
|
-
|
|
121
|
-
def attach_sc_to_node(self, node: Node,
|
|
122
|
-
fw_info: FrameworkInfo) -> common.statistics_collector.BaseStatsContainer:
|
|
123
|
-
"""
|
|
124
|
-
Return a statistics collector that should be attached to a node's output
|
|
125
|
-
during statistics collection.
|
|
126
|
-
|
|
127
|
-
Args:
|
|
128
|
-
node: Node to return its collector.
|
|
129
|
-
fw_info: FrameworkInfo object with information about the specific framework's model
|
|
130
|
-
|
|
131
|
-
Returns:
|
|
132
|
-
Statistics collector for the node.
|
|
133
|
-
"""
|
|
134
|
-
return get_node_stats_collector(node,
|
|
135
|
-
fw_info)
|
|
136
|
-
|
|
137
|
-
def get_substitutions_marking(self) -> List[common.BaseSubstitution]:
|
|
138
|
-
"""
|
|
139
|
-
|
|
140
|
-
Returns: A list of the framework substitutions used for marking
|
|
141
|
-
points we fuse.
|
|
142
|
-
|
|
143
|
-
"""
|
|
144
|
-
return [MarkActivation()]
|
|
145
|
-
|
|
146
|
-
def get_substitutions_pre_statistics_collection(self) -> List[common.BaseSubstitution]:
|
|
147
|
-
"""
|
|
148
|
-
|
|
149
|
-
Returns: A list of the framework substitutions used before we build a quantized model.
|
|
150
|
-
|
|
151
|
-
"""
|
|
152
|
-
return [SeparableConvDecomposition(),
|
|
153
|
-
ActivationDecomposition(),
|
|
154
|
-
BatchNormalizationFolding()]
|
|
155
|
-
|
|
156
|
-
def get_substitutions_post_statistics_collection(self, quant_config: QuantizationConfig) -> List[
|
|
157
|
-
common.BaseSubstitution]:
|
|
158
|
-
"""
|
|
159
|
-
Return a list of the framework substitutions used after we collect statistics.
|
|
160
|
-
|
|
161
|
-
Args:
|
|
162
|
-
quant_config: QuantizationConfig to determine which substitutions to return.
|
|
163
|
-
|
|
164
|
-
Returns:
|
|
165
|
-
A list of the framework substitutions used after we collect statistics.
|
|
166
|
-
"""
|
|
167
|
-
substitutions_list = []
|
|
168
|
-
if quant_config.input_scaling:
|
|
169
|
-
substitutions_list.append(InputScaling())
|
|
170
|
-
substitutions_list.append(InputScalingWithPad())
|
|
171
|
-
|
|
172
|
-
if quant_config.relu_unbound_correction:
|
|
173
|
-
substitutions_list.append(ReLUBoundCorrection())
|
|
174
|
-
return substitutions_list
|
|
175
|
-
|
|
176
|
-
def get_substitutions_channel_equalization(self,
|
|
177
|
-
quant_config: QuantizationConfig,
|
|
178
|
-
fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
|
|
179
|
-
"""
|
|
180
|
-
Return a list of the framework substitutions used for channel equalization.
|
|
181
|
-
|
|
182
|
-
Args:
|
|
183
|
-
quant_config: QuantizationConfig to determine which substitutions to return.
|
|
184
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
|
185
|
-
|
|
186
|
-
Returns:
|
|
187
|
-
A list of the framework substitutions used after we collect statistics.
|
|
188
|
-
"""
|
|
189
|
-
substitutions_list = []
|
|
190
|
-
if quant_config.activation_channel_equalization:
|
|
191
|
-
substitutions_list.extend([ScaleEqualization(quant_config, fw_info),
|
|
192
|
-
ScaleEqualizationWithPad(quant_config, fw_info),
|
|
193
|
-
ScaleEqualizationMidActivation(quant_config, fw_info),
|
|
194
|
-
ScaleEqualizationMidActivationWithPad(quant_config, fw_info)])
|
|
195
|
-
return substitutions_list
|
|
196
|
-
|
|
197
|
-
def get_substitutions_pre_build(self) -> List[common.BaseSubstitution]:
|
|
198
|
-
"""
|
|
199
|
-
|
|
200
|
-
Returns: A list of the framework substitutions used before we build a quantized model.
|
|
201
|
-
|
|
202
|
-
"""
|
|
203
|
-
|
|
204
|
-
return [RemoveReLUUpperBound()]
|
|
205
|
-
|
|
206
|
-
def gptq_training(self,
|
|
207
|
-
graph: Graph,
|
|
208
|
-
representative_data_gen: Callable,
|
|
209
|
-
gptq_config: GradientPTQConfig,
|
|
210
|
-
fw_info: FrameworkInfo) -> Graph:
|
|
211
|
-
"""
|
|
212
|
-
Update a graph using GPTQ after minimizing the loss between the float model's output
|
|
213
|
-
and the quantized model's outputs.
|
|
214
|
-
|
|
215
|
-
Args:
|
|
216
|
-
graph: Graph to fine-tune.
|
|
217
|
-
representative_data_gen: Dataset to use for inputs of the models.
|
|
218
|
-
gptq_config: GradientPTQConfig with configuration for the fine-tuning process.
|
|
219
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
|
220
|
-
|
|
221
|
-
Returns:
|
|
222
|
-
Updated graph after GPTQ.
|
|
223
|
-
"""
|
|
224
|
-
|
|
225
|
-
return gptq_training_wrapper(graph,
|
|
226
|
-
representative_data_gen,
|
|
227
|
-
gptq_config,
|
|
228
|
-
fw_info)
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
def get_sensitivity_evaluation_fn(self,
|
|
232
|
-
graph: Graph,
|
|
233
|
-
quant_config: MixedPrecisionQuantizationConfig,
|
|
234
|
-
metrics_weights: np.ndarray,
|
|
235
|
-
representative_data_gen: Callable,
|
|
236
|
-
fw_info: FrameworkInfo) -> Callable:
|
|
237
|
-
"""
|
|
238
|
-
Create and return a function to compute a sensitivity metric for a mixed-precision
|
|
239
|
-
configuration (comparing to the float Keras model).
|
|
240
|
-
|
|
241
|
-
Args:
|
|
242
|
-
graph: Graph to build it's float and mixed-precision Keras models.
|
|
243
|
-
quant_config: QuantizationConfig of how the model should be quantized.
|
|
244
|
-
metrics_weights: Array of weights to weight the sensitivity among different layers.
|
|
245
|
-
representative_data_gen: Dataset to use for retrieving images for the models inputs.
|
|
246
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
|
247
|
-
|
|
248
|
-
Returns:
|
|
249
|
-
A function that computes the metric.
|
|
250
|
-
"""
|
|
251
|
-
|
|
252
|
-
return get_sensitivity_evaluation(graph,
|
|
253
|
-
quant_config,
|
|
254
|
-
metrics_weights,
|
|
255
|
-
representative_data_gen,
|
|
256
|
-
fw_info)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|