mct-nightly 0.0.0__py3-none-any.whl → 1.1.0.01122021-003325__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/METADATA +3 -2
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/RECORD +31 -38
- model_compression_toolkit/__init__.py +2 -6
- model_compression_toolkit/common/base_substitutions.py +1 -0
- model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +9 -12
- model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +8 -21
- model_compression_toolkit/common/collectors/histogram_collector.py +1 -1
- model_compression_toolkit/common/graph/base_graph.py +2 -4
- model_compression_toolkit/common/graph/graph_matchers.py +3 -1
- model_compression_toolkit/common/graph/graph_searches.py +3 -1
- model_compression_toolkit/common/mixed_precision/bit_width_setter.py +1 -2
- model_compression_toolkit/common/network_editors/node_filters.py +1 -0
- model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py +1 -1
- model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +3 -5
- model_compression_toolkit/common/quantization/quantize_graph_weights.py +4 -7
- model_compression_toolkit/common/quantization/quantize_node.py +3 -5
- model_compression_toolkit/keras/__init__.py +2 -0
- model_compression_toolkit/keras/back2framework/model_builder.py +24 -1
- model_compression_toolkit/{common → keras/back2framework}/model_collector.py +9 -18
- model_compression_toolkit/keras/default_framework_info.py +0 -1
- model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +57 -10
- model_compression_toolkit/keras/graph_substitutions/substituter.py +171 -0
- model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +26 -6
- model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +12 -5
- model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +3 -4
- model_compression_toolkit/keras/quantization_facade.py +524 -188
- model_compression_toolkit/keras/reader/connectivity_handler.py +4 -1
- model_compression_toolkit/keras/visualization/nn_visualizer.py +1 -2
- model_compression_toolkit/common/framework_implementation.py +0 -239
- model_compression_toolkit/common/gptq/__init__.py +0 -14
- model_compression_toolkit/common/gptq/gptq_config.py +0 -65
- model_compression_toolkit/common/model_builder_mode.py +0 -34
- model_compression_toolkit/common/post_training_quantization.py +0 -459
- model_compression_toolkit/common/substitutions/__init__.py +0 -14
- model_compression_toolkit/common/substitutions/apply_substitutions.py +0 -40
- model_compression_toolkit/keras/keras_implementation.py +0 -256
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/LICENSE +0 -0
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/WHEEL +0 -0
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/top_level.txt +0 -0
|
@@ -14,6 +14,8 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
|
|
17
|
+
from enum import Enum
|
|
18
|
+
|
|
17
19
|
import tensorflow as tf
|
|
18
20
|
|
|
19
21
|
# As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
|
|
@@ -24,7 +26,6 @@ else:
|
|
|
24
26
|
from keras import Input
|
|
25
27
|
from keras.layers.core import TFOpLambda
|
|
26
28
|
|
|
27
|
-
from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
|
|
28
29
|
from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
|
|
29
30
|
from tensorflow.python.keras.layers import Layer
|
|
30
31
|
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
|
|
@@ -40,6 +41,7 @@ from model_compression_toolkit.keras.quantizer.gradient_ptq.config_factory impor
|
|
|
40
41
|
from model_compression_toolkit.common import Node, Graph
|
|
41
42
|
from model_compression_toolkit.common.graph.edge import EDGE_SINK_INDEX
|
|
42
43
|
from model_compression_toolkit.keras.back2framework.instance_builder import OperationHandler
|
|
44
|
+
from model_compression_toolkit.keras.graph_substitutions.substituter import pre_build_substitute
|
|
43
45
|
from model_compression_toolkit.keras.reader.connectivity_handler import OutTensor
|
|
44
46
|
|
|
45
47
|
# In tf2.3 fake quant node is implemented as TensorFlowOpLayer, while in tf2.4 as TFOpLambda.
|
|
@@ -48,6 +50,22 @@ FQ_NODE_OP_V2_4 = 'quantization.fake_quant_with_min_max_vars'
|
|
|
48
50
|
BATCH_INPUT_SHAPE = 'batch_input_shape'
|
|
49
51
|
|
|
50
52
|
|
|
53
|
+
class ModelBuilderMode(Enum):
|
|
54
|
+
"""
|
|
55
|
+
Mode for building the model back from a graph:
|
|
56
|
+
FLOAT - Build model for statistics collection. Model's outputs list contain all output tensors of all nodes
|
|
57
|
+
in the graph.
|
|
58
|
+
QUANTIZED - Build a quantized model using the nodes' quantization attributes for adding
|
|
59
|
+
quantization nodes to the model.
|
|
60
|
+
GPTQ - Build a quantized model using the nodes' quantization attributes for wrapping
|
|
61
|
+
layers with QuantizeWrapper and output comparing points.
|
|
62
|
+
"""
|
|
63
|
+
FLOAT = 0
|
|
64
|
+
QUANTIZED = 1
|
|
65
|
+
GPTQ = 2
|
|
66
|
+
MIXEDPRECISION = 3
|
|
67
|
+
|
|
68
|
+
|
|
51
69
|
def get_node_name_from_layer(layer: Layer) -> str:
|
|
52
70
|
"""
|
|
53
71
|
Get a node's name from the layer it was built from. For TensorFlowOpLayer
|
|
@@ -182,6 +200,11 @@ def model_builder(graph: common.Graph,
|
|
|
182
200
|
Returns:
|
|
183
201
|
A tuple of the model, and an UserInformation object.
|
|
184
202
|
"""
|
|
203
|
+
|
|
204
|
+
# For quantized models, first apply some substitutions.
|
|
205
|
+
if mode != ModelBuilderMode.FLOAT:
|
|
206
|
+
graph = pre_build_substitute(graph)
|
|
207
|
+
|
|
185
208
|
node_to_output_tensors_dict = dict()
|
|
186
209
|
model_output_tensors = []
|
|
187
210
|
|
|
@@ -14,14 +14,13 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
import numpy as np
|
|
18
17
|
from typing import List
|
|
19
18
|
|
|
20
|
-
|
|
21
|
-
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
22
21
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
23
22
|
from model_compression_toolkit.common.logger import Logger
|
|
24
|
-
from model_compression_toolkit.
|
|
23
|
+
from model_compression_toolkit.keras.back2framework.model_builder import model_builder, ModelBuilderMode
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
class ModelCollector(object):
|
|
@@ -32,21 +31,16 @@ class ModelCollector(object):
|
|
|
32
31
|
for thresholds calculations.
|
|
33
32
|
"""
|
|
34
33
|
|
|
35
|
-
def __init__(self, graph: Graph
|
|
34
|
+
def __init__(self, graph: Graph):
|
|
36
35
|
"""
|
|
37
36
|
Build a Keras model from the passed graph, and set the model's
|
|
38
37
|
outputs to be all layers' outputs.
|
|
39
38
|
|
|
40
39
|
Args:
|
|
41
40
|
graph: Graph to build a model from it.
|
|
42
|
-
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
43
|
-
|
|
44
41
|
"""
|
|
45
42
|
|
|
46
43
|
self.graph = graph
|
|
47
|
-
self.fw_impl = fw_impl
|
|
48
|
-
self.fw_info = fw_info
|
|
49
|
-
|
|
50
44
|
node2fetch = [] # List of graph nodes, the model should output their outputs.
|
|
51
45
|
stats_containers_list = [] # List of output statistics containers of nodes ordered
|
|
52
46
|
# the same as node2fetch so statistics of outputs can be gathered for the correct statistics container.
|
|
@@ -72,10 +66,9 @@ class ModelCollector(object):
|
|
|
72
66
|
|
|
73
67
|
# Build a float model and output all layers' outputs
|
|
74
68
|
# (that should be collected) as the model's outputs
|
|
75
|
-
self.model, _ =
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
fw_info=self.fw_info)
|
|
69
|
+
self.model, _ = model_builder(self.graph,
|
|
70
|
+
mode=ModelBuilderMode.FLOAT,
|
|
71
|
+
append2output=node2fetch)
|
|
79
72
|
|
|
80
73
|
def infer(self, inputs_list: List[np.ndarray]):
|
|
81
74
|
"""
|
|
@@ -84,9 +77,7 @@ class ModelCollector(object):
|
|
|
84
77
|
|
|
85
78
|
Args:
|
|
86
79
|
inputs_list: Inputs for the model inferring.
|
|
87
|
-
|
|
88
80
|
"""
|
|
89
|
-
|
|
90
81
|
# TODO: Thinking about delegating collections to framework
|
|
91
82
|
# TODO: migrate datasets to framework datasets
|
|
92
83
|
tensor_data = self.model(list(inputs_list))
|
|
@@ -98,6 +89,6 @@ class ModelCollector(object):
|
|
|
98
89
|
if len(sc) != len(td):
|
|
99
90
|
Logger.exception('"tensor_data" and the model tensor_list must be of the same length')
|
|
100
91
|
for tdi, sci in zip(td, sc):
|
|
101
|
-
sci.update_statistics(
|
|
92
|
+
sci.update_statistics(tdi.numpy())
|
|
102
93
|
else:
|
|
103
|
-
sc.update_statistics(
|
|
94
|
+
sc.update_statistics(td.numpy())
|
|
@@ -27,7 +27,6 @@ from model_compression_toolkit.common.quantization.quantizers.power_of_two_quant
|
|
|
27
27
|
from model_compression_toolkit.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
|
|
28
28
|
KERNEL, DEPTHWISE_KERNEL
|
|
29
29
|
from model_compression_toolkit.keras.quantizer.fake_quant_builder import constraint_quantization
|
|
30
|
-
|
|
31
30
|
"""
|
|
32
31
|
Division of Keras layers by how they should be quantized.
|
|
33
32
|
KERNEL_OPS: Layers that their coefficients should be quantized.
|
|
@@ -19,21 +19,71 @@ import tensorflow as tf
|
|
|
19
19
|
from tqdm import tqdm
|
|
20
20
|
|
|
21
21
|
from model_compression_toolkit import common
|
|
22
|
-
from model_compression_toolkit.common.gptq.gptq_config import GradientPTQConfig
|
|
23
22
|
from model_compression_toolkit.common import Graph
|
|
24
|
-
from model_compression_toolkit.keras.back2framework.model_builder import model_builder
|
|
25
|
-
from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
|
|
23
|
+
from model_compression_toolkit.keras.back2framework.model_builder import model_builder, ModelBuilderMode
|
|
26
24
|
from model_compression_toolkit.keras.gradient_ptq.graph_info import get_compare_points, \
|
|
27
25
|
get_trainable_parameters
|
|
28
26
|
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
29
27
|
from model_compression_toolkit.keras.gradient_ptq.graph_update import update_graph_after_gptq
|
|
28
|
+
from model_compression_toolkit.keras.gradient_ptq.gptq_loss import \
|
|
29
|
+
multiple_tensors_mse_loss
|
|
30
|
+
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
|
|
30
31
|
import numpy as np
|
|
31
32
|
|
|
32
33
|
|
|
34
|
+
class GradientPTQConfig:
|
|
35
|
+
"""
|
|
36
|
+
Configuration to use for quantization with GradientPTQ (experimental).
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self,
|
|
40
|
+
n_iter: int,
|
|
41
|
+
optimizer: OptimizerV2 = tf.keras.optimizers.Adam(learning_rate=0.0001),
|
|
42
|
+
loss: Callable = multiple_tensors_mse_loss,
|
|
43
|
+
log_function: Callable = None,
|
|
44
|
+
train_bias: bool = True,
|
|
45
|
+
representative_data_gen: Callable = None):
|
|
46
|
+
"""
|
|
47
|
+
Initialize a GradientPTQConfig.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
n_iter (int): Number of iterations to train.
|
|
51
|
+
optimizer (OptimizerV2): Optimizer to use.
|
|
52
|
+
loss (Callable): the loss to use. should accept 2 lists of tf.Tensor. 1st list are the quantized tensors, the 2nd the float tensors
|
|
53
|
+
log_function (Callable): Function to log information about the GPTQ process.
|
|
54
|
+
train_bias (bool): Whether to update the bias during the training or not.
|
|
55
|
+
representative_data_gen (Callable): Dataset generator.
|
|
56
|
+
|
|
57
|
+
Examples:
|
|
58
|
+
Create a GradientPTQConfig to run for 5 iteration and uses a random dataset generator:
|
|
59
|
+
|
|
60
|
+
>>> import numpy as np
|
|
61
|
+
>>> def repr_datagen(): return [np.random.random((1,224,224,3))]
|
|
62
|
+
>>> gptq_conf = GradientPTQConfig(n_iter=5, representative_data_gen=repr_datagen)
|
|
63
|
+
|
|
64
|
+
An optimizer can be passed:
|
|
65
|
+
|
|
66
|
+
>>> gptq_conf = GradientPTQConfig(n_iter=5, representative_data_gen=repr_datagen, optimizer=tf.keras.optimizers.Nadam(learning_rate=0.2))
|
|
67
|
+
|
|
68
|
+
To disable the biases training, one may set train_bias to False (enabled by default):
|
|
69
|
+
|
|
70
|
+
>>> gptq_conf = GradientPTQConfig(n_iter=5, representative_data_gen=repr_datagen, train_bias=False)
|
|
71
|
+
|
|
72
|
+
The configuration can then be passed to :func:`~model_compression_toolkit.keras_post_training_quantization`.
|
|
73
|
+
|
|
74
|
+
"""
|
|
75
|
+
self.n_iter = n_iter
|
|
76
|
+
self.optimizer = optimizer
|
|
77
|
+
self.loss = loss
|
|
78
|
+
self.log_function = log_function
|
|
79
|
+
self.train_bias = train_bias
|
|
80
|
+
self.representative_data_gen = representative_data_gen
|
|
81
|
+
|
|
82
|
+
|
|
33
83
|
def gptq_training_wrapper(tg: Graph,
|
|
34
84
|
representative_data_gen: Callable,
|
|
35
85
|
gptq_config: GradientPTQConfig,
|
|
36
|
-
fw_info: FrameworkInfo)
|
|
86
|
+
fw_info: FrameworkInfo):
|
|
37
87
|
"""
|
|
38
88
|
Build two models from a graph: A teacher network (float model) and a student network (quantized model).
|
|
39
89
|
Use the dataset generator to pass images through the teacher and student networks to get intermediate
|
|
@@ -57,17 +107,14 @@ def gptq_training_wrapper(tg: Graph,
|
|
|
57
107
|
#########################################
|
|
58
108
|
# Build two models and compare points
|
|
59
109
|
#########################################
|
|
60
|
-
# TODO: maybe need to add pre_build substitutions here. Ask Elad
|
|
61
110
|
compare_points, _ = get_compare_points(tg) # get compare points
|
|
62
111
|
n = len(compare_points)
|
|
63
112
|
float_model, float_user_info = model_builder(tg,
|
|
64
113
|
mode=ModelBuilderMode.FLOAT,
|
|
65
|
-
append2output=compare_points
|
|
66
|
-
fw_info=fw_info)
|
|
114
|
+
append2output=compare_points)
|
|
67
115
|
fxp_model, gptq_user_info = model_builder(tg,
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
fw_info=fw_info)
|
|
116
|
+
mode=ModelBuilderMode.GPTQ,
|
|
117
|
+
append2output=compare_points)
|
|
71
118
|
|
|
72
119
|
trainable_weights = get_trainable_parameters(fxp_model,
|
|
73
120
|
fw_info,
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
# Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
import copy
|
|
18
|
+
|
|
19
|
+
from typing import List
|
|
20
|
+
|
|
21
|
+
from model_compression_toolkit import common
|
|
22
|
+
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
23
|
+
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
24
|
+
from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
|
|
25
|
+
|
|
26
|
+
from model_compression_toolkit.keras.graph_substitutions.substitutions.activation_decomposition import \
|
|
27
|
+
ActivationDecomposition
|
|
28
|
+
from model_compression_toolkit.keras.graph_substitutions.substitutions.relu_bound_correction import \
|
|
29
|
+
ReLUBoundCorrection
|
|
30
|
+
from model_compression_toolkit.keras.graph_substitutions.substitutions.batchnorm_folding import \
|
|
31
|
+
BatchNormalizationFolding
|
|
32
|
+
from model_compression_toolkit.keras.graph_substitutions.substitutions.input_scaling import InputScaling, InputScalingWithPad
|
|
33
|
+
from model_compression_toolkit.keras.graph_substitutions.substitutions.mark_activation import MarkActivation
|
|
34
|
+
from model_compression_toolkit.keras.graph_substitutions.substitutions.remove_relu_upper_bound import \
|
|
35
|
+
RemoveReLUUpperBound
|
|
36
|
+
from model_compression_toolkit.keras.graph_substitutions.substitutions.scale_equalization import \
|
|
37
|
+
ScaleEqualization, ScaleEqualizationWithPad, \
|
|
38
|
+
ScaleEqualizationMidActivation, ScaleEqualizationMidActivationWithPad
|
|
39
|
+
from model_compression_toolkit.keras.graph_substitutions.substitutions.separableconv_decomposition import \
|
|
40
|
+
SeparableConvDecomposition
|
|
41
|
+
from model_compression_toolkit.keras.graph_substitutions.substitutions.shift_negative_activation import \
|
|
42
|
+
apply_shift_negative_correction
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def substitute(graph_to_substitute: common.Graph,
|
|
46
|
+
substitutions_list: List[common.BaseSubstitution]) -> common.Graph:
|
|
47
|
+
"""
|
|
48
|
+
Apply a list of substitutions on a graph.
|
|
49
|
+
Args:
|
|
50
|
+
graph: Graph to transform.
|
|
51
|
+
substitutions_list: List of substitutions to apply on the graph.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Transformed graph after applying all substitutions in substitutions_list.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
graph = copy.deepcopy(graph_to_substitute)
|
|
58
|
+
for substitution in substitutions_list:
|
|
59
|
+
matched_nodes = graph.filter(substitution.matcher_instance)
|
|
60
|
+
for idn in matched_nodes:
|
|
61
|
+
graph = substitution.substitute(graph, idn)
|
|
62
|
+
return graph
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def graph_marking_substitute(graph: Graph) -> Graph:
|
|
66
|
+
"""
|
|
67
|
+
Build a list of marking substitutions the graph should transformed according to (before statistics
|
|
68
|
+
are being collected), apply these substitutions on the graph and return the transformed graph.
|
|
69
|
+
Args:
|
|
70
|
+
graph: Graph to apply substitutions on.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Transformed graph after marking substitutions were applied.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
marking_substitutions_list = [MarkActivation()] # mark activation layers that their inputs should not be quantized
|
|
77
|
+
return substitute(graph,
|
|
78
|
+
marking_substitutions_list)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def pre_statistics_collection_substitute(graph: Graph) -> Graph:
|
|
82
|
+
"""
|
|
83
|
+
Build a list of substitutions the graph should transformed according to (before statistics
|
|
84
|
+
are being collected), apply these substitutions on the graph and return the transformed graph.
|
|
85
|
+
Args:
|
|
86
|
+
graph: Graph to apply substitutions on.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Transformed graph after substitutions.
|
|
90
|
+
"""
|
|
91
|
+
substitutions_list = [SeparableConvDecomposition(), # decompose separable node into depthwise and pointwise nodes
|
|
92
|
+
ActivationDecomposition(), # extract activation from linear op to an additional layer
|
|
93
|
+
BatchNormalizationFolding(), # fold batch normalization layer to the preceding linear layer
|
|
94
|
+
MarkActivation()] # mark activation layers that their inputs should not be quantized
|
|
95
|
+
|
|
96
|
+
return substitute(graph,
|
|
97
|
+
substitutions_list)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def post_statistics_collection_substitute(graph: Graph,
|
|
101
|
+
quant_config: QuantizationConfig,
|
|
102
|
+
fw_info: FrameworkInfo) -> Graph:
|
|
103
|
+
"""
|
|
104
|
+
Build a list of substitutions the graph should transformed according to (after statistics
|
|
105
|
+
were collected), apply these substitutions on the graph and return the transformed graph.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
graph: Graph to apply substitutions on.
|
|
109
|
+
quant_config: Quantization configuration to build the substitutions list according to.
|
|
110
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
|
111
|
+
groups of layers by how they should be quantized, etc.)
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Transformed graph after substitutions.
|
|
115
|
+
"""
|
|
116
|
+
substitutions_list = []
|
|
117
|
+
######################################
|
|
118
|
+
# Scale Activations
|
|
119
|
+
######################################
|
|
120
|
+
if quant_config.input_scaling:
|
|
121
|
+
substitutions_list.append(InputScaling(quant_config,
|
|
122
|
+
fw_info))
|
|
123
|
+
substitutions_list.append(InputScalingWithPad(quant_config,
|
|
124
|
+
fw_info))
|
|
125
|
+
|
|
126
|
+
######################################
|
|
127
|
+
# Scale Activations
|
|
128
|
+
######################################
|
|
129
|
+
if quant_config.relu_unbound_correction:
|
|
130
|
+
substitutions_list.append(ReLUBoundCorrection(quant_config,
|
|
131
|
+
fw_info))
|
|
132
|
+
|
|
133
|
+
if quant_config.activation_channel_equalization:
|
|
134
|
+
substitutions_list.append(ScaleEqualization(quant_config,
|
|
135
|
+
fw_info))
|
|
136
|
+
substitutions_list.append(ScaleEqualizationWithPad(quant_config,
|
|
137
|
+
fw_info))
|
|
138
|
+
substitutions_list.append(ScaleEqualizationMidActivation(quant_config,
|
|
139
|
+
fw_info))
|
|
140
|
+
substitutions_list.append(ScaleEqualizationMidActivationWithPad(quant_config,
|
|
141
|
+
fw_info))
|
|
142
|
+
|
|
143
|
+
######################################
|
|
144
|
+
# Shift Negative Activations
|
|
145
|
+
######################################
|
|
146
|
+
if quant_config.shift_negative_activation_correction:
|
|
147
|
+
graph = apply_shift_negative_correction(graph, quant_config, fw_info)
|
|
148
|
+
|
|
149
|
+
return substitute(graph,
|
|
150
|
+
substitutions_list)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def pre_build_substitute(graph: Graph,
|
|
154
|
+
remove_relu_bound: bool = True) -> Graph:
|
|
155
|
+
"""
|
|
156
|
+
Build a list of substitutions the graph should transformed according to (before building
|
|
157
|
+
the model back from its graph), apply these substitutions on the graph and return the transformed graph.
|
|
158
|
+
Args:
|
|
159
|
+
graph: Graph to apply substitutions on.
|
|
160
|
+
remove_relu_bound: Whether or not to remove bounds of bounded ReLUs in case the quantization threshold is
|
|
161
|
+
bound the maximal value anyway.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
Transformed graph after substitutions.
|
|
165
|
+
"""
|
|
166
|
+
substitutions_list = []
|
|
167
|
+
if remove_relu_bound:
|
|
168
|
+
substitutions_list.append(RemoveReLUUpperBound())
|
|
169
|
+
|
|
170
|
+
return substitute(graph,
|
|
171
|
+
substitutions_list)
|
|
@@ -47,6 +47,8 @@ class BaseInputScaling(common.BaseSubstitution):
|
|
|
47
47
|
"""
|
|
48
48
|
|
|
49
49
|
def __init__(self,
|
|
50
|
+
quantization_config: QuantizationConfig,
|
|
51
|
+
fw_info: FrameworkInfo,
|
|
50
52
|
matcher_instance):
|
|
51
53
|
"""
|
|
52
54
|
Matches: InputLayer -> (optional nodes) -> (Dense,Conv2D,DepthwiseConv2D,Conv2DTranspose)
|
|
@@ -55,9 +57,16 @@ class BaseInputScaling(common.BaseSubstitution):
|
|
|
55
57
|
Create a substitution using different params which may affect the way this substitution is made.
|
|
56
58
|
The substitution is looking for edges in the graph which are input layers connected to linear layers.
|
|
57
59
|
Args:
|
|
60
|
+
quantization_config: QuantizationConfig containing parameters of how the model should be quantized.
|
|
61
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
|
62
|
+
groups of layers by how they should be quantized, etc.)
|
|
58
63
|
matcher_instance: matcher instance of type WalkMatcher
|
|
59
64
|
|
|
60
65
|
"""
|
|
66
|
+
|
|
67
|
+
self.fw_info = fw_info
|
|
68
|
+
self.qc = quantization_config
|
|
69
|
+
|
|
61
70
|
super().__init__(matcher_instance=matcher_instance)
|
|
62
71
|
|
|
63
72
|
def substitute(self,
|
|
@@ -80,7 +89,6 @@ class BaseInputScaling(common.BaseSubstitution):
|
|
|
80
89
|
linear_layer = nodes_list[-1]
|
|
81
90
|
|
|
82
91
|
threshold = input_layer.activation_quantization_cfg.activation_quantization_params.get(THRESHOLD)
|
|
83
|
-
|
|
84
92
|
if threshold is None:
|
|
85
93
|
return graph
|
|
86
94
|
|
|
@@ -94,7 +102,7 @@ class BaseInputScaling(common.BaseSubstitution):
|
|
|
94
102
|
w1_fixed = linear_layer.get_weights_by_keys(KERNEL) * scale_factor
|
|
95
103
|
linear_layer.set_weights_by_keys(KERNEL, w1_fixed)
|
|
96
104
|
|
|
97
|
-
graph.scale_stats_collector(input_layer, 1
|
|
105
|
+
graph.scale_stats_collector(input_layer, 1/scale_factor)
|
|
98
106
|
|
|
99
107
|
# After scaling weights may have different thresholds so it needs to be recalculated
|
|
100
108
|
for nqc in linear_layer.candidates_weights_quantization_cfg:
|
|
@@ -108,12 +116,18 @@ class InputScaling(BaseInputScaling):
|
|
|
108
116
|
Substitution extends BaseInputScaling to the case of Input-->Linear
|
|
109
117
|
"""
|
|
110
118
|
|
|
111
|
-
def __init__(self
|
|
119
|
+
def __init__(self,
|
|
120
|
+
quant_config: QuantizationConfig,
|
|
121
|
+
fw_info: FrameworkInfo):
|
|
112
122
|
"""
|
|
113
123
|
Initialize a ScaleEqualization object.
|
|
124
|
+
Args:
|
|
125
|
+
quant_config: QuantizationConfig containing parameters of how the model should be quantized.
|
|
126
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
|
127
|
+
groups of layers by how they should be quantized, etc.)
|
|
114
128
|
"""
|
|
115
129
|
|
|
116
|
-
super().__init__(matcher_instance=INPUT_MATCHER)
|
|
130
|
+
super().__init__(quantization_config=quant_config, fw_info=fw_info, matcher_instance=INPUT_MATCHER)
|
|
117
131
|
|
|
118
132
|
|
|
119
133
|
class InputScalingWithPad(BaseInputScaling):
|
|
@@ -121,9 +135,15 @@ class InputScalingWithPad(BaseInputScaling):
|
|
|
121
135
|
Substitution extends BaseInputScaling to the case of Input-->ZeroPadding-->Linear
|
|
122
136
|
"""
|
|
123
137
|
|
|
124
|
-
def __init__(self
|
|
138
|
+
def __init__(self,
|
|
139
|
+
quant_config: QuantizationConfig,
|
|
140
|
+
fw_info: FrameworkInfo):
|
|
125
141
|
"""
|
|
126
142
|
Initialize a ScaleEqualization object.
|
|
143
|
+
Args:
|
|
144
|
+
quant_config: QuantizationConfig containing parameters of how the model should be quantized.
|
|
145
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
|
146
|
+
groups of layers by how they should be quantized, etc.)
|
|
127
147
|
"""
|
|
128
148
|
|
|
129
|
-
super().__init__(matcher_instance=INPUT_MATCHER_WITH_PAD)
|
|
149
|
+
super().__init__(quantization_config=quant_config, fw_info=fw_info, matcher_instance=INPUT_MATCHER_WITH_PAD)
|
|
@@ -30,20 +30,27 @@ from model_compression_toolkit.keras.constants import KERNEL, BIAS, ACTIVATION,
|
|
|
30
30
|
from model_compression_toolkit.keras.constants import RELU
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
33
|
class ReLUBoundCorrection(common.BaseSubstitution):
|
|
37
34
|
"""
|
|
38
35
|
Substitution to scale the weights of two linear nodes, and the bound of non-linear between them
|
|
39
36
|
(if bounded) in order to use the entire constrained range when activations are quantized.
|
|
40
37
|
"""
|
|
41
38
|
|
|
42
|
-
def __init__(self
|
|
39
|
+
def __init__(self,
|
|
40
|
+
quant_config: QuantizationConfig,
|
|
41
|
+
fw_info: FrameworkInfo):
|
|
43
42
|
"""
|
|
44
43
|
Initialize a ReLUBoundCorrection object.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
quant_config: QuantizationConfig containing parameters of how the model should be quantized.
|
|
47
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
|
48
|
+
groups of layers by how they should be quantized, etc.)
|
|
45
49
|
"""
|
|
46
50
|
|
|
51
|
+
self.fw_info = fw_info
|
|
52
|
+
self.quant_config = quant_config
|
|
53
|
+
|
|
47
54
|
homogeneous_activation_nodes = NodeOperationMatcher(ReLU) | \
|
|
48
55
|
NodeOperationMatcher(Activation) & \
|
|
49
56
|
NodeFrameworkAttrMatcher(ACTIVATION, RELU)
|
|
@@ -115,4 +122,4 @@ class ReLUBoundCorrection(common.BaseSubstitution):
|
|
|
115
122
|
for nqc in second_op2d_node.candidates_weights_quantization_cfg:
|
|
116
123
|
nqc.calculate_and_set_weights_params(w2_fixed)
|
|
117
124
|
|
|
118
|
-
return graph
|
|
125
|
+
return graph
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from
|
|
16
|
+
from keras.engine.base_layer_v1 import Layer
|
|
17
17
|
from tensorflow import Tensor
|
|
18
18
|
from tensorflow.keras.models import Model
|
|
19
19
|
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
|
|
@@ -24,8 +24,7 @@ from model_compression_toolkit.common import Node
|
|
|
24
24
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
25
25
|
from model_compression_toolkit.common.mixed_precision.mixed_precision_quantization_config import \
|
|
26
26
|
MixedPrecisionQuantizationConfig
|
|
27
|
-
from model_compression_toolkit.keras.back2framework.model_builder import model_builder
|
|
28
|
-
from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
|
|
27
|
+
from model_compression_toolkit.keras.back2framework.model_builder import ModelBuilderMode, model_builder
|
|
29
28
|
from model_compression_toolkit.keras.quantizer.mixed_precision.selective_weights_quantize_config import \
|
|
30
29
|
SelectiveWeightsQuantizeConfig
|
|
31
30
|
import numpy as np
|
|
@@ -35,7 +34,7 @@ def get_sensitivity_evaluation(graph: Graph,
|
|
|
35
34
|
quant_config: MixedPrecisionQuantizationConfig,
|
|
36
35
|
metrics_weights: np.ndarray,
|
|
37
36
|
representative_data_gen: Callable,
|
|
38
|
-
fw_info: FrameworkInfo)
|
|
37
|
+
fw_info: FrameworkInfo):
|
|
39
38
|
"""
|
|
40
39
|
Create a function to compute the sensitivity metric of an MP model (the sensitivity
|
|
41
40
|
is computed based on the similarity of the interest points' outputs between the MP model
|