mct-nightly 1.1.0.7012022.post2611__py3-none-any.whl → 1.1.0.07122021-002414__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/METADATA +3 -3
- {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/RECORD +72 -76
- {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/WHEEL +1 -1
- model_compression_toolkit/common/__init__.py +2 -2
- model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +2 -2
- model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +9 -9
- model_compression_toolkit/common/collectors/mean_collector.py +2 -3
- model_compression_toolkit/common/collectors/min_max_per_channel_collector.py +3 -6
- model_compression_toolkit/common/constants.py +0 -1
- model_compression_toolkit/common/framework_implementation.py +6 -22
- model_compression_toolkit/common/framework_info.py +7 -39
- model_compression_toolkit/common/graph/__init__.py +1 -1
- model_compression_toolkit/common/graph/base_graph.py +34 -34
- model_compression_toolkit/common/graph/edge.py +3 -3
- model_compression_toolkit/common/graph/graph_matchers.py +3 -3
- model_compression_toolkit/common/graph/graph_searches.py +4 -4
- model_compression_toolkit/common/graph/graph_vis.py +116 -0
- model_compression_toolkit/common/graph/{base_node.py → node.py} +27 -11
- model_compression_toolkit/common/mixed_precision/bit_width_setter.py +9 -10
- model_compression_toolkit/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/common/model_collector.py +12 -14
- model_compression_toolkit/common/network_editors/actions.py +23 -19
- model_compression_toolkit/common/post_training_quantization.py +7 -20
- model_compression_toolkit/common/quantization/node_quantization_config.py +5 -13
- model_compression_toolkit/common/quantization/quantization_analyzer.py +7 -11
- model_compression_toolkit/common/quantization/quantization_config.py +6 -6
- model_compression_toolkit/common/quantization/quantization_params_fn_selection.py +3 -2
- model_compression_toolkit/common/quantization/quantization_params_generation/qparams_activations_computation.py +7 -13
- model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +20 -17
- model_compression_toolkit/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/common/quantization/set_node_quantization_config.py +36 -39
- model_compression_toolkit/common/{collectors/statistics_collector.py → statistics_collector.py} +30 -26
- model_compression_toolkit/common/visualization/tensorboard_writer.py +8 -11
- model_compression_toolkit/keras/back2framework/instance_builder.py +4 -4
- model_compression_toolkit/keras/back2framework/model_builder.py +34 -47
- model_compression_toolkit/keras/constants.py +0 -3
- model_compression_toolkit/keras/default_framework_info.py +7 -33
- model_compression_toolkit/keras/gradient_ptq/graph_info.py +2 -2
- model_compression_toolkit/keras/gradient_ptq/graph_update.py +1 -7
- model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +1 -0
- model_compression_toolkit/keras/graph_substitutions/substitutions/activation_decomposition.py +8 -10
- model_compression_toolkit/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -2
- model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +2 -2
- model_compression_toolkit/keras/graph_substitutions/substitutions/mark_activation.py +3 -3
- model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +4 -3
- model_compression_toolkit/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +2 -2
- model_compression_toolkit/keras/graph_substitutions/substitutions/scale_equalization.py +9 -9
- model_compression_toolkit/keras/graph_substitutions/substitutions/separableconv_decomposition.py +19 -19
- model_compression_toolkit/keras/graph_substitutions/substitutions/shift_negative_activation.py +45 -64
- model_compression_toolkit/keras/keras_implementation.py +8 -28
- model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/keras/quantization_facade.py +1 -5
- model_compression_toolkit/keras/quantizer/fake_quant_builder.py +4 -4
- model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer.py +2 -3
- model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer_gptq_config.py +4 -8
- model_compression_toolkit/keras/quantizer/gradient_ptq/activation_weight_quantizer_gptq_config.py +4 -9
- model_compression_toolkit/keras/quantizer/gradient_ptq/config_factory.py +10 -9
- model_compression_toolkit/keras/quantizer/gradient_ptq/weight_quantizer_gptq_config.py +1 -9
- model_compression_toolkit/keras/quantizer/mixed_precision/quantization_config_factory.py +1 -1
- model_compression_toolkit/keras/quantizer/mixed_precision/selective_weights_quantize_config.py +1 -6
- model_compression_toolkit/keras/reader/common.py +11 -9
- model_compression_toolkit/keras/reader/connectivity_handler.py +9 -15
- model_compression_toolkit/keras/reader/nested_model/edges_merger.py +6 -6
- model_compression_toolkit/keras/reader/nested_model/nested_model_handler.py +2 -2
- model_compression_toolkit/keras/reader/nested_model/nodes_merger.py +3 -3
- model_compression_toolkit/keras/reader/nested_model/outputs_merger.py +2 -2
- model_compression_toolkit/keras/reader/node_builder.py +15 -65
- model_compression_toolkit/keras/reader/reader.py +5 -5
- model_compression_toolkit/keras/tensor_marking.py +113 -0
- model_compression_toolkit/keras/visualization/nn_visualizer.py +2 -2
- model_compression_toolkit/common/collectors/statistics_collector_generator.py +0 -43
- model_compression_toolkit/common/graph/functional_node.py +0 -59
- model_compression_toolkit/common/model_validation.py +0 -43
- model_compression_toolkit/common/node_prior_info.py +0 -29
- model_compression_toolkit/keras/keras_model_validation.py +0 -38
- model_compression_toolkit/keras/keras_node_prior_info.py +0 -60
- {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/LICENSE +0 -0
- {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/top_level.txt +0 -0
model_compression_toolkit/common/{collectors/statistics_collector.py → statistics_collector.py}
RENAMED
|
@@ -20,17 +20,20 @@ from typing import Any, Tuple
|
|
|
20
20
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
|
|
23
|
-
from model_compression_toolkit.common.framework_info import FrameworkInfo, ChannelAxis
|
|
24
23
|
from model_compression_toolkit.common.collectors.histogram_collector import HistogramCollector
|
|
25
24
|
from model_compression_toolkit.common.collectors.mean_collector import MeanCollector
|
|
26
25
|
from model_compression_toolkit.common.collectors.min_max_per_channel_collector import MinMaxPerChannelCollector
|
|
27
26
|
|
|
28
27
|
|
|
29
|
-
class
|
|
28
|
+
class BaseStatsContainer(object):
|
|
30
29
|
"""
|
|
31
|
-
Base class for statistics collection (
|
|
30
|
+
Base class for statistics collection container (contain multiple statistics collector such as mean collector,
|
|
32
31
|
histogram collector, etc.).
|
|
33
32
|
"""
|
|
33
|
+
def __init__(self):
|
|
34
|
+
# Disable histogram collection. Enable in specific collectors if needed
|
|
35
|
+
self.collect_histogram = False
|
|
36
|
+
self.use_min_max = False
|
|
34
37
|
|
|
35
38
|
def require_collection(self) -> bool:
|
|
36
39
|
"""
|
|
@@ -50,13 +53,12 @@ class BaseStatsCollector(object):
|
|
|
50
53
|
raise Exception(f'update_statistics is not implemented in {self.__class__.__name__}')
|
|
51
54
|
|
|
52
55
|
|
|
53
|
-
class
|
|
56
|
+
class StatsContainer(BaseStatsContainer):
|
|
54
57
|
"""
|
|
55
58
|
Class to wrap all statistics that are being collected for an input/output node.
|
|
56
59
|
"""
|
|
57
60
|
|
|
58
61
|
def __init__(self,
|
|
59
|
-
output_channel_index: ChannelAxis,
|
|
60
62
|
init_min_value: float = None,
|
|
61
63
|
init_max_value: float = None):
|
|
62
64
|
"""
|
|
@@ -64,17 +66,18 @@ class StatsCollector(BaseStatsCollector):
|
|
|
64
66
|
Set initial min/max values if are known.
|
|
65
67
|
|
|
66
68
|
Args:
|
|
67
|
-
output_channel_index: Index of output channels.
|
|
68
69
|
init_min_value: Initial min value for min/max stored values.
|
|
69
70
|
init_max_value: Initial max value for min/max stored values.
|
|
70
71
|
"""
|
|
71
72
|
|
|
72
73
|
super().__init__()
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
74
|
+
self.use_min_max = is_number(init_min_value) and is_number(init_max_value)
|
|
75
|
+
self.collect_histogram = True
|
|
76
|
+
if self.collect_histogram:
|
|
77
|
+
self.hc = HistogramCollector()
|
|
78
|
+
self.mc = MeanCollector()
|
|
75
79
|
self.mpcc = MinMaxPerChannelCollector(init_min_value=init_min_value,
|
|
76
|
-
init_max_value=init_max_value
|
|
77
|
-
axis=output_channel_index)
|
|
80
|
+
init_max_value=init_max_value)
|
|
78
81
|
|
|
79
82
|
def update_statistics(self, x: Any):
|
|
80
83
|
"""
|
|
@@ -85,7 +88,8 @@ class StatsCollector(BaseStatsCollector):
|
|
|
85
88
|
"""
|
|
86
89
|
|
|
87
90
|
x = standardize_tensor(x)
|
|
88
|
-
self.
|
|
91
|
+
if self.collect_histogram:
|
|
92
|
+
self.hc.update(x)
|
|
89
93
|
self.mc.update(x)
|
|
90
94
|
self.mpcc.update(x)
|
|
91
95
|
|
|
@@ -139,7 +143,7 @@ class StatsCollector(BaseStatsCollector):
|
|
|
139
143
|
return True
|
|
140
144
|
|
|
141
145
|
|
|
142
|
-
class
|
|
146
|
+
class NoStatsContainer(BaseStatsContainer):
|
|
143
147
|
"""
|
|
144
148
|
Class that inherits from base tensor.
|
|
145
149
|
Indicating that for a point in a graph we should not gather statistics.
|
|
@@ -207,51 +211,51 @@ def standardize_tensor(x: Any) -> np.ndarray:
|
|
|
207
211
|
return x
|
|
208
212
|
|
|
209
213
|
|
|
210
|
-
def shift_statistics(collector:
|
|
211
|
-
shift_value: np.ndarray) ->
|
|
214
|
+
def shift_statistics(collector: BaseStatsContainer,
|
|
215
|
+
shift_value: np.ndarray) -> BaseStatsContainer:
|
|
212
216
|
"""
|
|
213
|
-
Shift all statistics in collectors of a statistics
|
|
217
|
+
Shift all statistics in collectors of a statistics container by a
|
|
214
218
|
value (or a value per-channel).
|
|
215
219
|
|
|
216
220
|
Args:
|
|
217
|
-
collector: Statistics
|
|
221
|
+
collector: Statistics container to shift its collectors.
|
|
218
222
|
shift_value: Value to shift all statistics by.
|
|
219
223
|
|
|
220
224
|
Returns:
|
|
221
|
-
New copy of the
|
|
225
|
+
New copy of the container with shifted statistics.
|
|
222
226
|
|
|
223
227
|
"""
|
|
224
228
|
|
|
225
229
|
shifted_collector = deepcopy(collector)
|
|
226
|
-
if isinstance(collector,
|
|
230
|
+
if isinstance(collector, StatsContainer):
|
|
227
231
|
shifted_collector.mpcc.shift(shift_value)
|
|
228
232
|
shifted_collector.mc.shift(shift_value)
|
|
229
|
-
if shifted_collector.
|
|
233
|
+
if shifted_collector.collect_histogram:
|
|
230
234
|
shifted_collector.hc.shift(shift_value)
|
|
231
235
|
|
|
232
236
|
return shifted_collector
|
|
233
237
|
|
|
234
238
|
|
|
235
|
-
def scale_statistics(collector:
|
|
236
|
-
scale_value: np.ndarray) ->
|
|
239
|
+
def scale_statistics(collector: BaseStatsContainer,
|
|
240
|
+
scale_value: np.ndarray) -> BaseStatsContainer:
|
|
237
241
|
"""
|
|
238
|
-
Scale all statistics in collectors of a statistics
|
|
242
|
+
Scale all statistics in collectors of a statistics container
|
|
239
243
|
by a factor (or a factor per-channel).
|
|
240
244
|
|
|
241
245
|
Args:
|
|
242
|
-
collector: Statistics
|
|
246
|
+
collector: Statistics container to shift its collectors.
|
|
243
247
|
scale_value: Value to shift all statistics by.
|
|
244
248
|
|
|
245
249
|
Returns:
|
|
246
|
-
New copy of the
|
|
250
|
+
New copy of the container with scaled statistics.
|
|
247
251
|
|
|
248
252
|
"""
|
|
249
253
|
|
|
250
254
|
scaled_collector = deepcopy(collector)
|
|
251
|
-
if isinstance(collector,
|
|
255
|
+
if isinstance(collector, StatsContainer):
|
|
252
256
|
scaled_collector.mpcc.scale(scale_value)
|
|
253
257
|
scaled_collector.mc.scale(scale_value)
|
|
254
|
-
if scaled_collector.
|
|
258
|
+
if scaled_collector.collect_histogram:
|
|
255
259
|
scaled_collector.hc.scale(scale_value)
|
|
256
260
|
|
|
257
261
|
return scaled_collector
|
|
@@ -29,10 +29,10 @@ from tensorboard.compat.proto.summary_pb2 import HistogramProto
|
|
|
29
29
|
from tensorboard.compat.proto.summary_pb2 import Summary
|
|
30
30
|
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
|
|
31
31
|
from tensorboard.summary.writer.event_file_writer import EventFileWriter
|
|
32
|
-
from typing import List, Any, Dict
|
|
32
|
+
from typing import List, Any, Dict, Callable
|
|
33
33
|
|
|
34
|
-
from model_compression_toolkit.common import Graph,
|
|
35
|
-
from model_compression_toolkit.common.
|
|
34
|
+
from model_compression_toolkit.common import Graph, Node
|
|
35
|
+
from model_compression_toolkit.common.statistics_collector import BaseStatsContainer
|
|
36
36
|
|
|
37
37
|
DEVICE_STEP_STATS = "/device:CPU:0"
|
|
38
38
|
|
|
@@ -138,7 +138,7 @@ class TensorboardWriter(object):
|
|
|
138
138
|
bucket_limit=bins.tolist(),
|
|
139
139
|
bucket=counts.tolist())
|
|
140
140
|
|
|
141
|
-
def __create_histo_event(statistics_collector:
|
|
141
|
+
def __create_histo_event(statistics_collector: BaseStatsContainer):
|
|
142
142
|
"""
|
|
143
143
|
Create an event of histogram, and attach it to a list of events outside
|
|
144
144
|
the scope called 'events'.
|
|
@@ -186,7 +186,7 @@ class TensorboardWriter(object):
|
|
|
186
186
|
|
|
187
187
|
"""
|
|
188
188
|
|
|
189
|
-
def __get_node_attr(n:
|
|
189
|
+
def __get_node_attr(n: Node) -> Dict[str, Any]:
|
|
190
190
|
"""
|
|
191
191
|
Create a dictionary to display as the node's attributes.
|
|
192
192
|
The dictionary contains information from node's framework attributes, quantization attributes
|
|
@@ -203,10 +203,7 @@ class TensorboardWriter(object):
|
|
|
203
203
|
if n.quantization_attr is not None:
|
|
204
204
|
attr.update(n.quantization_attr)
|
|
205
205
|
|
|
206
|
-
#
|
|
207
|
-
# if they exist at all, as we can log the initial graph,
|
|
208
|
-
# which its nodes do not have configurations yet.
|
|
209
|
-
# Log final config or unified candidates, not both
|
|
206
|
+
# log final config or unified candidates, not both
|
|
210
207
|
if n.final_weights_quantization_cfg is not None:
|
|
211
208
|
attr.update(n.final_weights_quantization_cfg.__dict__)
|
|
212
209
|
elif n.candidates_weights_quantization_cfg is not None:
|
|
@@ -216,7 +213,7 @@ class TensorboardWriter(object):
|
|
|
216
213
|
attr.update(n.activation_quantization_cfg.__dict__)
|
|
217
214
|
return attr
|
|
218
215
|
|
|
219
|
-
def __get_node_output_dims(n:
|
|
216
|
+
def __get_node_output_dims(n: Node) -> List[tuple]:
|
|
220
217
|
"""
|
|
221
218
|
Get node's output shapes. If the first dimension in an output shape is None,
|
|
222
219
|
it means the batch size is dynamic, and it's replaced with -1 to mark it.
|
|
@@ -240,7 +237,7 @@ class TensorboardWriter(object):
|
|
|
240
237
|
dims = [(-1,) + output_shape[1:] if output_shape[0] is None else output_shape]
|
|
241
238
|
return dims
|
|
242
239
|
|
|
243
|
-
def __create_node_stats(n:
|
|
240
|
+
def __create_node_stats(n: Node):
|
|
244
241
|
"""
|
|
245
242
|
Create a NodeExecStats for a node in the graph. A NodeExecStats contains the
|
|
246
243
|
memory and compute time a node requires.
|
|
@@ -21,7 +21,7 @@ from networkx.algorithms.dag import topological_sort
|
|
|
21
21
|
|
|
22
22
|
from tensorflow.keras.layers import Layer
|
|
23
23
|
from model_compression_toolkit import common
|
|
24
|
-
from model_compression_toolkit.common import Graph,
|
|
24
|
+
from model_compression_toolkit.common import Graph, Node
|
|
25
25
|
from model_compression_toolkit.keras.constants import LAYER_NAME
|
|
26
26
|
|
|
27
27
|
|
|
@@ -35,7 +35,7 @@ class OperationHandler(object):
|
|
|
35
35
|
self.node_to_fw_op_dict = instance_builder(self.node_sort) # hold dictionary from node to its equivalent
|
|
36
36
|
# Keras layer
|
|
37
37
|
|
|
38
|
-
def get_node_op_function(self, n:
|
|
38
|
+
def get_node_op_function(self, n: Node) -> Layer:
|
|
39
39
|
"""
|
|
40
40
|
Get the Keras layer that was built from the passed node.
|
|
41
41
|
|
|
@@ -58,7 +58,7 @@ class OperationHandler(object):
|
|
|
58
58
|
return op_func
|
|
59
59
|
|
|
60
60
|
|
|
61
|
-
def node_builder(n: common.
|
|
61
|
+
def node_builder(n: common.Node) -> Layer:
|
|
62
62
|
"""
|
|
63
63
|
Build a Keras layer from a node.
|
|
64
64
|
|
|
@@ -78,7 +78,7 @@ def node_builder(n: common.BaseNode) -> Layer:
|
|
|
78
78
|
return node_instance
|
|
79
79
|
|
|
80
80
|
|
|
81
|
-
def instance_builder(toposort: List[
|
|
81
|
+
def instance_builder(toposort: List[Node]) -> Dict[Node, Layer]:
|
|
82
82
|
"""
|
|
83
83
|
Build a dictionary of nodes to their corresponding Keras
|
|
84
84
|
layers, given a list of nodes.
|
|
@@ -20,26 +20,24 @@ import tensorflow as tf
|
|
|
20
20
|
if tf.__version__ < "2.6":
|
|
21
21
|
from tensorflow.keras.layers import Input
|
|
22
22
|
from tensorflow.python.keras.layers.core import TFOpLambda
|
|
23
|
-
from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
|
|
24
|
-
from tensorflow.python.keras.layers import Layer
|
|
25
23
|
else:
|
|
26
24
|
from keras import Input
|
|
27
25
|
from keras.layers.core import TFOpLambda
|
|
28
|
-
from keras.engine.base_layer import TensorFlowOpLayer, Layer
|
|
29
26
|
|
|
30
27
|
from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
|
|
28
|
+
from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
|
|
29
|
+
from tensorflow.python.keras.layers import Layer
|
|
31
30
|
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
|
|
32
31
|
from typing import Tuple, Any, Dict, List
|
|
33
32
|
from tensorflow.python.util.object_identity import Reference as TFReference
|
|
34
|
-
|
|
33
|
+
|
|
35
34
|
|
|
36
35
|
from model_compression_toolkit import common
|
|
37
36
|
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
38
37
|
from model_compression_toolkit.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
39
|
-
from model_compression_toolkit.keras.quantizer.mixed_precision.quantization_config_factory import
|
|
40
|
-
quantization_config_builder_mixed_precision
|
|
38
|
+
from model_compression_toolkit.keras.quantizer.mixed_precision.quantization_config_factory import quantization_config_builder_mixed_precision
|
|
41
39
|
from model_compression_toolkit.keras.quantizer.gradient_ptq.config_factory import quantization_config_builder_gptq
|
|
42
|
-
from model_compression_toolkit.common import
|
|
40
|
+
from model_compression_toolkit.common import Node, Graph
|
|
43
41
|
from model_compression_toolkit.common.graph.edge import EDGE_SINK_INDEX
|
|
44
42
|
from model_compression_toolkit.keras.back2framework.instance_builder import OperationHandler
|
|
45
43
|
from model_compression_toolkit.keras.reader.connectivity_handler import OutTensor
|
|
@@ -82,9 +80,9 @@ def is_layer_fake_quant(layer: Layer) -> bool:
|
|
|
82
80
|
isinstance(layer, TFOpLambda) and layer.symbol == FQ_NODE_OP_V2_4)
|
|
83
81
|
|
|
84
82
|
|
|
85
|
-
def build_input_tensors_list(node:
|
|
83
|
+
def build_input_tensors_list(node: Node,
|
|
86
84
|
graph: Graph,
|
|
87
|
-
node_to_output_tensors_dict: Dict[
|
|
85
|
+
node_to_output_tensors_dict: Dict[Node, List[TFReference]]) -> List[List[TFReference]]:
|
|
88
86
|
"""
|
|
89
87
|
Given a node, build a list of input tensors the node gets. The list is built
|
|
90
88
|
based on the node's incoming edges and previous nodes' output tensors.
|
|
@@ -107,10 +105,10 @@ def build_input_tensors_list(node: BaseNode,
|
|
|
107
105
|
return input_tensors
|
|
108
106
|
|
|
109
107
|
|
|
110
|
-
def run_operation(n:
|
|
108
|
+
def run_operation(n: Node,
|
|
111
109
|
input_tensors: List[List[TFReference]],
|
|
112
110
|
op_func: Layer,
|
|
113
|
-
input_nodes_to_input_tensors: Dict[
|
|
111
|
+
input_nodes_to_input_tensors: Dict[Node, Any],
|
|
114
112
|
mode: ModelBuilderMode = ModelBuilderMode.QUANTIZED) -> List[TFReference]:
|
|
115
113
|
"""
|
|
116
114
|
Applying the layer (op_func) to the input tensors (input_tensors).
|
|
@@ -131,43 +129,33 @@ def run_operation(n: BaseNode,
|
|
|
131
129
|
|
|
132
130
|
if len(input_tensors) == 0: # Placeholder handling
|
|
133
131
|
out_tensors_of_n = input_nodes_to_input_tensors[n]
|
|
134
|
-
if
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
raise Exception(f'{n.name} should be quantized, but activation quantization function is None')
|
|
142
|
-
|
|
132
|
+
if mode in [ModelBuilderMode.QUANTIZED, ModelBuilderMode.GPTQ, ModelBuilderMode.MIXEDPRECISION]:
|
|
133
|
+
# Adding a fake quant node to Input when in GPTQ mode because quantize_model doesn't quantize the input layer
|
|
134
|
+
assert n.activation_quantization_cfg is not None # Input layers should always have activation config
|
|
135
|
+
fake_quant = n.activation_quantization_cfg.activation_quantization_fn(n.activation_quantization_cfg.activation_n_bits,
|
|
136
|
+
n.activation_quantization_cfg.activation_is_signed,
|
|
137
|
+
n.activation_quantization_cfg.activation_quantization_params)
|
|
138
|
+
if fake_quant is not None:
|
|
143
139
|
out_tensors_of_n = fake_quant(out_tensors_of_n)
|
|
144
|
-
|
|
140
|
+
|
|
145
141
|
else:
|
|
146
142
|
input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
out_tensors_of_n = op_func(*input_tensors, *n.op_call_args, **n.op_call_kwargs)
|
|
143
|
+
|
|
144
|
+
# If operator expects a single input tensor, it cannot be a list as it should
|
|
145
|
+
# have a dtype field.
|
|
146
|
+
if len(input_tensors) == 1:
|
|
147
|
+
out_tensors_of_n = op_func(input_tensors[0], **n.op_call_args)
|
|
153
148
|
else:
|
|
154
|
-
|
|
155
|
-
# have a dtype field.
|
|
156
|
-
if len(input_tensors) == 1:
|
|
157
|
-
input_tensors = input_tensors[0]
|
|
158
|
-
out_tensors_of_n = op_func(input_tensors)
|
|
149
|
+
out_tensors_of_n = op_func(input_tensors, **n.op_call_args)
|
|
159
150
|
|
|
160
151
|
# Add a fake quant node if the node has an activation threshold.
|
|
161
|
-
if n.
|
|
162
|
-
if mode in [ModelBuilderMode.QUANTIZED,
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
raise Exception(f'{n.name} should be quantized, but activation quantization function is None')
|
|
169
|
-
|
|
170
|
-
out_tensors_of_n = fake_quant(out_tensors_of_n)
|
|
152
|
+
if n.activation_quantization_cfg is not None:
|
|
153
|
+
if mode in [ModelBuilderMode.QUANTIZED, ModelBuilderMode.MIXEDPRECISION] and n.activation_quantization_cfg.enable_activation_quantization:
|
|
154
|
+
fake_quant = n.activation_quantization_cfg.activation_quantization_fn(n.activation_quantization_cfg.activation_n_bits,
|
|
155
|
+
n.activation_quantization_cfg.activation_is_signed,
|
|
156
|
+
n.activation_quantization_cfg.activation_quantization_params)
|
|
157
|
+
if fake_quant is not None:
|
|
158
|
+
out_tensors_of_n = fake_quant(out_tensors_of_n)
|
|
171
159
|
|
|
172
160
|
return out_tensors_of_n
|
|
173
161
|
|
|
@@ -277,11 +265,10 @@ def model_builder(graph: common.Graph,
|
|
|
277
265
|
nodes = graph.find_node_by_name(get_node_name_from_layer(layer))
|
|
278
266
|
if len(nodes) == 1:
|
|
279
267
|
node = nodes[0]
|
|
280
|
-
#
|
|
281
|
-
if node.
|
|
282
|
-
return
|
|
283
|
-
return layer
|
|
284
|
-
|
|
268
|
+
# does not need to get wrapped as its weights are not quantized
|
|
269
|
+
if node.candidates_weights_quantization_cfg is None:
|
|
270
|
+
return layer
|
|
271
|
+
return QuantizeWrapper(layer, quantization_config_builder_mixed_precision(node, fw_info))
|
|
285
272
|
elif is_layer_fake_quant(layer):
|
|
286
273
|
return layer
|
|
287
274
|
else:
|
|
@@ -38,9 +38,6 @@ PAD_SAME = 'same'
|
|
|
38
38
|
RELU_MAX_VALUE = 'max_value'
|
|
39
39
|
THRESHOLD = 'threshold'
|
|
40
40
|
NEGATIVE_SLOPE = 'negative_slope'
|
|
41
|
-
CHANNELS_FORMAT = 'data_format'
|
|
42
|
-
CHANNELS_FORMAT_FIRST = 'channels_first'
|
|
43
|
-
CHANNELS_FORMAT_LAST = 'channels_last'
|
|
44
41
|
|
|
45
42
|
# Layers variables names:
|
|
46
43
|
KERNEL = 'kernel'
|
|
@@ -14,18 +14,12 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
MaxPooling2D, Activation, ReLU, GlobalAveragePooling2D, Add, Multiply, AveragePooling2D, UpSampling2D, InputLayer, \
|
|
21
|
-
Concatenate, Softmax, PReLU, Flatten, Cropping2D
|
|
22
|
-
else:
|
|
23
|
-
from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Reshape, ZeroPadding2D, \
|
|
24
|
-
Dropout, MaxPooling2D, Activation, ReLU, GlobalAveragePooling2D, Add, Multiply, AveragePooling2D, UpSampling2D, \
|
|
25
|
-
InputLayer, Concatenate, Softmax, PReLU, Flatten, Cropping2D
|
|
17
|
+
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Reshape, ZeroPadding2D, Dropout, \
|
|
18
|
+
MaxPooling2D, Activation, ReLU, GlobalAveragePooling2D, Add, Multiply, AveragePooling2D, UpSampling2D, InputLayer, \
|
|
19
|
+
Concatenate, Softmax, PReLU, Flatten, Cropping2D
|
|
26
20
|
|
|
27
21
|
from model_compression_toolkit.common.defaultdict import DefaultDict
|
|
28
|
-
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
22
|
+
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
29
23
|
from model_compression_toolkit.common.quantization.quantization_config import QuantizationMethod
|
|
30
24
|
from model_compression_toolkit.common.quantization.quantizers.kmeans_quantizer import kmeans_quantizer
|
|
31
25
|
from model_compression_toolkit.common.quantization.quantizers.lut_kmeans_quantizer import lut_kmeans_quantizer
|
|
@@ -51,9 +45,7 @@ NO_QUANTIZATION = [Reshape,
|
|
|
51
45
|
Cropping2D,
|
|
52
46
|
ZeroPadding2D,
|
|
53
47
|
Dropout,
|
|
54
|
-
MaxPooling2D
|
|
55
|
-
tf.split,
|
|
56
|
-
tf.quantization.fake_quant_with_min_max_vars] # TODO: replace with marking
|
|
48
|
+
MaxPooling2D] # TODO: replace with marking
|
|
57
49
|
|
|
58
50
|
ACTIVATION = [Activation,
|
|
59
51
|
ReLU,
|
|
@@ -65,19 +57,7 @@ ACTIVATION = [Activation,
|
|
|
65
57
|
UpSampling2D,
|
|
66
58
|
InputLayer,
|
|
67
59
|
Concatenate,
|
|
68
|
-
PReLU
|
|
69
|
-
tf.add,
|
|
70
|
-
tf.multiply,
|
|
71
|
-
tf.reduce_mean,
|
|
72
|
-
tf.reduce_min,
|
|
73
|
-
tf.reduce_sum,
|
|
74
|
-
tf.reduce_max,
|
|
75
|
-
tf.image.resize,
|
|
76
|
-
tf.image.crop_and_resize,
|
|
77
|
-
tf.concat]
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
60
|
+
PReLU]
|
|
81
61
|
|
|
82
62
|
"""
|
|
83
63
|
Map each layer to a list of its' weights attributes that should get quantized.
|
|
@@ -131,11 +111,6 @@ WEIGHTS_QUANTIZER_MAPPING = {QuantizationMethod.POWER_OF_TWO: power_of_two_quant
|
|
|
131
111
|
QuantizationMethod.KMEANS: kmeans_quantizer,
|
|
132
112
|
QuantizationMethod.LUT_QUANTIZER: lut_kmeans_quantizer}
|
|
133
113
|
|
|
134
|
-
"""
|
|
135
|
-
Output channel index of the model's layers
|
|
136
|
-
"""
|
|
137
|
-
OUTPUT_CHANNEL_INDEX = ChannelAxis.NHWC
|
|
138
|
-
|
|
139
114
|
DEFAULT_KERAS_INFO = FrameworkInfo(KERNEL_OPS,
|
|
140
115
|
ACTIVATION,
|
|
141
116
|
NO_QUANTIZATION,
|
|
@@ -144,5 +119,4 @@ DEFAULT_KERAS_INFO = FrameworkInfo(KERNEL_OPS,
|
|
|
144
119
|
DEFAULT_CHANNEL_AXIS_DICT,
|
|
145
120
|
ACTIVATION2MINMAX,
|
|
146
121
|
LAYER2MINMAX,
|
|
147
|
-
KERNEL_ATTRIBUTES
|
|
148
|
-
OUTPUT_CHANNEL_INDEX)
|
|
122
|
+
KERNEL_ATTRIBUTES)
|
|
@@ -20,14 +20,14 @@ from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapp
|
|
|
20
20
|
from typing import Tuple, List
|
|
21
21
|
|
|
22
22
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
23
|
-
from model_compression_toolkit.common.graph.
|
|
23
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
24
24
|
from model_compression_toolkit.keras.constants import USE_BIAS
|
|
25
25
|
from model_compression_toolkit.keras.quantizer.gradient_ptq import ActivationQuantizeConfig, WeightQuantizeConfig, ActivationAndWeightQuantizeConfig
|
|
26
26
|
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
27
27
|
from tensorflow.keras.models import Model
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
def get_compare_points(input_graph: Graph) -> Tuple[List[
|
|
30
|
+
def get_compare_points(input_graph: Graph) -> Tuple[List[Node], List[str]]:
|
|
31
31
|
"""
|
|
32
32
|
Create a list of nodes with weights in a graph and a corresponding list
|
|
33
33
|
of their names for tensors comparison purposes.
|
|
@@ -15,14 +15,8 @@
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
import copy
|
|
18
|
-
import tensorflow as tf
|
|
19
|
-
|
|
20
|
-
# As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
|
|
21
|
-
if tf.__version__ < "2.6":
|
|
22
|
-
from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
|
|
23
|
-
else:
|
|
24
|
-
from keras.engine.base_layer import TensorFlowOpLayer
|
|
25
18
|
|
|
19
|
+
from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
|
|
26
20
|
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
|
|
27
21
|
|
|
28
22
|
from model_compression_toolkit import common
|
|
@@ -57,6 +57,7 @@ def gptq_training_wrapper(tg: Graph,
|
|
|
57
57
|
#########################################
|
|
58
58
|
# Build two models and compare points
|
|
59
59
|
#########################################
|
|
60
|
+
# TODO: maybe need to add pre_build substitutions here. Ask Elad
|
|
60
61
|
compare_points, _ = get_compare_points(tg) # get compare points
|
|
61
62
|
n = len(compare_points)
|
|
62
63
|
float_model, float_user_info = model_builder(tg,
|
model_compression_toolkit/keras/graph_substitutions/substitutions/activation_decomposition.py
CHANGED
|
@@ -21,7 +21,7 @@ from model_compression_toolkit.common.constants import FLOAT_32, DATA_TYPE
|
|
|
21
21
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
22
22
|
from model_compression_toolkit.common.graph.graph_matchers import NodeOperationMatcher, \
|
|
23
23
|
NodeFrameworkAttrMatcher
|
|
24
|
-
from model_compression_toolkit.common.graph.
|
|
24
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
25
25
|
from model_compression_toolkit.keras.constants import LINEAR, ACTIVATION, TRAINABLE, LAYER_NAME
|
|
26
26
|
|
|
27
27
|
|
|
@@ -49,7 +49,7 @@ class ActivationDecomposition(common.BaseSubstitution):
|
|
|
49
49
|
|
|
50
50
|
def substitute(self,
|
|
51
51
|
graph: Graph,
|
|
52
|
-
op2d_node:
|
|
52
|
+
op2d_node: Node) -> Graph:
|
|
53
53
|
"""
|
|
54
54
|
Decompose the activation function in a linear node to a new activation layer.
|
|
55
55
|
Set activation function in the linear node to 'linear' (y=x).
|
|
@@ -70,14 +70,12 @@ class ActivationDecomposition(common.BaseSubstitution):
|
|
|
70
70
|
DATA_TYPE: FLOAT_32,
|
|
71
71
|
ACTIVATION: op2d_node.framework_attr.get(ACTIVATION)}
|
|
72
72
|
|
|
73
|
-
activation_node = common.graph.
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
73
|
+
activation_node = common.graph.Node(activation_node_name,
|
|
74
|
+
activation_fw_attr,
|
|
75
|
+
op2d_node.output_shape,
|
|
76
|
+
op2d_node.output_shape,
|
|
77
|
+
{},
|
|
78
|
+
Activation)
|
|
81
79
|
|
|
82
80
|
graph.add_node(activation_node)
|
|
83
81
|
graph.reconnect_out_edges(current_node=op2d_node,
|
|
@@ -24,7 +24,7 @@ from model_compression_toolkit import common
|
|
|
24
24
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
25
25
|
from model_compression_toolkit.common.graph.graph_matchers import EdgeMatcher, NodeOperationMatcher, \
|
|
26
26
|
NodeFrameworkAttrMatcher
|
|
27
|
-
from model_compression_toolkit.common.graph.
|
|
27
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
28
28
|
from model_compression_toolkit.keras.constants import KERNEL, BIAS, USE_BIAS, LINEAR, ACTIVATION, LAYER_NAME, \
|
|
29
29
|
GAMMA, BETA, EPSILON, \
|
|
30
30
|
MOVING_MEAN, \
|
|
@@ -51,7 +51,7 @@ class BatchNormalizationFolding(common.BaseSubstitution):
|
|
|
51
51
|
|
|
52
52
|
def substitute(self,
|
|
53
53
|
graph: Graph,
|
|
54
|
-
edge_nodes: Tuple[
|
|
54
|
+
edge_nodes: Tuple[Node, Node]) -> Graph:
|
|
55
55
|
"""
|
|
56
56
|
Fold BatchNormalization into preceding linear layers.
|
|
57
57
|
|
|
@@ -21,7 +21,7 @@ from model_compression_toolkit import common
|
|
|
21
21
|
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
22
22
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
23
23
|
from model_compression_toolkit.common.graph.graph_matchers import NodeOperationMatcher, EdgeMatcher, WalkMatcher
|
|
24
|
-
from model_compression_toolkit.common.graph.
|
|
24
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
25
25
|
from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
|
|
26
26
|
from model_compression_toolkit.common.constants import THRESHOLD
|
|
27
27
|
from model_compression_toolkit.keras.constants import KERNEL
|
|
@@ -62,7 +62,7 @@ class BaseInputScaling(common.BaseSubstitution):
|
|
|
62
62
|
|
|
63
63
|
def substitute(self,
|
|
64
64
|
graph: Graph,
|
|
65
|
-
nodes_list: List[
|
|
65
|
+
nodes_list: List[Node]) -> Graph:
|
|
66
66
|
"""
|
|
67
67
|
Scale activation threshold for input layers, if they are followed by linear nodes. We first
|
|
68
68
|
scale their thresholds to a constrained threshold, and then fix it by scaling the linear op weights
|
|
@@ -21,7 +21,7 @@ from model_compression_toolkit import common
|
|
|
21
21
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
22
22
|
from model_compression_toolkit.common.graph.graph_matchers import NodeOperationMatcher, EdgeMatcher, \
|
|
23
23
|
NodeFrameworkAttrMatcher
|
|
24
|
-
from model_compression_toolkit.common.graph.
|
|
24
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
25
25
|
from model_compression_toolkit.keras.constants import LINEAR, ACTIVATION
|
|
26
26
|
|
|
27
27
|
|
|
@@ -55,7 +55,7 @@ class MarkActivation(common.BaseSubstitution):
|
|
|
55
55
|
|
|
56
56
|
def substitute(self,
|
|
57
57
|
graph: Graph,
|
|
58
|
-
edge: Tuple[
|
|
58
|
+
edge: Tuple[Node, Node]) -> Graph:
|
|
59
59
|
"""
|
|
60
60
|
Mark the first node in an edge that should not be quantized as so.
|
|
61
61
|
This can be done due to the following reasons:
|
|
@@ -69,5 +69,5 @@ class MarkActivation(common.BaseSubstitution):
|
|
|
69
69
|
Graph after applying the substitution.
|
|
70
70
|
"""
|
|
71
71
|
|
|
72
|
-
edge[0].
|
|
72
|
+
edge[0].output_quantization = False
|
|
73
73
|
return graph
|
|
@@ -20,11 +20,12 @@ import numpy as np
|
|
|
20
20
|
from tensorflow.keras.layers import ReLU, Activation, DepthwiseConv2D, Conv2DTranspose, Conv2D, Dense
|
|
21
21
|
|
|
22
22
|
from model_compression_toolkit import common
|
|
23
|
-
from model_compression_toolkit.common import Graph,
|
|
23
|
+
from model_compression_toolkit.common import FrameworkInfo, Graph, Node
|
|
24
24
|
from model_compression_toolkit.common.constants import THRESHOLD
|
|
25
25
|
from model_compression_toolkit.common.graph.graph_matchers import NodeOperationMatcher, WalkMatcher, \
|
|
26
26
|
NodeFrameworkAttrMatcher
|
|
27
|
-
from model_compression_toolkit.common.
|
|
27
|
+
from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
|
|
28
|
+
from model_compression_toolkit.common.statistics_collector import is_number
|
|
28
29
|
from model_compression_toolkit.keras.constants import KERNEL, BIAS, ACTIVATION, RELU_MAX_VALUE
|
|
29
30
|
from model_compression_toolkit.keras.constants import RELU
|
|
30
31
|
|
|
@@ -60,7 +61,7 @@ class ReLUBoundCorrection(common.BaseSubstitution):
|
|
|
60
61
|
|
|
61
62
|
def substitute(self,
|
|
62
63
|
graph: Graph,
|
|
63
|
-
nodes_list: List[
|
|
64
|
+
nodes_list: List[Node]) -> Graph:
|
|
64
65
|
"""
|
|
65
66
|
Transform a list of nodes in a graph to use the entire constrained quantized range.
|
|
66
67
|
This is done by scaling two linear nodes with a non-linearity between them, if the non-linearity
|