mct-nightly 1.1.0.7012022.post2611__py3-none-any.whl → 1.1.0.07122021-002414__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/METADATA +3 -3
- {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/RECORD +72 -76
- {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/WHEEL +1 -1
- model_compression_toolkit/common/__init__.py +2 -2
- model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +2 -2
- model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +9 -9
- model_compression_toolkit/common/collectors/mean_collector.py +2 -3
- model_compression_toolkit/common/collectors/min_max_per_channel_collector.py +3 -6
- model_compression_toolkit/common/constants.py +0 -1
- model_compression_toolkit/common/framework_implementation.py +6 -22
- model_compression_toolkit/common/framework_info.py +7 -39
- model_compression_toolkit/common/graph/__init__.py +1 -1
- model_compression_toolkit/common/graph/base_graph.py +34 -34
- model_compression_toolkit/common/graph/edge.py +3 -3
- model_compression_toolkit/common/graph/graph_matchers.py +3 -3
- model_compression_toolkit/common/graph/graph_searches.py +4 -4
- model_compression_toolkit/common/graph/graph_vis.py +116 -0
- model_compression_toolkit/common/graph/{base_node.py → node.py} +27 -11
- model_compression_toolkit/common/mixed_precision/bit_width_setter.py +9 -10
- model_compression_toolkit/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/common/model_collector.py +12 -14
- model_compression_toolkit/common/network_editors/actions.py +23 -19
- model_compression_toolkit/common/post_training_quantization.py +7 -20
- model_compression_toolkit/common/quantization/node_quantization_config.py +5 -13
- model_compression_toolkit/common/quantization/quantization_analyzer.py +7 -11
- model_compression_toolkit/common/quantization/quantization_config.py +6 -6
- model_compression_toolkit/common/quantization/quantization_params_fn_selection.py +3 -2
- model_compression_toolkit/common/quantization/quantization_params_generation/qparams_activations_computation.py +7 -13
- model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +20 -17
- model_compression_toolkit/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/common/quantization/set_node_quantization_config.py +36 -39
- model_compression_toolkit/common/{collectors/statistics_collector.py → statistics_collector.py} +30 -26
- model_compression_toolkit/common/visualization/tensorboard_writer.py +8 -11
- model_compression_toolkit/keras/back2framework/instance_builder.py +4 -4
- model_compression_toolkit/keras/back2framework/model_builder.py +34 -47
- model_compression_toolkit/keras/constants.py +0 -3
- model_compression_toolkit/keras/default_framework_info.py +7 -33
- model_compression_toolkit/keras/gradient_ptq/graph_info.py +2 -2
- model_compression_toolkit/keras/gradient_ptq/graph_update.py +1 -7
- model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +1 -0
- model_compression_toolkit/keras/graph_substitutions/substitutions/activation_decomposition.py +8 -10
- model_compression_toolkit/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -2
- model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +2 -2
- model_compression_toolkit/keras/graph_substitutions/substitutions/mark_activation.py +3 -3
- model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +4 -3
- model_compression_toolkit/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +2 -2
- model_compression_toolkit/keras/graph_substitutions/substitutions/scale_equalization.py +9 -9
- model_compression_toolkit/keras/graph_substitutions/substitutions/separableconv_decomposition.py +19 -19
- model_compression_toolkit/keras/graph_substitutions/substitutions/shift_negative_activation.py +45 -64
- model_compression_toolkit/keras/keras_implementation.py +8 -28
- model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/keras/quantization_facade.py +1 -5
- model_compression_toolkit/keras/quantizer/fake_quant_builder.py +4 -4
- model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer.py +2 -3
- model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer_gptq_config.py +4 -8
- model_compression_toolkit/keras/quantizer/gradient_ptq/activation_weight_quantizer_gptq_config.py +4 -9
- model_compression_toolkit/keras/quantizer/gradient_ptq/config_factory.py +10 -9
- model_compression_toolkit/keras/quantizer/gradient_ptq/weight_quantizer_gptq_config.py +1 -9
- model_compression_toolkit/keras/quantizer/mixed_precision/quantization_config_factory.py +1 -1
- model_compression_toolkit/keras/quantizer/mixed_precision/selective_weights_quantize_config.py +1 -6
- model_compression_toolkit/keras/reader/common.py +11 -9
- model_compression_toolkit/keras/reader/connectivity_handler.py +9 -15
- model_compression_toolkit/keras/reader/nested_model/edges_merger.py +6 -6
- model_compression_toolkit/keras/reader/nested_model/nested_model_handler.py +2 -2
- model_compression_toolkit/keras/reader/nested_model/nodes_merger.py +3 -3
- model_compression_toolkit/keras/reader/nested_model/outputs_merger.py +2 -2
- model_compression_toolkit/keras/reader/node_builder.py +15 -65
- model_compression_toolkit/keras/reader/reader.py +5 -5
- model_compression_toolkit/keras/tensor_marking.py +113 -0
- model_compression_toolkit/keras/visualization/nn_visualizer.py +2 -2
- model_compression_toolkit/common/collectors/statistics_collector_generator.py +0 -43
- model_compression_toolkit/common/graph/functional_node.py +0 -59
- model_compression_toolkit/common/model_validation.py +0 -43
- model_compression_toolkit/common/node_prior_info.py +0 -29
- model_compression_toolkit/keras/keras_model_validation.py +0 -38
- model_compression_toolkit/keras/keras_node_prior_info.py +0 -60
- {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/LICENSE +0 -0
- {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/top_level.txt +0 -0
|
@@ -16,8 +16,7 @@
|
|
|
16
16
|
from abc import ABC, abstractmethod
|
|
17
17
|
from collections import namedtuple
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit.common.graph.
|
|
20
|
-
from model_compression_toolkit.common.quantization import quantization_params_generation
|
|
19
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
21
20
|
from model_compression_toolkit.common.quantization.quantization_params_fn_selection import \
|
|
22
21
|
get_activation_quantization_params_fn, get_weights_quantization_params_fn
|
|
23
22
|
|
|
@@ -51,7 +50,7 @@ class BaseAction(ABC):
|
|
|
51
50
|
"""
|
|
52
51
|
|
|
53
52
|
@abstractmethod
|
|
54
|
-
def apply(self, node:
|
|
53
|
+
def apply(self, node: Node, graph, fw_info):
|
|
55
54
|
"""
|
|
56
55
|
Apply an action on the node after matching the node with a node filter.
|
|
57
56
|
|
|
@@ -82,7 +81,7 @@ class ChangeCandidatesWeightsQuantConfigAttr(BaseAction):
|
|
|
82
81
|
"""
|
|
83
82
|
self.kwargs = kwargs
|
|
84
83
|
|
|
85
|
-
def apply(self, node:
|
|
84
|
+
def apply(self, node: Node, graph, fw_info):
|
|
86
85
|
"""
|
|
87
86
|
Change the attribute 'attr_name' in quant_config with 'attr_value'.
|
|
88
87
|
|
|
@@ -94,9 +93,10 @@ class ChangeCandidatesWeightsQuantConfigAttr(BaseAction):
|
|
|
94
93
|
Returns:
|
|
95
94
|
The node after its quant_config has been modified.
|
|
96
95
|
"""
|
|
97
|
-
|
|
98
|
-
for
|
|
99
|
-
|
|
96
|
+
if node.candidates_weights_quantization_cfg is not None:
|
|
97
|
+
for nqc in node.candidates_weights_quantization_cfg:
|
|
98
|
+
for attr_name, attr_value in self.kwargs.items():
|
|
99
|
+
nqc.set_quant_config_attr(attr_name, attr_value)
|
|
100
100
|
|
|
101
101
|
|
|
102
102
|
class ChangeFinalWeightsQuantConfigAttr(BaseAction):
|
|
@@ -113,7 +113,7 @@ class ChangeFinalWeightsQuantConfigAttr(BaseAction):
|
|
|
113
113
|
"""
|
|
114
114
|
self.kwargs = kwargs
|
|
115
115
|
|
|
116
|
-
def apply(self, node:
|
|
116
|
+
def apply(self, node: Node, graph, fw_info):
|
|
117
117
|
if node.final_weights_quantization_cfg is not None:
|
|
118
118
|
for attr_name, attr_value in self.kwargs.items():
|
|
119
119
|
node.final_weights_quantization_cfg.set_quant_config_attr(attr_name, attr_value)
|
|
@@ -134,7 +134,7 @@ class ChangeActivationQuantConfigAttr(BaseAction):
|
|
|
134
134
|
"""
|
|
135
135
|
self.kwargs = kwargs
|
|
136
136
|
|
|
137
|
-
def apply(self, node:
|
|
137
|
+
def apply(self, node: Node, graph, fw_info):
|
|
138
138
|
"""
|
|
139
139
|
Change the attribute 'attr_name' in quant_config with 'attr_value'.
|
|
140
140
|
|
|
@@ -146,8 +146,9 @@ class ChangeActivationQuantConfigAttr(BaseAction):
|
|
|
146
146
|
Returns:q
|
|
147
147
|
The node after its quant_config has been modified.
|
|
148
148
|
"""
|
|
149
|
-
|
|
150
|
-
|
|
149
|
+
if node.activation_quantization_cfg is not None:
|
|
150
|
+
for attr_name, attr_value in self.kwargs.items():
|
|
151
|
+
node.activation_quantization_cfg.set_quant_config_attr(attr_name, attr_value)
|
|
151
152
|
|
|
152
153
|
|
|
153
154
|
class ChangeQuantizationParamFunction(BaseAction):
|
|
@@ -166,7 +167,7 @@ class ChangeQuantizationParamFunction(BaseAction):
|
|
|
166
167
|
self.activation_quantization_params_fn = activation_quantization_params_fn
|
|
167
168
|
self.weights_quantization_params_fn = weights_quantization_params_fn
|
|
168
169
|
|
|
169
|
-
def apply(self, node:
|
|
170
|
+
def apply(self, node: Node, graph, fw_info):
|
|
170
171
|
"""
|
|
171
172
|
Change the node's weights/activations quantization params function.
|
|
172
173
|
|
|
@@ -201,7 +202,7 @@ class ChangeActivationQuantizationMethod(BaseAction):
|
|
|
201
202
|
"""
|
|
202
203
|
self.activation_quantization_method = activation_quantization_method
|
|
203
204
|
|
|
204
|
-
def apply(self, node:
|
|
205
|
+
def apply(self, node: Node, graph, fw_info):
|
|
205
206
|
"""
|
|
206
207
|
Change the node's activations quantization function.
|
|
207
208
|
|
|
@@ -216,12 +217,15 @@ class ChangeActivationQuantizationMethod(BaseAction):
|
|
|
216
217
|
"""
|
|
217
218
|
if self.activation_quantization_method is not None:
|
|
218
219
|
|
|
220
|
+
out_stats_container = graph.get_out_stats_collector(node)[0] if isinstance(
|
|
221
|
+
graph.get_out_stats_collector(node),
|
|
222
|
+
list) else graph.get_out_stats_collector(
|
|
223
|
+
node)
|
|
224
|
+
|
|
219
225
|
activation_quantization_params_fn = get_activation_quantization_params_fn(
|
|
220
226
|
self.activation_quantization_method,
|
|
221
|
-
node.activation_quantization_cfg.activation_threshold_method
|
|
222
|
-
|
|
223
|
-
if node.prior_info.is_output_bounded():
|
|
224
|
-
activation_quantization_params_fn = quantization_params_generation.no_clipping_selection_min_max
|
|
227
|
+
node.activation_quantization_cfg.activation_threshold_method,
|
|
228
|
+
out_stats_container.use_min_max)
|
|
225
229
|
|
|
226
230
|
node.activation_quantization_cfg.set_activation_quantization_params_fn(activation_quantization_params_fn)
|
|
227
231
|
activation_quantization_fn = fw_info.activation_quantizer_mapping.get(self.activation_quantization_method)
|
|
@@ -248,7 +252,7 @@ class ChangeFinalWeightsQuantizationMethod(BaseAction):
|
|
|
248
252
|
|
|
249
253
|
self.weights_quantization_method = weights_quantization_method
|
|
250
254
|
|
|
251
|
-
def apply(self, node:
|
|
255
|
+
def apply(self, node: Node, graph, fw_info):
|
|
252
256
|
"""
|
|
253
257
|
Change the node's weights quantization function.
|
|
254
258
|
|
|
@@ -292,7 +296,7 @@ class ChangeCandidtaesWeightsQuantizationMethod(BaseAction):
|
|
|
292
296
|
"""
|
|
293
297
|
self.weights_quantization_method = weights_quantization_method
|
|
294
298
|
|
|
295
|
-
def apply(self, node:
|
|
299
|
+
def apply(self, node: Node, graph, fw_info):
|
|
296
300
|
"""
|
|
297
301
|
Change the node's weights quantization function.
|
|
298
302
|
|
|
@@ -35,13 +35,11 @@ from model_compression_toolkit.common.network_editors.actions import EditRule
|
|
|
35
35
|
from model_compression_toolkit.common.network_editors.edit_network import edit_network_graph
|
|
36
36
|
from model_compression_toolkit.common.mixed_precision.mixed_precision_quantization_config import \
|
|
37
37
|
MixedPrecisionQuantizationConfig
|
|
38
|
-
from model_compression_toolkit.common.quantization.quantization_params_fn_selection import \
|
|
39
|
-
get_activation_quantization_params_fn
|
|
40
38
|
from model_compression_toolkit.common.quantization.quantize_graph_weights import quantize_graph_weights
|
|
41
39
|
from model_compression_toolkit.common.bias_correction.compute_bias_correction_of_graph import compute_bias_correction_of_graph
|
|
42
40
|
|
|
43
41
|
from model_compression_toolkit.common.quantization.quantization_analyzer import analyzer_graph
|
|
44
|
-
from model_compression_toolkit.common.quantization.quantization_config import DEFAULTCONFIG
|
|
42
|
+
from model_compression_toolkit.common.quantization.quantization_config import DEFAULTCONFIG
|
|
45
43
|
from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
|
|
46
44
|
from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_computation import \
|
|
47
45
|
calculate_quantization_params
|
|
@@ -364,21 +362,6 @@ def _prepare_model_for_quantization(in_model: Any,
|
|
|
364
362
|
if tb_w is not None:
|
|
365
363
|
tb_w.add_graph(transformed_graph, 'pre_statistics_collection_substitutions')
|
|
366
364
|
|
|
367
|
-
#########################################
|
|
368
|
-
# Set prior info to nodes
|
|
369
|
-
##########################################
|
|
370
|
-
for node in transformed_graph.nodes:
|
|
371
|
-
node.prior_info = fw_impl.get_node_prior_info(node=node,
|
|
372
|
-
fw_info=fw_info)
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
######################################
|
|
376
|
-
# Add quantization configurations
|
|
377
|
-
######################################
|
|
378
|
-
transformed_graph = set_quantization_configuration_to_graph(graph=transformed_graph,
|
|
379
|
-
quant_config=quant_config,
|
|
380
|
-
fw_info=fw_info)
|
|
381
|
-
|
|
382
365
|
######################################
|
|
383
366
|
# Graph marking points
|
|
384
367
|
######################################
|
|
@@ -398,7 +381,6 @@ def _prepare_model_for_quantization(in_model: Any,
|
|
|
398
381
|
if tb_w is not None:
|
|
399
382
|
tb_w.add_graph(transformed_graph, 'after_analyzer_graph')
|
|
400
383
|
|
|
401
|
-
|
|
402
384
|
######################################
|
|
403
385
|
# Statistic collection
|
|
404
386
|
######################################
|
|
@@ -409,6 +391,12 @@ def _prepare_model_for_quantization(in_model: Any,
|
|
|
409
391
|
for _ in tqdm(range(n_iter)):
|
|
410
392
|
mi.infer(representative_data_gen())
|
|
411
393
|
|
|
394
|
+
######################################
|
|
395
|
+
# Add quantization configurations
|
|
396
|
+
######################################
|
|
397
|
+
transformed_graph = set_quantization_configuration_to_graph(transformed_graph,
|
|
398
|
+
quant_config,
|
|
399
|
+
fw_info)
|
|
412
400
|
|
|
413
401
|
######################################
|
|
414
402
|
# Edit network according to user specific settings
|
|
@@ -469,4 +457,3 @@ def _prepare_model_for_quantization(in_model: Any,
|
|
|
469
457
|
assert n.final_weights_quantization_cfg is None
|
|
470
458
|
|
|
471
459
|
return tg_with_bias
|
|
472
|
-
|
|
@@ -62,7 +62,8 @@ class NodeActivationQuantizationConfig(BaseNodeNodeQuantizationConfig):
|
|
|
62
62
|
def __init__(self,
|
|
63
63
|
qc: QuantizationConfig,
|
|
64
64
|
activation_quantization_fn: Callable,
|
|
65
|
-
activation_quantization_params_fn: Callable
|
|
65
|
+
activation_quantization_params_fn: Callable,
|
|
66
|
+
activation_is_signed: bool = None
|
|
66
67
|
):
|
|
67
68
|
"""
|
|
68
69
|
|
|
@@ -70,10 +71,11 @@ class NodeActivationQuantizationConfig(BaseNodeNodeQuantizationConfig):
|
|
|
70
71
|
qc: QuantizationConfig to create the node's config from.
|
|
71
72
|
activation_quantization_fn: Function to use when quantizing the node's activations.
|
|
72
73
|
activation_quantization_params_fn: Function to use when computing the threshold for quantizing a node's activations.
|
|
74
|
+
activation_is_signed: Signedness of the activation quantized range.
|
|
73
75
|
"""
|
|
74
|
-
|
|
75
76
|
self.activation_quantization_fn = activation_quantization_fn
|
|
76
77
|
self.activation_quantization_params_fn = activation_quantization_params_fn
|
|
78
|
+
self.activation_is_signed = activation_is_signed
|
|
77
79
|
self.activation_quantization_params = {}
|
|
78
80
|
self.activation_threshold_method = qc.activation_threshold_method
|
|
79
81
|
self.activation_quantization_method = qc.activation_quantization_method
|
|
@@ -89,14 +91,6 @@ class NodeActivationQuantizationConfig(BaseNodeNodeQuantizationConfig):
|
|
|
89
91
|
self.shift_negative_ratio = qc.shift_negative_ratio
|
|
90
92
|
self.shift_negative_threshold_recalculation = qc.shift_negative_threshold_recalculation
|
|
91
93
|
|
|
92
|
-
def generate_quantization_node(self) -> Callable:
|
|
93
|
-
"""
|
|
94
|
-
Returns: Quantization function to use for quantizing the node's activations,
|
|
95
|
-
with the node's quantization configuration properties.
|
|
96
|
-
"""
|
|
97
|
-
return self.activation_quantization_fn(self.activation_n_bits,
|
|
98
|
-
self.activation_quantization_params)
|
|
99
|
-
|
|
100
94
|
def set_activation_quantization_fn(self, activation_quantization_fn: Callable):
|
|
101
95
|
"""
|
|
102
96
|
Sets activation quantization function for the node.
|
|
@@ -126,7 +120,6 @@ class NodeActivationQuantizationConfig(BaseNodeNodeQuantizationConfig):
|
|
|
126
120
|
activation_params: Dictionary that contains weight quantization params.
|
|
127
121
|
|
|
128
122
|
"""
|
|
129
|
-
assert self.enable_activation_quantization
|
|
130
123
|
for param_name, param_value in activation_params.items():
|
|
131
124
|
self.activation_quantization_params[param_name] = param_value
|
|
132
125
|
|
|
@@ -205,7 +198,6 @@ class NodeWeightsQuantizationConfig(BaseNodeNodeQuantizationConfig):
|
|
|
205
198
|
weights_params: Dictionary that contains weight quantization params.
|
|
206
199
|
|
|
207
200
|
"""
|
|
208
|
-
assert self.enable_weights_quantization
|
|
209
201
|
for param_name, param_value in weights_params.items():
|
|
210
202
|
self.weights_quantization_params[param_name] = param_value
|
|
211
203
|
|
|
@@ -218,7 +210,7 @@ class NodeWeightsQuantizationConfig(BaseNodeNodeQuantizationConfig):
|
|
|
218
210
|
Recalculated weights quantization params from the kernel and channel axis.
|
|
219
211
|
|
|
220
212
|
"""
|
|
221
|
-
|
|
213
|
+
|
|
222
214
|
if self.weights_quantization_params_fn is not None:
|
|
223
215
|
self.set_weights_quantization_param(self.weights_quantization_params_fn(tensor_data,
|
|
224
216
|
p=self.l_p_value,
|
|
@@ -21,20 +21,17 @@ from model_compression_toolkit import common
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
def create_tensor2node(graph: common.Graph,
|
|
24
|
-
node: common.
|
|
25
|
-
fw_info: common.FrameworkInfo):
|
|
24
|
+
node: common.Node):
|
|
26
25
|
"""
|
|
27
26
|
Force tensor creation and assignment for a node.
|
|
28
27
|
Args:
|
|
29
28
|
graph: Graph of the node (for retrieving the current tensor).
|
|
30
29
|
node: Node to create a tensor for.
|
|
31
|
-
fw_info: Specific framework information (for example, output channels index).
|
|
32
30
|
|
|
33
31
|
"""
|
|
34
32
|
current_tensor = graph.get_out_stats_collector(node)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
graph.set_out_stats_collector_to_node(node, common.StatsCollector(output_channel_index=fw_info.output_channel_index))
|
|
33
|
+
if isinstance(current_tensor, common.NoStatsContainer) or current_tensor is None:
|
|
34
|
+
graph.set_out_stats_collector_to_node(node, common.StatsContainer())
|
|
38
35
|
|
|
39
36
|
|
|
40
37
|
def analyzer_graph(node_analyze_func: Callable,
|
|
@@ -56,7 +53,7 @@ def analyzer_graph(node_analyze_func: Callable,
|
|
|
56
53
|
"""
|
|
57
54
|
nodes_sorted = topological_sort(graph)
|
|
58
55
|
for n in nodes_sorted:
|
|
59
|
-
|
|
56
|
+
t = node_analyze_func(n, fw_info) # Get tensor for the node
|
|
60
57
|
# If we use bias correction, and the node has coefficients to quantize, we need to make sure
|
|
61
58
|
# its previous nodes' tensors are consistent with this node.
|
|
62
59
|
# TODO: factor tensor marking in case of bias correction.
|
|
@@ -64,7 +61,6 @@ def analyzer_graph(node_analyze_func: Callable,
|
|
|
64
61
|
for ie in graph.incoming_edges(n):
|
|
65
62
|
input_node = ie.source_node
|
|
66
63
|
create_tensor2node(graph,
|
|
67
|
-
input_node
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
graph.set_out_stats_collector_to_node(n, sc)
|
|
64
|
+
input_node)
|
|
65
|
+
if t is not None:
|
|
66
|
+
graph.set_out_stats_collector_to_node(n, t)
|
|
@@ -155,12 +155,12 @@ DEFAULTCONFIG = QuantizationConfig(ThresholdSelectionMethod.MSE,
|
|
|
155
155
|
ThresholdSelectionMethod.MSE,
|
|
156
156
|
QuantizationMethod.POWER_OF_TWO,
|
|
157
157
|
QuantizationMethod.POWER_OF_TWO,
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
158
|
+
8,
|
|
159
|
+
8,
|
|
160
|
+
False,
|
|
161
|
+
True,
|
|
162
|
+
True,
|
|
163
|
+
False)
|
|
164
164
|
|
|
165
165
|
|
|
166
166
|
|
|
@@ -23,7 +23,8 @@ from model_compression_toolkit.common.quantization.quantization_params_generatio
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
def get_activation_quantization_params_fn(activation_quantization_method: QuantizationMethod,
|
|
26
|
-
activation_threshold_method: ThresholdSelectionMethod
|
|
26
|
+
activation_threshold_method: ThresholdSelectionMethod,
|
|
27
|
+
use_min_max: bool) -> Callable:
|
|
27
28
|
"""
|
|
28
29
|
Generate a function for finding activation quantization threshold.
|
|
29
30
|
|
|
@@ -37,7 +38,7 @@ def get_activation_quantization_params_fn(activation_quantization_method: Quanti
|
|
|
37
38
|
"""
|
|
38
39
|
if activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
|
|
39
40
|
# Use min/max as the threshold if we use NOCLIPPING
|
|
40
|
-
if activation_threshold_method == ThresholdSelectionMethod.NOCLIPPING:
|
|
41
|
+
if use_min_max or activation_threshold_method == ThresholdSelectionMethod.NOCLIPPING:
|
|
41
42
|
params_fn = quantization_params_generation.no_clipping_selection_min_max
|
|
42
43
|
# Use MSE to search_methods for the optimal threshold.
|
|
43
44
|
elif activation_threshold_method == ThresholdSelectionMethod.MSE:
|
|
@@ -15,13 +15,12 @@
|
|
|
15
15
|
import numpy as np
|
|
16
16
|
from typing import Tuple, Dict
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.common import
|
|
19
|
-
from model_compression_toolkit.common.constants import SIGNED
|
|
18
|
+
from model_compression_toolkit.common import Node, Graph
|
|
20
19
|
from model_compression_toolkit.common.quantization import quantization_params_generation
|
|
21
20
|
|
|
22
21
|
|
|
23
|
-
def get_activations_qparams(n:
|
|
24
|
-
graph: Graph) -> Dict[str, float]:
|
|
22
|
+
def get_activations_qparams(n: Node,
|
|
23
|
+
graph: Graph) -> Tuple[Dict[str, float], bool]:
|
|
25
24
|
"""
|
|
26
25
|
Compute the activations params for a given node in a graph according to a params function.
|
|
27
26
|
|
|
@@ -30,29 +29,25 @@ def get_activations_qparams(n: BaseNode,
|
|
|
30
29
|
graph: Graph the node is in.
|
|
31
30
|
|
|
32
31
|
Returns:
|
|
33
|
-
|
|
32
|
+
Tuple of the computed quantization params and sign for the node's activations quantization.
|
|
34
33
|
"""
|
|
35
|
-
|
|
36
34
|
out_stats_container = graph.get_out_stats_collector(n)
|
|
37
35
|
bins_values, bins_counts = None, None
|
|
38
36
|
|
|
39
37
|
# If the statistics container collected the histogram, we start by filtering outliers using z threshold
|
|
40
38
|
# filtering, and then computing the threshold based on the filtered histogram.
|
|
41
|
-
if out_stats_container.
|
|
39
|
+
if out_stats_container.collect_histogram:
|
|
42
40
|
bins_values, bins_counts = out_stats_container.hc.get_histogram()
|
|
43
41
|
bins_counts = quantization_params_generation.z_score_filter(n.activation_quantization_cfg.z_threshold,
|
|
44
42
|
bins_values,
|
|
45
43
|
bins_counts)
|
|
46
44
|
min_value, max_value = out_stats_container.get_min_max_values()
|
|
47
45
|
|
|
48
|
-
if
|
|
46
|
+
if out_stats_container.use_min_max:
|
|
49
47
|
signed = min_value < 0
|
|
50
48
|
else:
|
|
51
49
|
signed = np.any(bins_values < 0)
|
|
52
50
|
|
|
53
|
-
if n.prior_info.is_output_bounded():
|
|
54
|
-
n.activation_quantization_cfg.activation_quantization_params_fn = quantization_params_generation.no_clipping_selection_min_max
|
|
55
|
-
|
|
56
51
|
activation_params = n.activation_quantization_cfg.activation_quantization_params_fn(bins_values,
|
|
57
52
|
bins_counts,
|
|
58
53
|
n.activation_quantization_cfg.l_p_value,
|
|
@@ -60,6 +55,5 @@ def get_activations_qparams(n: BaseNode,
|
|
|
60
55
|
min_value,
|
|
61
56
|
max_value,
|
|
62
57
|
min_threshold=n.activation_quantization_cfg.min_threshold)
|
|
63
|
-
activation_params.update({SIGNED: signed})
|
|
64
58
|
|
|
65
|
-
return activation_params
|
|
59
|
+
return activation_params, signed
|
model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py
CHANGED
|
@@ -16,7 +16,7 @@ from typing import List
|
|
|
16
16
|
|
|
17
17
|
from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
|
|
18
18
|
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
19
|
-
from model_compression_toolkit.common import Graph,
|
|
19
|
+
from model_compression_toolkit.common import Graph, Node, Logger
|
|
20
20
|
from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_activations_computation \
|
|
21
21
|
import \
|
|
22
22
|
get_activations_qparams
|
|
@@ -26,7 +26,7 @@ from model_compression_toolkit.common.quantization.quantization_params_generatio
|
|
|
26
26
|
|
|
27
27
|
def calculate_quantization_params(graph: Graph,
|
|
28
28
|
fw_info: FrameworkInfo,
|
|
29
|
-
nodes: List[
|
|
29
|
+
nodes: List[Node] = [],
|
|
30
30
|
specific_nodes: bool = False,
|
|
31
31
|
fw_impl: FrameworkImplementation = None):
|
|
32
32
|
"""
|
|
@@ -48,7 +48,7 @@ def calculate_quantization_params(graph: Graph,
|
|
|
48
48
|
"""
|
|
49
49
|
|
|
50
50
|
# Create a list of nodes to compute their thresholds
|
|
51
|
-
nodes_list: List[
|
|
51
|
+
nodes_list: List[Node] = nodes if specific_nodes else graph.nodes()
|
|
52
52
|
|
|
53
53
|
for n in nodes_list: # iterate only nodes that we should compute their thresholds
|
|
54
54
|
|
|
@@ -56,23 +56,25 @@ def calculate_quantization_params(graph: Graph,
|
|
|
56
56
|
input_channels_axis, activation_threshold_float = {}, {}, None, None, None, None
|
|
57
57
|
|
|
58
58
|
if fw_info.in_kernel_ops(n): # If the node has a kernel to quantize
|
|
59
|
-
if n.is_weights_quantization_enabled():
|
|
60
|
-
for candidtae_qc in n.candidates_weights_quantization_cfg:
|
|
61
|
-
output_channels_axis, _ = get_channels_axis(candidtae_qc, fw_info, n.layer_class)
|
|
62
|
-
weights_params = get_weights_qparams(n.get_weights_by_keys(fw_impl.constants.KERNEL),
|
|
63
|
-
candidtae_qc,
|
|
64
|
-
output_channels_axis)
|
|
65
59
|
|
|
66
|
-
|
|
67
|
-
|
|
60
|
+
for candidtae_qc in n.candidates_weights_quantization_cfg:
|
|
61
|
+
output_channels_axis, _ = get_channels_axis(candidtae_qc, fw_info, n.layer_class)
|
|
62
|
+
weights_params = get_weights_qparams(n.get_weights_by_keys(fw_impl.constants.KERNEL),
|
|
63
|
+
candidtae_qc,
|
|
64
|
+
output_channels_axis)
|
|
68
65
|
|
|
69
|
-
|
|
66
|
+
candidtae_qc.set_weights_quantization_param(weights_params)
|
|
67
|
+
candidtae_qc.weights_channels_axis = output_channels_axis
|
|
68
|
+
|
|
69
|
+
if n.output_quantization: # If node's activations should be quantized as well, we compute its
|
|
70
70
|
# activation threshold
|
|
71
|
-
activation_params = get_activations_qparams(n=n,
|
|
71
|
+
activation_params, activation_is_signed = get_activations_qparams(n=n,
|
|
72
|
+
graph=graph)
|
|
72
73
|
|
|
73
74
|
elif fw_info.in_activation_ops(n): # If node has no kernel, but its activations should be quantized
|
|
74
|
-
if n.
|
|
75
|
-
activation_params = get_activations_qparams(n=n,
|
|
75
|
+
if n.output_quantization:
|
|
76
|
+
activation_params, activation_is_signed = get_activations_qparams(n=n,
|
|
77
|
+
graph=graph)
|
|
76
78
|
# If node should not be quantized at all
|
|
77
79
|
elif fw_info.in_no_quantization_ops(n):
|
|
78
80
|
pass # pragma: no cover
|
|
@@ -82,5 +84,6 @@ def calculate_quantization_params(graph: Graph,
|
|
|
82
84
|
Logger.warning(f"Warning: unknown layer: {n.layer_class.__name__}")
|
|
83
85
|
|
|
84
86
|
# Create a NodeQuantizationConfig containing all quantization params and attach it to the node
|
|
85
|
-
if n.
|
|
86
|
-
n.activation_quantization_cfg.set_activation_quantization_param(activation_params)
|
|
87
|
+
if n.activation_quantization_cfg is not None:
|
|
88
|
+
n.activation_quantization_cfg.set_activation_quantization_param(activation_params)
|
|
89
|
+
n.activation_quantization_cfg.activation_is_signed = activation_is_signed
|
|
@@ -19,7 +19,7 @@ import copy
|
|
|
19
19
|
from model_compression_toolkit import common
|
|
20
20
|
from model_compression_toolkit.common import Logger
|
|
21
21
|
from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
|
|
22
|
-
from model_compression_toolkit.common.graph.
|
|
22
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
23
23
|
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
24
24
|
from model_compression_toolkit.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
|
|
25
25
|
from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_weights_computation import \
|
|
@@ -27,7 +27,7 @@ from model_compression_toolkit.common.quantization.quantization_params_generatio
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def get_quantized_kernel_by_weights_qc(fw_info:FrameworkInfo,
|
|
30
|
-
n:
|
|
30
|
+
n:Node,
|
|
31
31
|
weights_qc: NodeWeightsQuantizationConfig,
|
|
32
32
|
fw_impl: FrameworkImplementation):
|
|
33
33
|
"""
|
|
@@ -17,7 +17,8 @@
|
|
|
17
17
|
import copy
|
|
18
18
|
from typing import List
|
|
19
19
|
|
|
20
|
-
from model_compression_toolkit.common import
|
|
20
|
+
from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
|
|
21
|
+
from model_compression_toolkit.common import Logger
|
|
21
22
|
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
22
23
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
23
24
|
from model_compression_toolkit.common.mixed_precision.mixed_precision_quantization_config import \
|
|
@@ -46,45 +47,38 @@ def set_quantization_configuration_to_graph(graph: Graph,
|
|
|
46
47
|
"""
|
|
47
48
|
|
|
48
49
|
graph_with_qcs = copy.deepcopy(graph)
|
|
50
|
+
|
|
49
51
|
for n in graph_with_qcs.nodes:
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
52
|
+
# Set qc only when needed
|
|
53
|
+
quantize_node_weights = False
|
|
54
|
+
quantize_node_activations = False
|
|
55
|
+
|
|
56
|
+
if fw_info.in_kernel_ops(n):
|
|
57
|
+
quantize_node_weights = True
|
|
58
|
+
quantize_node_activations = n.output_quantization
|
|
59
|
+
elif fw_info.in_activation_ops(n):
|
|
60
|
+
quantize_node_activations = True
|
|
61
|
+
|
|
62
|
+
if quantize_node_activations:
|
|
63
|
+
# Create activation QC for this node
|
|
64
|
+
out_sc = graph_with_qcs.get_out_stats_collector(n)
|
|
65
|
+
sc = out_sc[0] if isinstance(out_sc, list) else out_sc
|
|
66
|
+
use_min_max = sc.use_min_max
|
|
67
|
+
n.activation_quantization_cfg = create_node_activation_qc(quant_config,
|
|
68
|
+
fw_info,
|
|
69
|
+
use_min_max)
|
|
70
|
+
if quantize_node_weights:
|
|
71
|
+
# Create weights QC for this node
|
|
72
|
+
weight_channel_axis = fw_info.kernel_channels_mapping.get(n.layer_class)[0]
|
|
73
|
+
n.candidates_weights_quantization_cfg = _create_node_candidates_weights_qc(quant_config,
|
|
74
|
+
fw_info,
|
|
75
|
+
weight_channel_axis)
|
|
53
76
|
return graph_with_qcs
|
|
54
77
|
|
|
55
78
|
|
|
56
|
-
def set_quantization_configs_to_node(node: BaseNode,
|
|
57
|
-
quant_config: QuantizationConfig,
|
|
58
|
-
fw_info: FrameworkInfo):
|
|
59
|
-
"""
|
|
60
|
-
Create and set quantization configurations to a node (for both weights and activation).
|
|
61
|
-
|
|
62
|
-
Args:
|
|
63
|
-
node: Node to set its quantization configurations.
|
|
64
|
-
quant_config: Quantization configuration to generate the node's configurations from.
|
|
65
|
-
fw_info: Information needed for quantization about the specific framework.
|
|
66
|
-
|
|
67
|
-
"""
|
|
68
|
-
# Create activation QC for this node
|
|
69
|
-
node.activation_quantization_cfg = create_node_activation_qc(quant_config,
|
|
70
|
-
fw_info)
|
|
71
|
-
|
|
72
|
-
enable_activation_quantization = quant_config.enable_activation_quantization and (fw_info.in_activation_ops(node) or fw_info.in_kernel_ops(node))
|
|
73
|
-
node.activation_quantization_cfg.enable_activation_quantization = enable_activation_quantization
|
|
74
|
-
|
|
75
|
-
# Create weights QC for this node
|
|
76
|
-
weight_channel_axis = fw_info.kernel_channels_mapping.get(node.layer_class)[0]
|
|
77
|
-
node.candidates_weights_quantization_cfg = _create_node_candidates_weights_qc(quant_config,
|
|
78
|
-
fw_info,
|
|
79
|
-
weight_channel_axis)
|
|
80
|
-
|
|
81
|
-
enable_weights_quantization = quant_config.enable_weights_quantization and fw_info.in_kernel_ops(node)
|
|
82
|
-
for qc in node.candidates_weights_quantization_cfg:
|
|
83
|
-
qc.enable_weights_quantization = enable_weights_quantization
|
|
84
|
-
|
|
85
|
-
|
|
86
79
|
def create_node_activation_qc(qc: QuantizationConfig,
|
|
87
|
-
fw_info: FrameworkInfo
|
|
80
|
+
fw_info: FrameworkInfo,
|
|
81
|
+
use_min_max: bool) -> NodeActivationQuantizationConfig:
|
|
88
82
|
"""
|
|
89
83
|
Create a activations quantization configuration from a QuantizationConfig object.
|
|
90
84
|
|
|
@@ -92,6 +86,7 @@ def create_node_activation_qc(qc: QuantizationConfig,
|
|
|
92
86
|
qc: QuantizationConfig to create the node's config from.
|
|
93
87
|
fw_info: Information about the specific framework the node was created from (e.g., whether or not its
|
|
94
88
|
weights/activations should be quantized)
|
|
89
|
+
use_min_max: Whether the collected min/max statistics should be used when the threshold is computed or not.
|
|
95
90
|
|
|
96
91
|
Returns:
|
|
97
92
|
Activation quantization configuration of a node.
|
|
@@ -102,7 +97,8 @@ def create_node_activation_qc(qc: QuantizationConfig,
|
|
|
102
97
|
Logger.critical('Unknown quantization method for activations')
|
|
103
98
|
|
|
104
99
|
activation_quantization_params_fn = get_activation_quantization_params_fn(qc.activation_quantization_method,
|
|
105
|
-
qc.activation_threshold_method
|
|
100
|
+
qc.activation_threshold_method,
|
|
101
|
+
use_min_max)
|
|
106
102
|
|
|
107
103
|
return NodeActivationQuantizationConfig(qc,
|
|
108
104
|
activation_quantization_fn,
|
|
@@ -139,9 +135,10 @@ def create_node_weights_qc(qc: QuantizationConfig,
|
|
|
139
135
|
weight_channel_axis)
|
|
140
136
|
|
|
141
137
|
|
|
138
|
+
|
|
142
139
|
def _create_node_candidates_weights_qc(qc: QuantizationConfig,
|
|
143
|
-
|
|
144
|
-
|
|
140
|
+
fw_info: FrameworkInfo,
|
|
141
|
+
weight_channel_axis: int) -> List[NodeWeightsQuantizationConfig]:
|
|
145
142
|
"""
|
|
146
143
|
Create a list of candidates of weights quantization configurations for a node.
|
|
147
144
|
|
|
@@ -164,4 +161,4 @@ def _create_node_candidates_weights_qc(qc: QuantizationConfig,
|
|
|
164
161
|
else:
|
|
165
162
|
candidats.append(create_node_weights_qc(qc, fw_info, weight_channel_axis))
|
|
166
163
|
|
|
167
|
-
return candidats
|
|
164
|
+
return candidats
|