mct-nightly 1.1.0.6012022.post2521__py3-none-any.whl → 1.1.0.07122021-002414__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.1.0.6012022.post2521.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/METADATA +3 -3
- {mct_nightly-1.1.0.6012022.post2521.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/RECORD +72 -76
- {mct_nightly-1.1.0.6012022.post2521.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/WHEEL +1 -1
- model_compression_toolkit/common/__init__.py +2 -2
- model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +2 -2
- model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +9 -9
- model_compression_toolkit/common/collectors/mean_collector.py +2 -3
- model_compression_toolkit/common/collectors/min_max_per_channel_collector.py +3 -6
- model_compression_toolkit/common/constants.py +0 -1
- model_compression_toolkit/common/framework_implementation.py +6 -22
- model_compression_toolkit/common/framework_info.py +7 -39
- model_compression_toolkit/common/graph/__init__.py +1 -1
- model_compression_toolkit/common/graph/base_graph.py +34 -34
- model_compression_toolkit/common/graph/edge.py +3 -3
- model_compression_toolkit/common/graph/graph_matchers.py +3 -3
- model_compression_toolkit/common/graph/graph_searches.py +4 -4
- model_compression_toolkit/common/graph/graph_vis.py +116 -0
- model_compression_toolkit/common/graph/{base_node.py → node.py} +27 -11
- model_compression_toolkit/common/mixed_precision/bit_width_setter.py +9 -10
- model_compression_toolkit/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/common/model_collector.py +12 -14
- model_compression_toolkit/common/network_editors/actions.py +23 -19
- model_compression_toolkit/common/post_training_quantization.py +7 -20
- model_compression_toolkit/common/quantization/node_quantization_config.py +5 -13
- model_compression_toolkit/common/quantization/quantization_analyzer.py +7 -11
- model_compression_toolkit/common/quantization/quantization_config.py +6 -6
- model_compression_toolkit/common/quantization/quantization_params_fn_selection.py +3 -2
- model_compression_toolkit/common/quantization/quantization_params_generation/qparams_activations_computation.py +7 -13
- model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +20 -17
- model_compression_toolkit/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/common/quantization/set_node_quantization_config.py +36 -39
- model_compression_toolkit/common/{collectors/statistics_collector.py → statistics_collector.py} +30 -26
- model_compression_toolkit/common/visualization/tensorboard_writer.py +8 -11
- model_compression_toolkit/keras/back2framework/instance_builder.py +4 -4
- model_compression_toolkit/keras/back2framework/model_builder.py +34 -47
- model_compression_toolkit/keras/constants.py +0 -3
- model_compression_toolkit/keras/default_framework_info.py +7 -33
- model_compression_toolkit/keras/gradient_ptq/graph_info.py +2 -2
- model_compression_toolkit/keras/gradient_ptq/graph_update.py +1 -7
- model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +1 -0
- model_compression_toolkit/keras/graph_substitutions/substitutions/activation_decomposition.py +8 -10
- model_compression_toolkit/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -2
- model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +2 -2
- model_compression_toolkit/keras/graph_substitutions/substitutions/mark_activation.py +3 -3
- model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +4 -3
- model_compression_toolkit/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +2 -2
- model_compression_toolkit/keras/graph_substitutions/substitutions/scale_equalization.py +9 -9
- model_compression_toolkit/keras/graph_substitutions/substitutions/separableconv_decomposition.py +19 -19
- model_compression_toolkit/keras/graph_substitutions/substitutions/shift_negative_activation.py +45 -64
- model_compression_toolkit/keras/keras_implementation.py +8 -28
- model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/keras/quantization_facade.py +1 -5
- model_compression_toolkit/keras/quantizer/fake_quant_builder.py +4 -4
- model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer.py +2 -3
- model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer_gptq_config.py +4 -8
- model_compression_toolkit/keras/quantizer/gradient_ptq/activation_weight_quantizer_gptq_config.py +4 -9
- model_compression_toolkit/keras/quantizer/gradient_ptq/config_factory.py +10 -9
- model_compression_toolkit/keras/quantizer/gradient_ptq/weight_quantizer_gptq_config.py +1 -9
- model_compression_toolkit/keras/quantizer/mixed_precision/quantization_config_factory.py +1 -1
- model_compression_toolkit/keras/quantizer/mixed_precision/selective_weights_quantize_config.py +1 -6
- model_compression_toolkit/keras/reader/common.py +11 -9
- model_compression_toolkit/keras/reader/connectivity_handler.py +9 -15
- model_compression_toolkit/keras/reader/nested_model/edges_merger.py +6 -6
- model_compression_toolkit/keras/reader/nested_model/nested_model_handler.py +2 -2
- model_compression_toolkit/keras/reader/nested_model/nodes_merger.py +3 -3
- model_compression_toolkit/keras/reader/nested_model/outputs_merger.py +2 -2
- model_compression_toolkit/keras/reader/node_builder.py +15 -65
- model_compression_toolkit/keras/reader/reader.py +5 -5
- model_compression_toolkit/keras/tensor_marking.py +113 -0
- model_compression_toolkit/keras/visualization/nn_visualizer.py +2 -2
- model_compression_toolkit/common/collectors/statistics_collector_generator.py +0 -43
- model_compression_toolkit/common/graph/functional_node.py +0 -59
- model_compression_toolkit/common/model_validation.py +0 -43
- model_compression_toolkit/common/node_prior_info.py +0 -29
- model_compression_toolkit/keras/keras_model_validation.py +0 -38
- model_compression_toolkit/keras/keras_node_prior_info.py +0 -60
- {mct_nightly-1.1.0.6012022.post2521.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/LICENSE +0 -0
- {mct_nightly-1.1.0.6012022.post2521.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/top_level.txt +0 -0
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
from tensorflow.keras.layers import DepthwiseConv2D
|
|
17
|
+
from tensorflow.keras.layers import DepthwiseConv2D, Conv2D, Conv2DTranspose, Dense
|
|
18
18
|
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_configs import \
|
|
19
19
|
NoOpQuantizeConfig
|
|
20
20
|
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_registry import \
|
|
@@ -30,7 +30,7 @@ from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
|
30
30
|
MAX_LSBS_CHANGE = 8
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
def quantization_config_builder_gptq(n: common.
|
|
33
|
+
def quantization_config_builder_gptq(n: common.Node,
|
|
34
34
|
fw_info: FrameworkInfo) -> QuantizeConfig:
|
|
35
35
|
"""
|
|
36
36
|
Build a QuantizeConfig for a node according to its quantization configuration and
|
|
@@ -45,28 +45,29 @@ def quantization_config_builder_gptq(n: common.BaseNode,
|
|
|
45
45
|
quantization configuration).
|
|
46
46
|
"""
|
|
47
47
|
|
|
48
|
-
if n.
|
|
48
|
+
if n.activation_quantization() and n.weight_quantization():
|
|
49
49
|
qc = keras.quantizer.gradient_ptq.ActivationAndWeightQuantizeConfig(fw_info.get_kernel_op_attributes(n.layer_class),
|
|
50
50
|
n.final_weights_quantization_cfg.weights_quantization_params.get(THRESHOLD),
|
|
51
51
|
n.final_weights_quantization_cfg.weights_channels_axis,
|
|
52
52
|
n.final_weights_quantization_cfg.weights_n_bits,
|
|
53
53
|
n.activation_quantization_cfg.activation_quantization_params,
|
|
54
|
+
n.activation_quantization_cfg.activation_is_signed,
|
|
54
55
|
activation_num_bits=n.activation_quantization_cfg.activation_n_bits,
|
|
55
56
|
max_lsbs_change=MAX_LSBS_CHANGE
|
|
56
57
|
)
|
|
57
|
-
|
|
58
|
-
elif n.is_activation_quantization_enabled() and not n.is_weights_quantization_enabled():
|
|
58
|
+
elif n.activation_quantization():
|
|
59
59
|
qc = keras.quantizer.gradient_ptq.ActivationQuantizeConfig(n.activation_quantization_cfg.activation_quantization_params,
|
|
60
|
+
n.activation_quantization_cfg.activation_is_signed,
|
|
60
61
|
num_bits=n.activation_quantization_cfg.activation_n_bits)
|
|
61
|
-
|
|
62
|
-
elif n.is_weights_quantization_enabled() and not n.is_activation_quantization_enabled():
|
|
62
|
+
elif n.weight_quantization():
|
|
63
63
|
qc = keras.quantizer.gradient_ptq.WeightQuantizeConfig(fw_info.get_kernel_op_attributes(n.layer_class),
|
|
64
64
|
n.final_weights_quantization_cfg.weights_quantization_params.get(THRESHOLD),
|
|
65
65
|
n.final_weights_quantization_cfg.weights_channels_axis,
|
|
66
66
|
n.final_weights_quantization_cfg.weights_n_bits,
|
|
67
|
-
max_lsbs_change=MAX_LSBS_CHANGE
|
|
67
|
+
max_lsbs_change=MAX_LSBS_CHANGE
|
|
68
|
+
)
|
|
68
69
|
|
|
69
|
-
elif
|
|
70
|
+
elif n.no_quantization():
|
|
70
71
|
qc = NoOpQuantizeConfig()
|
|
71
72
|
|
|
72
73
|
else:
|
|
@@ -16,15 +16,7 @@
|
|
|
16
16
|
from typing import List, Tuple, Any, Dict
|
|
17
17
|
|
|
18
18
|
from tensorflow import Tensor
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
# As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
|
|
22
|
-
if tf.__version__ < "2.6":
|
|
23
|
-
from tensorflow.python.keras.layers import Layer
|
|
24
|
-
else:
|
|
25
|
-
from keras.engine.base_layer import Layer
|
|
26
|
-
|
|
27
|
-
|
|
19
|
+
from tensorflow.python.keras.layers import Layer
|
|
28
20
|
from tensorflow.python.training.tracking.data_structures import ListWrapper
|
|
29
21
|
from tensorflow_model_optimization.python.core.quantization.keras.quantizers import Quantizer
|
|
30
22
|
|
|
@@ -25,7 +25,7 @@ from model_compression_toolkit.keras.quantizer.mixed_precision.selective_weights
|
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def quantization_config_builder_mixed_precision(n: common.
|
|
28
|
+
def quantization_config_builder_mixed_precision(n: common.Node,
|
|
29
29
|
fw_info: FrameworkInfo) -> QuantizeConfig:
|
|
30
30
|
"""
|
|
31
31
|
Build a QuantizeConfig for layers that should be wrapped in a QuantizeWrapper to
|
model_compression_toolkit/keras/quantizer/mixed_precision/selective_weights_quantize_config.py
CHANGED
|
@@ -17,12 +17,7 @@ from typing import List, Tuple, Any, Dict
|
|
|
17
17
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
from tensorflow import Tensor
|
|
20
|
-
|
|
21
|
-
# As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
|
|
22
|
-
if tf.__version__ < "2.6":
|
|
23
|
-
from tensorflow.python.keras.layers import Layer
|
|
24
|
-
else:
|
|
25
|
-
from keras.engine.base_layer import Layer
|
|
20
|
+
from tensorflow.python.keras.layers import Layer
|
|
26
21
|
from tensorflow.python.training.tracking.data_structures import ListWrapper
|
|
27
22
|
from tensorflow_model_optimization.python.core.quantization.keras.quantize_config import QuantizeConfig
|
|
28
23
|
|
|
@@ -20,19 +20,20 @@ import tensorflow as tf
|
|
|
20
20
|
if tf.__version__ < "2.6":
|
|
21
21
|
from tensorflow.python.keras.engine.node import Node as KerasNode
|
|
22
22
|
from tensorflow.keras.layers import InputLayer
|
|
23
|
-
from tensorflow.python.keras.engine.functional import Functional
|
|
24
|
-
from tensorflow.python.keras.engine.sequential import Sequential
|
|
25
23
|
else:
|
|
26
24
|
from keras.engine.input_layer import InputLayer
|
|
27
25
|
from keras.engine.node import Node as KerasNode
|
|
28
|
-
from keras.engine.functional import Functional
|
|
29
|
-
from keras.engine.sequential import Sequential
|
|
30
26
|
|
|
31
|
-
from model_compression_toolkit.common.graph.base_node import BaseNode
|
|
32
27
|
|
|
28
|
+
from tensorflow.python.keras.engine.functional import Functional
|
|
33
29
|
|
|
30
|
+
from tensorflow.python.keras.engine.sequential import Sequential
|
|
34
31
|
|
|
35
|
-
|
|
32
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def is_node_an_input_layer(node: Node) -> bool:
|
|
36
37
|
"""
|
|
37
38
|
Checks if a node represents a Keras input layer.
|
|
38
39
|
Args:
|
|
@@ -41,14 +42,14 @@ def is_node_an_input_layer(node: BaseNode) -> bool:
|
|
|
41
42
|
Returns:
|
|
42
43
|
Whether the node represents an input layer or not.
|
|
43
44
|
"""
|
|
44
|
-
if isinstance(node,
|
|
45
|
+
if isinstance(node, Node):
|
|
45
46
|
return node.layer_class == InputLayer
|
|
46
47
|
elif isinstance(node, KerasNode):
|
|
47
48
|
return isinstance(node.layer, InputLayer)
|
|
48
49
|
else:
|
|
49
50
|
raise Exception('Node to check has to be either a graph node or a keras node')
|
|
50
51
|
|
|
51
|
-
def is_node_a_model(node:
|
|
52
|
+
def is_node_a_model(node: Node) -> bool:
|
|
52
53
|
"""
|
|
53
54
|
Checks if a node represents a Keras model.
|
|
54
55
|
Args:
|
|
@@ -57,10 +58,11 @@ def is_node_a_model(node: BaseNode) -> bool:
|
|
|
57
58
|
Returns:
|
|
58
59
|
Whether the node represents a Keras model or not.
|
|
59
60
|
"""
|
|
60
|
-
if isinstance(node,
|
|
61
|
+
if isinstance(node, Node):
|
|
61
62
|
return node.layer_class in [Functional, Sequential]
|
|
62
63
|
elif isinstance(node, KerasNode):
|
|
63
64
|
return isinstance(node.layer, Functional) or isinstance(node.layer, Sequential)
|
|
64
65
|
else:
|
|
65
66
|
raise Exception('Node to check has to be either a graph node or a keras node')
|
|
66
67
|
|
|
68
|
+
# return node.layer_class in [Functional, Sequential]
|
|
@@ -15,18 +15,12 @@
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
import tensorflow as tf
|
|
18
|
-
|
|
19
|
-
if tf.__version__ < "2.6":
|
|
20
|
-
from tensorflow.python.keras.engine.node import Node as KerasNode
|
|
21
|
-
else:
|
|
22
|
-
from keras.engine.node import Node as KerasNode
|
|
23
|
-
|
|
24
|
-
|
|
18
|
+
from tensorflow.python.keras.engine.node import Node as KerasNode
|
|
25
19
|
from tensorflow.python.util.object_identity import Reference as TFReference
|
|
26
20
|
from typing import List, Tuple
|
|
27
21
|
|
|
28
22
|
from model_compression_toolkit.common.graph.base_graph import OutTensor
|
|
29
|
-
from model_compression_toolkit.common.graph.
|
|
23
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
30
24
|
from model_compression_toolkit.keras.reader.common import is_node_an_input_layer
|
|
31
25
|
from model_compression_toolkit.keras.reader.node_builder import build_node
|
|
32
26
|
|
|
@@ -47,7 +41,7 @@ class ConnectivityHandler(object):
|
|
|
47
41
|
self._nodes2output_tensors = dict() # Node -> List[Tensor]
|
|
48
42
|
self._output_tensors2nodes = dict() # Tensor -> Node
|
|
49
43
|
|
|
50
|
-
def get_nodes(self) -> List[
|
|
44
|
+
def get_nodes(self) -> List[Node]:
|
|
51
45
|
"""
|
|
52
46
|
Returns: List of nodes in the connectivity handler.
|
|
53
47
|
"""
|
|
@@ -66,7 +60,7 @@ class ConnectivityHandler(object):
|
|
|
66
60
|
return self._input_tensors2nodes.get(tensor) is not None
|
|
67
61
|
|
|
68
62
|
def input_tensor2nodes(self,
|
|
69
|
-
in_tensor: str) -> List[
|
|
63
|
+
in_tensor: str) -> List[Node]:
|
|
70
64
|
"""
|
|
71
65
|
Returns a list of nodes that have a given tensor in their input tensors.
|
|
72
66
|
Args:
|
|
@@ -90,7 +84,7 @@ class ConnectivityHandler(object):
|
|
|
90
84
|
return self._output_tensors2nodes[out_tensor] if out_tensor in self._output_tensors2nodes else None
|
|
91
85
|
|
|
92
86
|
def node2input_tensors(self,
|
|
93
|
-
node:
|
|
87
|
+
node: Node) -> List[TFReference]:
|
|
94
88
|
"""
|
|
95
89
|
Get a list of input tensors of a node.
|
|
96
90
|
Args:
|
|
@@ -102,7 +96,7 @@ class ConnectivityHandler(object):
|
|
|
102
96
|
return self._nodes2input_tensors[node] if node in self._nodes2input_tensors else []
|
|
103
97
|
|
|
104
98
|
def node2output_tensors(self,
|
|
105
|
-
node:
|
|
99
|
+
node: Node) -> List[TFReference]:
|
|
106
100
|
"""
|
|
107
101
|
Get a list of output tensors of a node.
|
|
108
102
|
Args:
|
|
@@ -179,8 +173,8 @@ class ConnectivityHandler(object):
|
|
|
179
173
|
self._output_tensors2nodes[output_t] = node
|
|
180
174
|
|
|
181
175
|
def get_edge_indices(self,
|
|
182
|
-
src_node:
|
|
183
|
-
dst_node:
|
|
176
|
+
src_node: Node,
|
|
177
|
+
dst_node: Node,
|
|
184
178
|
connecting_tensor: TFReference) -> Tuple[int, int]:
|
|
185
179
|
"""
|
|
186
180
|
Get indices of an edge by its source/destination nodes and the connecting tensor which defines the edge.
|
|
@@ -199,7 +193,7 @@ class ConnectivityHandler(object):
|
|
|
199
193
|
return src_index, dst_index
|
|
200
194
|
|
|
201
195
|
def get_out_edges_params_list(self,
|
|
202
|
-
src_node:
|
|
196
|
+
src_node: Node) -> List[tuple]:
|
|
203
197
|
"""
|
|
204
198
|
Compute for a given node, all parameters of its outgoing edges.
|
|
205
199
|
Args:
|
|
@@ -18,11 +18,11 @@ from typing import List, Dict, Tuple
|
|
|
18
18
|
|
|
19
19
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
20
20
|
from model_compression_toolkit.common.graph.edge import Edge, convert_to_edge
|
|
21
|
-
from model_compression_toolkit.common.graph.
|
|
21
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
22
22
|
from model_compression_toolkit.keras.reader.connectivity_handler import OutTensor
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
def merge_models_edges(inner_model_node:
|
|
25
|
+
def merge_models_edges(inner_model_node: Node,
|
|
26
26
|
outer_graph: Graph,
|
|
27
27
|
inner_graph: Graph) -> List[Edge]:
|
|
28
28
|
"""
|
|
@@ -91,8 +91,8 @@ def rewire_outgoing_edged(inner_model_outputs: List[OutTensor],
|
|
|
91
91
|
res_edges.remove(model_node_out_edge)
|
|
92
92
|
|
|
93
93
|
|
|
94
|
-
def rewire_incoming_edges(inner_inputs_out_edges: Dict[
|
|
95
|
-
inner_model_inputs_dict: List[
|
|
94
|
+
def rewire_incoming_edges(inner_inputs_out_edges: Dict[Node, List[Edge]],
|
|
95
|
+
inner_model_inputs_dict: List[Node],
|
|
96
96
|
model_node_in_edges: List[Edge],
|
|
97
97
|
res_edge_list: List[Edge]):
|
|
98
98
|
"""
|
|
@@ -126,7 +126,7 @@ def rewire_incoming_edges(inner_inputs_out_edges: Dict[BaseNode, List[Edge]],
|
|
|
126
126
|
res_edge_list.remove(model_node_in_edge)
|
|
127
127
|
|
|
128
128
|
|
|
129
|
-
def get_inner_inputs_successors(inner_graph: Graph) -> Dict[
|
|
129
|
+
def get_inner_inputs_successors(inner_graph: Graph) -> Dict[Node, List[Edge]]:
|
|
130
130
|
"""
|
|
131
131
|
Compute out edges the input nodes of the inner model has.
|
|
132
132
|
Args:
|
|
@@ -147,7 +147,7 @@ def get_inner_inputs_successors(inner_graph: Graph) -> Dict[BaseNode, List[Edge]
|
|
|
147
147
|
return inner_inputs_out_edges
|
|
148
148
|
|
|
149
149
|
|
|
150
|
-
def get_model_node_edges(model_node:
|
|
150
|
+
def get_model_node_edges(model_node: Node,
|
|
151
151
|
outer_edge_list: List[Edge]) -> Tuple[List[Edge], List[Edge]]:
|
|
152
152
|
"""
|
|
153
153
|
Get incoming and outgoing edges the inner model node has in the outer graph.
|
|
@@ -15,13 +15,13 @@
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
18
|
-
from model_compression_toolkit.common.graph.
|
|
18
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
19
19
|
from model_compression_toolkit.keras.reader.nested_model.edges_merger import merge_models_edges
|
|
20
20
|
from model_compression_toolkit.keras.reader.nested_model.nodes_merger import merge_models_nodes
|
|
21
21
|
from model_compression_toolkit.keras.reader.nested_model.outputs_merger import merge_models_outputs
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def merge_graphs(inner_model_node:
|
|
24
|
+
def merge_graphs(inner_model_node: Node,
|
|
25
25
|
outer_graph: Graph,
|
|
26
26
|
inner_graph: Graph) -> Graph:
|
|
27
27
|
"""
|
|
@@ -20,12 +20,12 @@ import copy
|
|
|
20
20
|
from typing import List
|
|
21
21
|
|
|
22
22
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
23
|
-
from model_compression_toolkit.common.graph.
|
|
23
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def merge_models_nodes(inner_model_node:
|
|
26
|
+
def merge_models_nodes(inner_model_node: Node,
|
|
27
27
|
outer_graph: Graph,
|
|
28
|
-
inner_graph: Graph) -> List[
|
|
28
|
+
inner_graph: Graph) -> List[Node]:
|
|
29
29
|
"""
|
|
30
30
|
Given two MultiDiGraphs (one of an outer model and the second of the inner model), merge their nodes into
|
|
31
31
|
a single nodes list representing the nodes that should be in a single MultiDiGraph after unrolling the inner graph.
|
|
@@ -19,11 +19,11 @@ import copy
|
|
|
19
19
|
from typing import List
|
|
20
20
|
|
|
21
21
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
22
|
-
from model_compression_toolkit.common.graph.
|
|
22
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
23
23
|
from model_compression_toolkit.keras.reader.connectivity_handler import OutTensor
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def merge_models_outputs(inner_model_node:
|
|
26
|
+
def merge_models_outputs(inner_model_node: Node,
|
|
27
27
|
outer_graph: Graph,
|
|
28
28
|
inner_graph: Graph) -> List[OutTensor]:
|
|
29
29
|
"""
|
|
@@ -12,21 +12,12 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from typing import Any
|
|
16
15
|
|
|
17
|
-
import tensorflow as tf
|
|
18
16
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
|
|
22
|
-
from tensorflow.python.keras.engine.node import Node as KerasNode
|
|
23
|
-
else:
|
|
24
|
-
from keras.layers.core import TFOpLambda, SlicingOpLambda
|
|
25
|
-
from keras.engine.keras_tensor import KerasTensor
|
|
26
|
-
from keras.engine.node import Node as KerasNode
|
|
17
|
+
import tensorflow as tf
|
|
18
|
+
from tensorflow.python.keras.engine.node import Node as KerasNode
|
|
27
19
|
|
|
28
|
-
from model_compression_toolkit.common.graph.
|
|
29
|
-
from model_compression_toolkit.common.graph.functional_node import FunctionalNode
|
|
20
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
30
21
|
|
|
31
22
|
keras = tf.keras
|
|
32
23
|
layers = keras.layers
|
|
@@ -35,7 +26,7 @@ REUSED_IDENTIFIER = '_reused_'
|
|
|
35
26
|
|
|
36
27
|
|
|
37
28
|
def build_node(node: KerasNode,
|
|
38
|
-
node_name_to_node: dict) ->
|
|
29
|
+
node_name_to_node: dict) -> Node:
|
|
39
30
|
"""
|
|
40
31
|
Build a node from a Keras node. A node contains all information to reconstruct the layer it's representing
|
|
41
32
|
in a model:
|
|
@@ -51,8 +42,7 @@ def build_node(node: KerasNode,
|
|
|
51
42
|
"""
|
|
52
43
|
keras_layer = node.layer # get the layer the node represents.
|
|
53
44
|
layer_config = keras_layer.get_config() # layer configuration to reconstruct it.
|
|
54
|
-
op_call_args = node.
|
|
55
|
-
op_call_kwargs = node.call_kwargs
|
|
45
|
+
op_call_args = node.call_kwargs
|
|
56
46
|
layer_class = type(keras_layer) # class path to instantiating it in back2framework.
|
|
57
47
|
weights = {v.name: v.numpy() for v in keras_layer.weights} # layer's weights
|
|
58
48
|
|
|
@@ -76,56 +66,16 @@ def build_node(node: KerasNode,
|
|
|
76
66
|
input_shape = keras_layer.get_input_shape_at(io_index)
|
|
77
67
|
output_shape = keras_layer.get_output_shape_at(io_index)
|
|
78
68
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
node = FunctionalNode(node_name,
|
|
89
|
-
layer_config,
|
|
90
|
-
input_shape,
|
|
91
|
-
output_shape,
|
|
92
|
-
weights,
|
|
93
|
-
layer_class,
|
|
94
|
-
[arg for arg in op_call_args if not isinstance(arg, KerasTensor)], # Do not hold the tensors that are in op_call_args
|
|
95
|
-
{k: v for k, v in op_call_kwargs.items() if not isinstance(v, KerasTensor)}, # In TF2.5 tensors are in kwargs as well.
|
|
96
|
-
is_reused,
|
|
97
|
-
reuse_group,
|
|
98
|
-
functional_op=keras_layer.function,
|
|
99
|
-
inputs_as_list=inputs_as_list)
|
|
100
|
-
else:
|
|
101
|
-
node = BaseNode(node_name,
|
|
102
|
-
layer_config,
|
|
103
|
-
input_shape,
|
|
104
|
-
output_shape,
|
|
105
|
-
weights,
|
|
106
|
-
layer_class,
|
|
107
|
-
is_reused,
|
|
108
|
-
reuse_group)
|
|
69
|
+
node = Node(node_name,
|
|
70
|
+
layer_config,
|
|
71
|
+
input_shape,
|
|
72
|
+
output_shape,
|
|
73
|
+
weights,
|
|
74
|
+
layer_class,
|
|
75
|
+
is_reused,
|
|
76
|
+
reuse_group,
|
|
77
|
+
op_call_args)
|
|
109
78
|
|
|
110
79
|
node_name_to_node[node_name] = node
|
|
111
|
-
return node
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def __is_functional_inputs_a_list(op_call_args: Any) -> bool:
|
|
115
|
-
"""
|
|
116
|
-
Check whether the input tensors should be passed as a list
|
|
117
|
-
or not.
|
|
118
80
|
|
|
119
|
-
|
|
120
|
-
op_call_args: Arguments list to check.
|
|
121
|
-
|
|
122
|
-
Returns:
|
|
123
|
-
Whether the input tensors should be passed as a list or not.
|
|
124
|
-
"""
|
|
125
|
-
|
|
126
|
-
if len(op_call_args) > 0 and isinstance(op_call_args[0], list):
|
|
127
|
-
inputs_as_list = True
|
|
128
|
-
for arg in op_call_args[0]:
|
|
129
|
-
inputs_as_list = inputs_as_list and isinstance(arg, KerasTensor)
|
|
130
|
-
return inputs_as_list
|
|
131
|
-
return False
|
|
81
|
+
return node
|
|
@@ -24,8 +24,8 @@ from typing import List
|
|
|
24
24
|
|
|
25
25
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
26
26
|
from model_compression_toolkit.common.graph.edge import Edge
|
|
27
|
-
from model_compression_toolkit.common.graph.
|
|
28
|
-
from model_compression_toolkit.keras.reader.common import is_node_a_model
|
|
27
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
28
|
+
from model_compression_toolkit.keras.reader.common import is_node_a_model
|
|
29
29
|
from model_compression_toolkit.keras.reader.connectivity_handler import ConnectivityHandler
|
|
30
30
|
from model_compression_toolkit.keras.reader.nested_model.nested_model_handler import merge_graphs
|
|
31
31
|
|
|
@@ -70,6 +70,7 @@ def build_tensors_list(tensors_list) -> List[TFReference]:
|
|
|
70
70
|
tensor in
|
|
71
71
|
tensors_list]
|
|
72
72
|
|
|
73
|
+
|
|
73
74
|
def build_connectivity_handler(model: Model) -> ConnectivityHandler:
|
|
74
75
|
"""
|
|
75
76
|
Build a connectivity handler containing all information about connections in the model (nodes and
|
|
@@ -84,8 +85,7 @@ def build_connectivity_handler(model: Model) -> ConnectivityHandler:
|
|
|
84
85
|
connectivity_handler = ConnectivityHandler()
|
|
85
86
|
for nodes in model._nodes_by_depth.values():
|
|
86
87
|
for node in nodes: # nodes_by depth values are lists (each list for a different depth)
|
|
87
|
-
|
|
88
|
-
input_tensors = build_tensors_list(node_inputs) # build input tensors of the node
|
|
88
|
+
input_tensors = build_tensors_list(node.input_tensors) # build input tensors of the node
|
|
89
89
|
output_tensors = build_tensors_list(node.output_tensors) # build output tensors of the node
|
|
90
90
|
connectivity_handler.add_node(node,
|
|
91
91
|
input_tensors,
|
|
@@ -143,7 +143,7 @@ def parse_model(model: Model) -> Graph:
|
|
|
143
143
|
|
|
144
144
|
|
|
145
145
|
def flatten_nested_model(outer_graph: Graph,
|
|
146
|
-
inner_model_node:
|
|
146
|
+
inner_model_node: Node,
|
|
147
147
|
outer_keras_model: Model):
|
|
148
148
|
"""
|
|
149
149
|
Flat a nested model given two graphs: inner and outer models' graphs.
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
# Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
from tensorflow.keras.layers import ReLU, Activation
|
|
18
|
+
|
|
19
|
+
from model_compression_toolkit import common
|
|
20
|
+
from model_compression_toolkit.common import FrameworkInfo
|
|
21
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
22
|
+
from model_compression_toolkit.common.statistics_collector import BaseStatsContainer
|
|
23
|
+
from model_compression_toolkit.keras.constants import LINEAR, ACTIVATION, RELU_MAX_VALUE, THRESHOLD, NEGATIVE_SLOPE
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_stats_collector_for_activation_op(n: Node,
|
|
27
|
+
fw_info: FrameworkInfo) -> common.StatsContainer:
|
|
28
|
+
"""
|
|
29
|
+
Create and initial a statistics collector for an activation layer. If the activation function's min/max
|
|
30
|
+
output values are known, the statistics collector is initialized with these values.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
|
34
|
+
groups of layers by how they should be quantized, etc.)
|
|
35
|
+
n: Node to create a statistics collector for it.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Statistics collector initialized with known min/max values.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
if n.layer_class == ReLU:
|
|
42
|
+
negative_slope = n.framework_attr[NEGATIVE_SLOPE]
|
|
43
|
+
threshold = n.framework_attr[THRESHOLD]
|
|
44
|
+
init_max = n.framework_attr[RELU_MAX_VALUE]
|
|
45
|
+
return common.StatsContainer(init_min_value=threshold if negative_slope == 0 else None,
|
|
46
|
+
init_max_value=init_max)
|
|
47
|
+
|
|
48
|
+
if n.layer_class == Activation:
|
|
49
|
+
init_min, init_max = fw_info.activation_min_max_mapping[n.framework_attr[ACTIVATION]]
|
|
50
|
+
return common.StatsContainer(init_min_value=init_min,
|
|
51
|
+
init_max_value=init_max)
|
|
52
|
+
|
|
53
|
+
if fw_info.layers_has_min_max(n.layer_class):
|
|
54
|
+
init_min, init_max = fw_info.layer_min_max_mapping[n.layer_class]
|
|
55
|
+
return common.StatsContainer(init_min_value=init_min,
|
|
56
|
+
init_max_value=init_max)
|
|
57
|
+
|
|
58
|
+
return common.StatsContainer()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def get_stats_collector_for_kernel_op(n: common.Node,
|
|
62
|
+
fw_info: FrameworkInfo) -> BaseStatsContainer:
|
|
63
|
+
"""
|
|
64
|
+
Create and initial a statistics collector for a linear operator. If the layer has an activation function and
|
|
65
|
+
its min/max output values are known, the statistics collector is initialized with these values.
|
|
66
|
+
If the layer's output should not be quantized, NoStatsContainer is created.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
|
70
|
+
groups of layers by how they should be quantized, etc.)
|
|
71
|
+
n: Node to create a statistics collector for it.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
BaseStatsContainer according to statistics that are collected.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
if n.framework_attr[ACTIVATION] == LINEAR and n.output_quantization:
|
|
78
|
+
return common.StatsContainer()
|
|
79
|
+
|
|
80
|
+
if n.framework_attr[ACTIVATION] == LINEAR and not n.output_quantization:
|
|
81
|
+
return common.NoStatsContainer()
|
|
82
|
+
|
|
83
|
+
if n.framework_attr[ACTIVATION] in fw_info.activation_min_max_mapping.keys():
|
|
84
|
+
min_value, max_value = fw_info.activation_min_max_mapping[n.framework_attr[ACTIVATION]]
|
|
85
|
+
return common.StatsContainer(init_min_value=min_value,
|
|
86
|
+
init_max_value=max_value)
|
|
87
|
+
|
|
88
|
+
return common.StatsContainer()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_node_stats_collector(node: common.Node,
|
|
92
|
+
fw_info: common.FrameworkInfo) -> common.statistics_collector.BaseStatsContainer:
|
|
93
|
+
"""
|
|
94
|
+
Gets a node and a groups list and create and return a statistics collector for the node
|
|
95
|
+
according to the group the node is in.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
node: Node to create its statistics collector.
|
|
99
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
|
100
|
+
groups of layers by how they should be quantized, etc.)
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Statistics collector for statistics collection for the node.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
stats_collector = get_stats_collector_for_activation_op(node, fw_info)
|
|
107
|
+
if fw_info.in_no_quantization_ops(node): # node should not be quantized
|
|
108
|
+
stats_collector = common.NoStatsContainer()
|
|
109
|
+
|
|
110
|
+
if fw_info.in_kernel_ops(node): # node's kernel should be quantized
|
|
111
|
+
stats_collector = get_stats_collector_for_kernel_op(node, fw_info)
|
|
112
|
+
|
|
113
|
+
return stats_collector
|
|
@@ -22,10 +22,10 @@ from model_compression_toolkit.common import Graph
|
|
|
22
22
|
from model_compression_toolkit.common.similarity_analyzer import compute_cs
|
|
23
23
|
from model_compression_toolkit.keras.back2framework.model_builder import model_builder
|
|
24
24
|
from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
|
|
25
|
-
from model_compression_toolkit.common.graph.
|
|
25
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def get_compare_points(input_graph: Graph) -> Tuple[List[
|
|
28
|
+
def get_compare_points(input_graph: Graph) -> Tuple[List[Node], List[str]]:
|
|
29
29
|
"""
|
|
30
30
|
Create a list of nodes in a graph where we collect their output statistics, and a corresponding list
|
|
31
31
|
of their names for tensors comparison purposes.
|
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
from model_compression_toolkit import common
|
|
18
|
-
from model_compression_toolkit.common.collectors.statistics_collector import BaseStatsCollector
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def create_stats_collector_for_node(node: common.BaseNode,
|
|
22
|
-
output_channel_index: int) -> BaseStatsCollector:
|
|
23
|
-
"""
|
|
24
|
-
Gets a node and a groups list and create and return a statistics collector for a node
|
|
25
|
-
according to whether its statistics should be collected and the prior information we
|
|
26
|
-
have about this node.
|
|
27
|
-
|
|
28
|
-
Args:
|
|
29
|
-
node: Node to create its statistics collector.
|
|
30
|
-
output_channel_index: Index of output channels (for statistics per-channel).
|
|
31
|
-
|
|
32
|
-
Returns:
|
|
33
|
-
Statistics collector for statistics collection for the node.
|
|
34
|
-
"""
|
|
35
|
-
|
|
36
|
-
if node.is_activation_quantization_enabled():
|
|
37
|
-
stats_collector = common.StatsCollector(init_min_value=node.prior_info.min_output,
|
|
38
|
-
init_max_value=node.prior_info.max_output,
|
|
39
|
-
output_channel_index=output_channel_index)
|
|
40
|
-
else:
|
|
41
|
-
stats_collector = common.NoStatsCollector()
|
|
42
|
-
|
|
43
|
-
return stats_collector
|