mct-nightly 0.0.0__py3-none-any.whl → 1.1.0.02122021-003117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/METADATA +3 -2
  2. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/RECORD +31 -38
  3. model_compression_toolkit/__init__.py +2 -6
  4. model_compression_toolkit/common/base_substitutions.py +1 -0
  5. model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +9 -12
  6. model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +8 -21
  7. model_compression_toolkit/common/collectors/histogram_collector.py +1 -1
  8. model_compression_toolkit/common/graph/base_graph.py +2 -4
  9. model_compression_toolkit/common/graph/graph_matchers.py +3 -1
  10. model_compression_toolkit/common/graph/graph_searches.py +3 -1
  11. model_compression_toolkit/common/mixed_precision/bit_width_setter.py +1 -2
  12. model_compression_toolkit/common/network_editors/node_filters.py +1 -0
  13. model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py +1 -1
  14. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +3 -5
  15. model_compression_toolkit/common/quantization/quantize_graph_weights.py +4 -7
  16. model_compression_toolkit/common/quantization/quantize_node.py +3 -5
  17. model_compression_toolkit/keras/__init__.py +2 -0
  18. model_compression_toolkit/keras/back2framework/model_builder.py +24 -1
  19. model_compression_toolkit/{common → keras/back2framework}/model_collector.py +9 -18
  20. model_compression_toolkit/keras/default_framework_info.py +0 -1
  21. model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +57 -10
  22. model_compression_toolkit/keras/graph_substitutions/substituter.py +171 -0
  23. model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +26 -6
  24. model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +12 -5
  25. model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +3 -4
  26. model_compression_toolkit/keras/quantization_facade.py +524 -188
  27. model_compression_toolkit/keras/reader/connectivity_handler.py +4 -1
  28. model_compression_toolkit/keras/visualization/nn_visualizer.py +1 -2
  29. model_compression_toolkit/common/framework_implementation.py +0 -239
  30. model_compression_toolkit/common/gptq/__init__.py +0 -14
  31. model_compression_toolkit/common/gptq/gptq_config.py +0 -65
  32. model_compression_toolkit/common/model_builder_mode.py +0 -34
  33. model_compression_toolkit/common/post_training_quantization.py +0 -459
  34. model_compression_toolkit/common/substitutions/__init__.py +0 -14
  35. model_compression_toolkit/common/substitutions/apply_substitutions.py +0 -40
  36. model_compression_toolkit/keras/keras_implementation.py +0 -256
  37. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/LICENSE +0 -0
  38. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/WHEEL +0 -0
  39. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,8 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
+ from enum import Enum
18
+
17
19
  import tensorflow as tf
18
20
 
19
21
  # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
@@ -24,7 +26,6 @@ else:
24
26
  from keras import Input
25
27
  from keras.layers.core import TFOpLambda
26
28
 
27
- from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
28
29
  from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
29
30
  from tensorflow.python.keras.layers import Layer
30
31
  from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
@@ -40,6 +41,7 @@ from model_compression_toolkit.keras.quantizer.gradient_ptq.config_factory impor
40
41
  from model_compression_toolkit.common import Node, Graph
41
42
  from model_compression_toolkit.common.graph.edge import EDGE_SINK_INDEX
42
43
  from model_compression_toolkit.keras.back2framework.instance_builder import OperationHandler
44
+ from model_compression_toolkit.keras.graph_substitutions.substituter import pre_build_substitute
43
45
  from model_compression_toolkit.keras.reader.connectivity_handler import OutTensor
44
46
 
45
47
  # In tf2.3 fake quant node is implemented as TensorFlowOpLayer, while in tf2.4 as TFOpLambda.
@@ -48,6 +50,22 @@ FQ_NODE_OP_V2_4 = 'quantization.fake_quant_with_min_max_vars'
48
50
  BATCH_INPUT_SHAPE = 'batch_input_shape'
49
51
 
50
52
 
53
+ class ModelBuilderMode(Enum):
54
+ """
55
+ Mode for building the model back from a graph:
56
+ FLOAT - Build model for statistics collection. Model's outputs list contain all output tensors of all nodes
57
+ in the graph.
58
+ QUANTIZED - Build a quantized model using the nodes' quantization attributes for adding
59
+ quantization nodes to the model.
60
+ GPTQ - Build a quantized model using the nodes' quantization attributes for wrapping
61
+ layers with QuantizeWrapper and output comparing points.
62
+ """
63
+ FLOAT = 0
64
+ QUANTIZED = 1
65
+ GPTQ = 2
66
+ MIXEDPRECISION = 3
67
+
68
+
51
69
  def get_node_name_from_layer(layer: Layer) -> str:
52
70
  """
53
71
  Get a node's name from the layer it was built from. For TensorFlowOpLayer
@@ -182,6 +200,11 @@ def model_builder(graph: common.Graph,
182
200
  Returns:
183
201
  A tuple of the model, and an UserInformation object.
184
202
  """
203
+
204
+ # For quantized models, first apply some substitutions.
205
+ if mode != ModelBuilderMode.FLOAT:
206
+ graph = pre_build_substitute(graph)
207
+
185
208
  node_to_output_tensors_dict = dict()
186
209
  model_output_tensors = []
187
210
 
@@ -14,14 +14,13 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- import numpy as np
18
17
  from typing import List
19
18
 
20
- from model_compression_toolkit import FrameworkInfo
21
- from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
19
+ import numpy as np
20
+
22
21
  from model_compression_toolkit.common.graph.base_graph import Graph
23
22
  from model_compression_toolkit.common.logger import Logger
24
- from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
23
+ from model_compression_toolkit.keras.back2framework.model_builder import model_builder, ModelBuilderMode
25
24
 
26
25
 
27
26
  class ModelCollector(object):
@@ -32,21 +31,16 @@ class ModelCollector(object):
32
31
  for thresholds calculations.
33
32
  """
34
33
 
35
- def __init__(self, graph: Graph, fw_impl: FrameworkImplementation, fw_info: FrameworkInfo):
34
+ def __init__(self, graph: Graph):
36
35
  """
37
36
  Build a Keras model from the passed graph, and set the model's
38
37
  outputs to be all layers' outputs.
39
38
 
40
39
  Args:
41
40
  graph: Graph to build a model from it.
42
- fw_impl: FrameworkImplementation object with a specific framework methods implementation.
43
-
44
41
  """
45
42
 
46
43
  self.graph = graph
47
- self.fw_impl = fw_impl
48
- self.fw_info = fw_info
49
-
50
44
  node2fetch = [] # List of graph nodes, the model should output their outputs.
51
45
  stats_containers_list = [] # List of output statistics containers of nodes ordered
52
46
  # the same as node2fetch so statistics of outputs can be gathered for the correct statistics container.
@@ -72,10 +66,9 @@ class ModelCollector(object):
72
66
 
73
67
  # Build a float model and output all layers' outputs
74
68
  # (that should be collected) as the model's outputs
75
- self.model, _ = self.fw_impl.model_builder(self.graph,
76
- mode=ModelBuilderMode.FLOAT,
77
- append2output=node2fetch,
78
- fw_info=self.fw_info)
69
+ self.model, _ = model_builder(self.graph,
70
+ mode=ModelBuilderMode.FLOAT,
71
+ append2output=node2fetch)
79
72
 
80
73
  def infer(self, inputs_list: List[np.ndarray]):
81
74
  """
@@ -84,9 +77,7 @@ class ModelCollector(object):
84
77
 
85
78
  Args:
86
79
  inputs_list: Inputs for the model inferring.
87
-
88
80
  """
89
-
90
81
  # TODO: Thinking about delegating collections to framework
91
82
  # TODO: migrate datasets to framework datasets
92
83
  tensor_data = self.model(list(inputs_list))
@@ -98,6 +89,6 @@ class ModelCollector(object):
98
89
  if len(sc) != len(td):
99
90
  Logger.exception('"tensor_data" and the model tensor_list must be of the same length')
100
91
  for tdi, sci in zip(td, sc):
101
- sci.update_statistics(self.fw_impl.to_numpy(tdi))
92
+ sci.update_statistics(tdi.numpy())
102
93
  else:
103
- sc.update_statistics(self.fw_impl.to_numpy(td))
94
+ sc.update_statistics(td.numpy())
@@ -27,7 +27,6 @@ from model_compression_toolkit.common.quantization.quantizers.power_of_two_quant
27
27
  from model_compression_toolkit.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
28
28
  KERNEL, DEPTHWISE_KERNEL
29
29
  from model_compression_toolkit.keras.quantizer.fake_quant_builder import constraint_quantization
30
-
31
30
  """
32
31
  Division of Keras layers by how they should be quantized.
33
32
  KERNEL_OPS: Layers that their coefficients should be quantized.
@@ -19,21 +19,71 @@ import tensorflow as tf
19
19
  from tqdm import tqdm
20
20
 
21
21
  from model_compression_toolkit import common
22
- from model_compression_toolkit.common.gptq.gptq_config import GradientPTQConfig
23
22
  from model_compression_toolkit.common import Graph
24
- from model_compression_toolkit.keras.back2framework.model_builder import model_builder
25
- from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
23
+ from model_compression_toolkit.keras.back2framework.model_builder import model_builder, ModelBuilderMode
26
24
  from model_compression_toolkit.keras.gradient_ptq.graph_info import get_compare_points, \
27
25
  get_trainable_parameters
28
26
  from model_compression_toolkit.common.framework_info import FrameworkInfo
29
27
  from model_compression_toolkit.keras.gradient_ptq.graph_update import update_graph_after_gptq
28
+ from model_compression_toolkit.keras.gradient_ptq.gptq_loss import \
29
+ multiple_tensors_mse_loss
30
+ from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
30
31
  import numpy as np
31
32
 
32
33
 
34
+ class GradientPTQConfig:
35
+ """
36
+ Configuration to use for quantization with GradientPTQ (experimental).
37
+ """
38
+
39
+ def __init__(self,
40
+ n_iter: int,
41
+ optimizer: OptimizerV2 = tf.keras.optimizers.Adam(learning_rate=0.0001),
42
+ loss: Callable = multiple_tensors_mse_loss,
43
+ log_function: Callable = None,
44
+ train_bias: bool = True,
45
+ representative_data_gen: Callable = None):
46
+ """
47
+ Initialize a GradientPTQConfig.
48
+
49
+ Args:
50
+ n_iter (int): Number of iterations to train.
51
+ optimizer (OptimizerV2): Optimizer to use.
52
+ loss (Callable): the loss to use. should accept 2 lists of tf.Tensor. 1st list are the quantized tensors, the 2nd the float tensors
53
+ log_function (Callable): Function to log information about the GPTQ process.
54
+ train_bias (bool): Whether to update the bias during the training or not.
55
+ representative_data_gen (Callable): Dataset generator.
56
+
57
+ Examples:
58
+ Create a GradientPTQConfig to run for 5 iteration and uses a random dataset generator:
59
+
60
+ >>> import numpy as np
61
+ >>> def repr_datagen(): return [np.random.random((1,224,224,3))]
62
+ >>> gptq_conf = GradientPTQConfig(n_iter=5, representative_data_gen=repr_datagen)
63
+
64
+ An optimizer can be passed:
65
+
66
+ >>> gptq_conf = GradientPTQConfig(n_iter=5, representative_data_gen=repr_datagen, optimizer=tf.keras.optimizers.Nadam(learning_rate=0.2))
67
+
68
+ To disable the biases training, one may set train_bias to False (enabled by default):
69
+
70
+ >>> gptq_conf = GradientPTQConfig(n_iter=5, representative_data_gen=repr_datagen, train_bias=False)
71
+
72
+ The configuration can then be passed to :func:`~model_compression_toolkit.keras_post_training_quantization`.
73
+
74
+ """
75
+ self.n_iter = n_iter
76
+ self.optimizer = optimizer
77
+ self.loss = loss
78
+ self.log_function = log_function
79
+ self.train_bias = train_bias
80
+ self.representative_data_gen = representative_data_gen
81
+
82
+
33
83
  def gptq_training_wrapper(tg: Graph,
34
84
  representative_data_gen: Callable,
35
85
  gptq_config: GradientPTQConfig,
36
- fw_info: FrameworkInfo) -> Graph:
86
+ fw_info: FrameworkInfo):
37
87
  """
38
88
  Build two models from a graph: A teacher network (float model) and a student network (quantized model).
39
89
  Use the dataset generator to pass images through the teacher and student networks to get intermediate
@@ -57,17 +107,14 @@ def gptq_training_wrapper(tg: Graph,
57
107
  #########################################
58
108
  # Build two models and compare points
59
109
  #########################################
60
- # TODO: maybe need to add pre_build substitutions here. Ask Elad
61
110
  compare_points, _ = get_compare_points(tg) # get compare points
62
111
  n = len(compare_points)
63
112
  float_model, float_user_info = model_builder(tg,
64
113
  mode=ModelBuilderMode.FLOAT,
65
- append2output=compare_points,
66
- fw_info=fw_info)
114
+ append2output=compare_points)
67
115
  fxp_model, gptq_user_info = model_builder(tg,
68
- mode=ModelBuilderMode.GPTQ,
69
- append2output=compare_points,
70
- fw_info=fw_info)
116
+ mode=ModelBuilderMode.GPTQ,
117
+ append2output=compare_points)
71
118
 
72
119
  trainable_weights = get_trainable_parameters(fxp_model,
73
120
  fw_info,
@@ -0,0 +1,171 @@
1
+ # Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import copy
18
+
19
+ from typing import List
20
+
21
+ from model_compression_toolkit import common
22
+ from model_compression_toolkit.common.framework_info import FrameworkInfo
23
+ from model_compression_toolkit.common.graph.base_graph import Graph
24
+ from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
25
+
26
+ from model_compression_toolkit.keras.graph_substitutions.substitutions.activation_decomposition import \
27
+ ActivationDecomposition
28
+ from model_compression_toolkit.keras.graph_substitutions.substitutions.relu_bound_correction import \
29
+ ReLUBoundCorrection
30
+ from model_compression_toolkit.keras.graph_substitutions.substitutions.batchnorm_folding import \
31
+ BatchNormalizationFolding
32
+ from model_compression_toolkit.keras.graph_substitutions.substitutions.input_scaling import InputScaling, InputScalingWithPad
33
+ from model_compression_toolkit.keras.graph_substitutions.substitutions.mark_activation import MarkActivation
34
+ from model_compression_toolkit.keras.graph_substitutions.substitutions.remove_relu_upper_bound import \
35
+ RemoveReLUUpperBound
36
+ from model_compression_toolkit.keras.graph_substitutions.substitutions.scale_equalization import \
37
+ ScaleEqualization, ScaleEqualizationWithPad, \
38
+ ScaleEqualizationMidActivation, ScaleEqualizationMidActivationWithPad
39
+ from model_compression_toolkit.keras.graph_substitutions.substitutions.separableconv_decomposition import \
40
+ SeparableConvDecomposition
41
+ from model_compression_toolkit.keras.graph_substitutions.substitutions.shift_negative_activation import \
42
+ apply_shift_negative_correction
43
+
44
+
45
+ def substitute(graph_to_substitute: common.Graph,
46
+ substitutions_list: List[common.BaseSubstitution]) -> common.Graph:
47
+ """
48
+ Apply a list of substitutions on a graph.
49
+ Args:
50
+ graph: Graph to transform.
51
+ substitutions_list: List of substitutions to apply on the graph.
52
+
53
+ Returns:
54
+ Transformed graph after applying all substitutions in substitutions_list.
55
+ """
56
+
57
+ graph = copy.deepcopy(graph_to_substitute)
58
+ for substitution in substitutions_list:
59
+ matched_nodes = graph.filter(substitution.matcher_instance)
60
+ for idn in matched_nodes:
61
+ graph = substitution.substitute(graph, idn)
62
+ return graph
63
+
64
+
65
+ def graph_marking_substitute(graph: Graph) -> Graph:
66
+ """
67
+ Build a list of marking substitutions the graph should transformed according to (before statistics
68
+ are being collected), apply these substitutions on the graph and return the transformed graph.
69
+ Args:
70
+ graph: Graph to apply substitutions on.
71
+
72
+ Returns:
73
+ Transformed graph after marking substitutions were applied.
74
+ """
75
+
76
+ marking_substitutions_list = [MarkActivation()] # mark activation layers that their inputs should not be quantized
77
+ return substitute(graph,
78
+ marking_substitutions_list)
79
+
80
+
81
+ def pre_statistics_collection_substitute(graph: Graph) -> Graph:
82
+ """
83
+ Build a list of substitutions the graph should transformed according to (before statistics
84
+ are being collected), apply these substitutions on the graph and return the transformed graph.
85
+ Args:
86
+ graph: Graph to apply substitutions on.
87
+
88
+ Returns:
89
+ Transformed graph after substitutions.
90
+ """
91
+ substitutions_list = [SeparableConvDecomposition(), # decompose separable node into depthwise and pointwise nodes
92
+ ActivationDecomposition(), # extract activation from linear op to an additional layer
93
+ BatchNormalizationFolding(), # fold batch normalization layer to the preceding linear layer
94
+ MarkActivation()] # mark activation layers that their inputs should not be quantized
95
+
96
+ return substitute(graph,
97
+ substitutions_list)
98
+
99
+
100
+ def post_statistics_collection_substitute(graph: Graph,
101
+ quant_config: QuantizationConfig,
102
+ fw_info: FrameworkInfo) -> Graph:
103
+ """
104
+ Build a list of substitutions the graph should transformed according to (after statistics
105
+ were collected), apply these substitutions on the graph and return the transformed graph.
106
+
107
+ Args:
108
+ graph: Graph to apply substitutions on.
109
+ quant_config: Quantization configuration to build the substitutions list according to.
110
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
111
+ groups of layers by how they should be quantized, etc.)
112
+
113
+ Returns:
114
+ Transformed graph after substitutions.
115
+ """
116
+ substitutions_list = []
117
+ ######################################
118
+ # Scale Activations
119
+ ######################################
120
+ if quant_config.input_scaling:
121
+ substitutions_list.append(InputScaling(quant_config,
122
+ fw_info))
123
+ substitutions_list.append(InputScalingWithPad(quant_config,
124
+ fw_info))
125
+
126
+ ######################################
127
+ # Scale Activations
128
+ ######################################
129
+ if quant_config.relu_unbound_correction:
130
+ substitutions_list.append(ReLUBoundCorrection(quant_config,
131
+ fw_info))
132
+
133
+ if quant_config.activation_channel_equalization:
134
+ substitutions_list.append(ScaleEqualization(quant_config,
135
+ fw_info))
136
+ substitutions_list.append(ScaleEqualizationWithPad(quant_config,
137
+ fw_info))
138
+ substitutions_list.append(ScaleEqualizationMidActivation(quant_config,
139
+ fw_info))
140
+ substitutions_list.append(ScaleEqualizationMidActivationWithPad(quant_config,
141
+ fw_info))
142
+
143
+ ######################################
144
+ # Shift Negative Activations
145
+ ######################################
146
+ if quant_config.shift_negative_activation_correction:
147
+ graph = apply_shift_negative_correction(graph, quant_config, fw_info)
148
+
149
+ return substitute(graph,
150
+ substitutions_list)
151
+
152
+
153
+ def pre_build_substitute(graph: Graph,
154
+ remove_relu_bound: bool = True) -> Graph:
155
+ """
156
+ Build a list of substitutions the graph should transformed according to (before building
157
+ the model back from its graph), apply these substitutions on the graph and return the transformed graph.
158
+ Args:
159
+ graph: Graph to apply substitutions on.
160
+ remove_relu_bound: Whether or not to remove bounds of bounded ReLUs in case the quantization threshold is
161
+ bound the maximal value anyway.
162
+
163
+ Returns:
164
+ Transformed graph after substitutions.
165
+ """
166
+ substitutions_list = []
167
+ if remove_relu_bound:
168
+ substitutions_list.append(RemoveReLUUpperBound())
169
+
170
+ return substitute(graph,
171
+ substitutions_list)
@@ -47,6 +47,8 @@ class BaseInputScaling(common.BaseSubstitution):
47
47
  """
48
48
 
49
49
  def __init__(self,
50
+ quantization_config: QuantizationConfig,
51
+ fw_info: FrameworkInfo,
50
52
  matcher_instance):
51
53
  """
52
54
  Matches: InputLayer -> (optional nodes) -> (Dense,Conv2D,DepthwiseConv2D,Conv2DTranspose)
@@ -55,9 +57,16 @@ class BaseInputScaling(common.BaseSubstitution):
55
57
  Create a substitution using different params which may affect the way this substitution is made.
56
58
  The substitution is looking for edges in the graph which are input layers connected to linear layers.
57
59
  Args:
60
+ quantization_config: QuantizationConfig containing parameters of how the model should be quantized.
61
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
62
+ groups of layers by how they should be quantized, etc.)
58
63
  matcher_instance: matcher instance of type WalkMatcher
59
64
 
60
65
  """
66
+
67
+ self.fw_info = fw_info
68
+ self.qc = quantization_config
69
+
61
70
  super().__init__(matcher_instance=matcher_instance)
62
71
 
63
72
  def substitute(self,
@@ -80,7 +89,6 @@ class BaseInputScaling(common.BaseSubstitution):
80
89
  linear_layer = nodes_list[-1]
81
90
 
82
91
  threshold = input_layer.activation_quantization_cfg.activation_quantization_params.get(THRESHOLD)
83
-
84
92
  if threshold is None:
85
93
  return graph
86
94
 
@@ -94,7 +102,7 @@ class BaseInputScaling(common.BaseSubstitution):
94
102
  w1_fixed = linear_layer.get_weights_by_keys(KERNEL) * scale_factor
95
103
  linear_layer.set_weights_by_keys(KERNEL, w1_fixed)
96
104
 
97
- graph.scale_stats_collector(input_layer, 1 / scale_factor)
105
+ graph.scale_stats_collector(input_layer, 1/scale_factor)
98
106
 
99
107
  # After scaling weights may have different thresholds so it needs to be recalculated
100
108
  for nqc in linear_layer.candidates_weights_quantization_cfg:
@@ -108,12 +116,18 @@ class InputScaling(BaseInputScaling):
108
116
  Substitution extends BaseInputScaling to the case of Input-->Linear
109
117
  """
110
118
 
111
- def __init__(self):
119
+ def __init__(self,
120
+ quant_config: QuantizationConfig,
121
+ fw_info: FrameworkInfo):
112
122
  """
113
123
  Initialize a ScaleEqualization object.
124
+ Args:
125
+ quant_config: QuantizationConfig containing parameters of how the model should be quantized.
126
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
127
+ groups of layers by how they should be quantized, etc.)
114
128
  """
115
129
 
116
- super().__init__(matcher_instance=INPUT_MATCHER)
130
+ super().__init__(quantization_config=quant_config, fw_info=fw_info, matcher_instance=INPUT_MATCHER)
117
131
 
118
132
 
119
133
  class InputScalingWithPad(BaseInputScaling):
@@ -121,9 +135,15 @@ class InputScalingWithPad(BaseInputScaling):
121
135
  Substitution extends BaseInputScaling to the case of Input-->ZeroPadding-->Linear
122
136
  """
123
137
 
124
- def __init__(self):
138
+ def __init__(self,
139
+ quant_config: QuantizationConfig,
140
+ fw_info: FrameworkInfo):
125
141
  """
126
142
  Initialize a ScaleEqualization object.
143
+ Args:
144
+ quant_config: QuantizationConfig containing parameters of how the model should be quantized.
145
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
146
+ groups of layers by how they should be quantized, etc.)
127
147
  """
128
148
 
129
- super().__init__(matcher_instance=INPUT_MATCHER_WITH_PAD)
149
+ super().__init__(quantization_config=quant_config, fw_info=fw_info, matcher_instance=INPUT_MATCHER_WITH_PAD)
@@ -30,20 +30,27 @@ from model_compression_toolkit.keras.constants import KERNEL, BIAS, ACTIVATION,
30
30
  from model_compression_toolkit.keras.constants import RELU
31
31
 
32
32
 
33
-
34
-
35
-
36
33
  class ReLUBoundCorrection(common.BaseSubstitution):
37
34
  """
38
35
  Substitution to scale the weights of two linear nodes, and the bound of non-linear between them
39
36
  (if bounded) in order to use the entire constrained range when activations are quantized.
40
37
  """
41
38
 
42
- def __init__(self):
39
+ def __init__(self,
40
+ quant_config: QuantizationConfig,
41
+ fw_info: FrameworkInfo):
43
42
  """
44
43
  Initialize a ReLUBoundCorrection object.
44
+
45
+ Args:
46
+ quant_config: QuantizationConfig containing parameters of how the model should be quantized.
47
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
48
+ groups of layers by how they should be quantized, etc.)
45
49
  """
46
50
 
51
+ self.fw_info = fw_info
52
+ self.quant_config = quant_config
53
+
47
54
  homogeneous_activation_nodes = NodeOperationMatcher(ReLU) | \
48
55
  NodeOperationMatcher(Activation) & \
49
56
  NodeFrameworkAttrMatcher(ACTIVATION, RELU)
@@ -115,4 +122,4 @@ class ReLUBoundCorrection(common.BaseSubstitution):
115
122
  for nqc in second_op2d_node.candidates_weights_quantization_cfg:
116
123
  nqc.calculate_and_set_weights_params(w2_fixed)
117
124
 
118
- return graph
125
+ return graph
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from tensorflow.python.layers.base import Layer
16
+ from keras.engine.base_layer_v1 import Layer
17
17
  from tensorflow import Tensor
18
18
  from tensorflow.keras.models import Model
19
19
  from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
@@ -24,8 +24,7 @@ from model_compression_toolkit.common import Node
24
24
  from model_compression_toolkit.common.graph.base_graph import Graph
25
25
  from model_compression_toolkit.common.mixed_precision.mixed_precision_quantization_config import \
26
26
  MixedPrecisionQuantizationConfig
27
- from model_compression_toolkit.keras.back2framework.model_builder import model_builder
28
- from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
27
+ from model_compression_toolkit.keras.back2framework.model_builder import ModelBuilderMode, model_builder
29
28
  from model_compression_toolkit.keras.quantizer.mixed_precision.selective_weights_quantize_config import \
30
29
  SelectiveWeightsQuantizeConfig
31
30
  import numpy as np
@@ -35,7 +34,7 @@ def get_sensitivity_evaluation(graph: Graph,
35
34
  quant_config: MixedPrecisionQuantizationConfig,
36
35
  metrics_weights: np.ndarray,
37
36
  representative_data_gen: Callable,
38
- fw_info: FrameworkInfo) -> Callable:
37
+ fw_info: FrameworkInfo):
39
38
  """
40
39
  Create a function to compute the sensitivity metric of an MP model (the sensitivity
41
40
  is computed based on the similarity of the interest points' outputs between the MP model