mct-nightly 1.1.0.7012022.post2611__py3-none-any.whl → 1.1.0.07122021-002414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/METADATA +3 -3
  2. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/RECORD +72 -76
  3. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/common/__init__.py +2 -2
  5. model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +2 -2
  6. model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +9 -9
  7. model_compression_toolkit/common/collectors/mean_collector.py +2 -3
  8. model_compression_toolkit/common/collectors/min_max_per_channel_collector.py +3 -6
  9. model_compression_toolkit/common/constants.py +0 -1
  10. model_compression_toolkit/common/framework_implementation.py +6 -22
  11. model_compression_toolkit/common/framework_info.py +7 -39
  12. model_compression_toolkit/common/graph/__init__.py +1 -1
  13. model_compression_toolkit/common/graph/base_graph.py +34 -34
  14. model_compression_toolkit/common/graph/edge.py +3 -3
  15. model_compression_toolkit/common/graph/graph_matchers.py +3 -3
  16. model_compression_toolkit/common/graph/graph_searches.py +4 -4
  17. model_compression_toolkit/common/graph/graph_vis.py +116 -0
  18. model_compression_toolkit/common/graph/{base_node.py → node.py} +27 -11
  19. model_compression_toolkit/common/mixed_precision/bit_width_setter.py +9 -10
  20. model_compression_toolkit/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  21. model_compression_toolkit/common/model_collector.py +12 -14
  22. model_compression_toolkit/common/network_editors/actions.py +23 -19
  23. model_compression_toolkit/common/post_training_quantization.py +7 -20
  24. model_compression_toolkit/common/quantization/node_quantization_config.py +5 -13
  25. model_compression_toolkit/common/quantization/quantization_analyzer.py +7 -11
  26. model_compression_toolkit/common/quantization/quantization_config.py +6 -6
  27. model_compression_toolkit/common/quantization/quantization_params_fn_selection.py +3 -2
  28. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_activations_computation.py +7 -13
  29. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +20 -17
  30. model_compression_toolkit/common/quantization/quantize_node.py +2 -2
  31. model_compression_toolkit/common/quantization/set_node_quantization_config.py +36 -39
  32. model_compression_toolkit/common/{collectors/statistics_collector.py → statistics_collector.py} +30 -26
  33. model_compression_toolkit/common/visualization/tensorboard_writer.py +8 -11
  34. model_compression_toolkit/keras/back2framework/instance_builder.py +4 -4
  35. model_compression_toolkit/keras/back2framework/model_builder.py +34 -47
  36. model_compression_toolkit/keras/constants.py +0 -3
  37. model_compression_toolkit/keras/default_framework_info.py +7 -33
  38. model_compression_toolkit/keras/gradient_ptq/graph_info.py +2 -2
  39. model_compression_toolkit/keras/gradient_ptq/graph_update.py +1 -7
  40. model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +1 -0
  41. model_compression_toolkit/keras/graph_substitutions/substitutions/activation_decomposition.py +8 -10
  42. model_compression_toolkit/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -2
  43. model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +2 -2
  44. model_compression_toolkit/keras/graph_substitutions/substitutions/mark_activation.py +3 -3
  45. model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +4 -3
  46. model_compression_toolkit/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +2 -2
  47. model_compression_toolkit/keras/graph_substitutions/substitutions/scale_equalization.py +9 -9
  48. model_compression_toolkit/keras/graph_substitutions/substitutions/separableconv_decomposition.py +19 -19
  49. model_compression_toolkit/keras/graph_substitutions/substitutions/shift_negative_activation.py +45 -64
  50. model_compression_toolkit/keras/keras_implementation.py +8 -28
  51. model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +2 -2
  52. model_compression_toolkit/keras/quantization_facade.py +1 -5
  53. model_compression_toolkit/keras/quantizer/fake_quant_builder.py +4 -4
  54. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer.py +2 -3
  55. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer_gptq_config.py +4 -8
  56. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_weight_quantizer_gptq_config.py +4 -9
  57. model_compression_toolkit/keras/quantizer/gradient_ptq/config_factory.py +10 -9
  58. model_compression_toolkit/keras/quantizer/gradient_ptq/weight_quantizer_gptq_config.py +1 -9
  59. model_compression_toolkit/keras/quantizer/mixed_precision/quantization_config_factory.py +1 -1
  60. model_compression_toolkit/keras/quantizer/mixed_precision/selective_weights_quantize_config.py +1 -6
  61. model_compression_toolkit/keras/reader/common.py +11 -9
  62. model_compression_toolkit/keras/reader/connectivity_handler.py +9 -15
  63. model_compression_toolkit/keras/reader/nested_model/edges_merger.py +6 -6
  64. model_compression_toolkit/keras/reader/nested_model/nested_model_handler.py +2 -2
  65. model_compression_toolkit/keras/reader/nested_model/nodes_merger.py +3 -3
  66. model_compression_toolkit/keras/reader/nested_model/outputs_merger.py +2 -2
  67. model_compression_toolkit/keras/reader/node_builder.py +15 -65
  68. model_compression_toolkit/keras/reader/reader.py +5 -5
  69. model_compression_toolkit/keras/tensor_marking.py +113 -0
  70. model_compression_toolkit/keras/visualization/nn_visualizer.py +2 -2
  71. model_compression_toolkit/common/collectors/statistics_collector_generator.py +0 -43
  72. model_compression_toolkit/common/graph/functional_node.py +0 -59
  73. model_compression_toolkit/common/model_validation.py +0 -43
  74. model_compression_toolkit/common/node_prior_info.py +0 -29
  75. model_compression_toolkit/keras/keras_model_validation.py +0 -38
  76. model_compression_toolkit/keras/keras_node_prior_info.py +0 -60
  77. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/LICENSE +0 -0
  78. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from tensorflow.keras.layers import DepthwiseConv2D
17
+ from tensorflow.keras.layers import DepthwiseConv2D, Conv2D, Conv2DTranspose, Dense
18
18
  from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_configs import \
19
19
  NoOpQuantizeConfig
20
20
  from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_registry import \
@@ -30,7 +30,7 @@ from model_compression_toolkit.common.framework_info import FrameworkInfo
30
30
  MAX_LSBS_CHANGE = 8
31
31
 
32
32
 
33
- def quantization_config_builder_gptq(n: common.BaseNode,
33
+ def quantization_config_builder_gptq(n: common.Node,
34
34
  fw_info: FrameworkInfo) -> QuantizeConfig:
35
35
  """
36
36
  Build a QuantizeConfig for a node according to its quantization configuration and
@@ -45,28 +45,29 @@ def quantization_config_builder_gptq(n: common.BaseNode,
45
45
  quantization configuration).
46
46
  """
47
47
 
48
- if n.is_weights_quantization_enabled() and n.is_activation_quantization_enabled():
48
+ if n.activation_quantization() and n.weight_quantization():
49
49
  qc = keras.quantizer.gradient_ptq.ActivationAndWeightQuantizeConfig(fw_info.get_kernel_op_attributes(n.layer_class),
50
50
  n.final_weights_quantization_cfg.weights_quantization_params.get(THRESHOLD),
51
51
  n.final_weights_quantization_cfg.weights_channels_axis,
52
52
  n.final_weights_quantization_cfg.weights_n_bits,
53
53
  n.activation_quantization_cfg.activation_quantization_params,
54
+ n.activation_quantization_cfg.activation_is_signed,
54
55
  activation_num_bits=n.activation_quantization_cfg.activation_n_bits,
55
56
  max_lsbs_change=MAX_LSBS_CHANGE
56
57
  )
57
-
58
- elif n.is_activation_quantization_enabled() and not n.is_weights_quantization_enabled():
58
+ elif n.activation_quantization():
59
59
  qc = keras.quantizer.gradient_ptq.ActivationQuantizeConfig(n.activation_quantization_cfg.activation_quantization_params,
60
+ n.activation_quantization_cfg.activation_is_signed,
60
61
  num_bits=n.activation_quantization_cfg.activation_n_bits)
61
-
62
- elif n.is_weights_quantization_enabled() and not n.is_activation_quantization_enabled():
62
+ elif n.weight_quantization():
63
63
  qc = keras.quantizer.gradient_ptq.WeightQuantizeConfig(fw_info.get_kernel_op_attributes(n.layer_class),
64
64
  n.final_weights_quantization_cfg.weights_quantization_params.get(THRESHOLD),
65
65
  n.final_weights_quantization_cfg.weights_channels_axis,
66
66
  n.final_weights_quantization_cfg.weights_n_bits,
67
- max_lsbs_change=MAX_LSBS_CHANGE)
67
+ max_lsbs_change=MAX_LSBS_CHANGE
68
+ )
68
69
 
69
- elif not n.is_weights_quantization_enabled() and not n.is_activation_quantization_enabled():
70
+ elif n.no_quantization():
70
71
  qc = NoOpQuantizeConfig()
71
72
 
72
73
  else:
@@ -16,15 +16,7 @@
16
16
  from typing import List, Tuple, Any, Dict
17
17
 
18
18
  from tensorflow import Tensor
19
- import tensorflow as tf
20
-
21
- # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
22
- if tf.__version__ < "2.6":
23
- from tensorflow.python.keras.layers import Layer
24
- else:
25
- from keras.engine.base_layer import Layer
26
-
27
-
19
+ from tensorflow.python.keras.layers import Layer
28
20
  from tensorflow.python.training.tracking.data_structures import ListWrapper
29
21
  from tensorflow_model_optimization.python.core.quantization.keras.quantizers import Quantizer
30
22
 
@@ -25,7 +25,7 @@ from model_compression_toolkit.keras.quantizer.mixed_precision.selective_weights
25
25
 
26
26
 
27
27
 
28
- def quantization_config_builder_mixed_precision(n: common.BaseNode,
28
+ def quantization_config_builder_mixed_precision(n: common.Node,
29
29
  fw_info: FrameworkInfo) -> QuantizeConfig:
30
30
  """
31
31
  Build a QuantizeConfig for layers that should be wrapped in a QuantizeWrapper to
@@ -17,12 +17,7 @@ from typing import List, Tuple, Any, Dict
17
17
 
18
18
  import numpy as np
19
19
  from tensorflow import Tensor
20
- import tensorflow as tf
21
- # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
22
- if tf.__version__ < "2.6":
23
- from tensorflow.python.keras.layers import Layer
24
- else:
25
- from keras.engine.base_layer import Layer
20
+ from tensorflow.python.keras.layers import Layer
26
21
  from tensorflow.python.training.tracking.data_structures import ListWrapper
27
22
  from tensorflow_model_optimization.python.core.quantization.keras.quantize_config import QuantizeConfig
28
23
 
@@ -20,19 +20,20 @@ import tensorflow as tf
20
20
  if tf.__version__ < "2.6":
21
21
  from tensorflow.python.keras.engine.node import Node as KerasNode
22
22
  from tensorflow.keras.layers import InputLayer
23
- from tensorflow.python.keras.engine.functional import Functional
24
- from tensorflow.python.keras.engine.sequential import Sequential
25
23
  else:
26
24
  from keras.engine.input_layer import InputLayer
27
25
  from keras.engine.node import Node as KerasNode
28
- from keras.engine.functional import Functional
29
- from keras.engine.sequential import Sequential
30
26
 
31
- from model_compression_toolkit.common.graph.base_node import BaseNode
32
27
 
28
+ from tensorflow.python.keras.engine.functional import Functional
33
29
 
30
+ from tensorflow.python.keras.engine.sequential import Sequential
34
31
 
35
- def is_node_an_input_layer(node: BaseNode) -> bool:
32
+ from model_compression_toolkit.common.graph.node import Node
33
+
34
+
35
+
36
+ def is_node_an_input_layer(node: Node) -> bool:
36
37
  """
37
38
  Checks if a node represents a Keras input layer.
38
39
  Args:
@@ -41,14 +42,14 @@ def is_node_an_input_layer(node: BaseNode) -> bool:
41
42
  Returns:
42
43
  Whether the node represents an input layer or not.
43
44
  """
44
- if isinstance(node, BaseNode):
45
+ if isinstance(node, Node):
45
46
  return node.layer_class == InputLayer
46
47
  elif isinstance(node, KerasNode):
47
48
  return isinstance(node.layer, InputLayer)
48
49
  else:
49
50
  raise Exception('Node to check has to be either a graph node or a keras node')
50
51
 
51
- def is_node_a_model(node: BaseNode) -> bool:
52
+ def is_node_a_model(node: Node) -> bool:
52
53
  """
53
54
  Checks if a node represents a Keras model.
54
55
  Args:
@@ -57,10 +58,11 @@ def is_node_a_model(node: BaseNode) -> bool:
57
58
  Returns:
58
59
  Whether the node represents a Keras model or not.
59
60
  """
60
- if isinstance(node, BaseNode):
61
+ if isinstance(node, Node):
61
62
  return node.layer_class in [Functional, Sequential]
62
63
  elif isinstance(node, KerasNode):
63
64
  return isinstance(node.layer, Functional) or isinstance(node.layer, Sequential)
64
65
  else:
65
66
  raise Exception('Node to check has to be either a graph node or a keras node')
66
67
 
68
+ # return node.layer_class in [Functional, Sequential]
@@ -15,18 +15,12 @@
15
15
 
16
16
 
17
17
  import tensorflow as tf
18
- # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
19
- if tf.__version__ < "2.6":
20
- from tensorflow.python.keras.engine.node import Node as KerasNode
21
- else:
22
- from keras.engine.node import Node as KerasNode
23
-
24
-
18
+ from tensorflow.python.keras.engine.node import Node as KerasNode
25
19
  from tensorflow.python.util.object_identity import Reference as TFReference
26
20
  from typing import List, Tuple
27
21
 
28
22
  from model_compression_toolkit.common.graph.base_graph import OutTensor
29
- from model_compression_toolkit.common.graph.base_node import BaseNode
23
+ from model_compression_toolkit.common.graph.node import Node
30
24
  from model_compression_toolkit.keras.reader.common import is_node_an_input_layer
31
25
  from model_compression_toolkit.keras.reader.node_builder import build_node
32
26
 
@@ -47,7 +41,7 @@ class ConnectivityHandler(object):
47
41
  self._nodes2output_tensors = dict() # Node -> List[Tensor]
48
42
  self._output_tensors2nodes = dict() # Tensor -> Node
49
43
 
50
- def get_nodes(self) -> List[BaseNode]:
44
+ def get_nodes(self) -> List[Node]:
51
45
  """
52
46
  Returns: List of nodes in the connectivity handler.
53
47
  """
@@ -66,7 +60,7 @@ class ConnectivityHandler(object):
66
60
  return self._input_tensors2nodes.get(tensor) is not None
67
61
 
68
62
  def input_tensor2nodes(self,
69
- in_tensor: str) -> List[BaseNode]:
63
+ in_tensor: str) -> List[Node]:
70
64
  """
71
65
  Returns a list of nodes that have a given tensor in their input tensors.
72
66
  Args:
@@ -90,7 +84,7 @@ class ConnectivityHandler(object):
90
84
  return self._output_tensors2nodes[out_tensor] if out_tensor in self._output_tensors2nodes else None
91
85
 
92
86
  def node2input_tensors(self,
93
- node: BaseNode) -> List[TFReference]:
87
+ node: Node) -> List[TFReference]:
94
88
  """
95
89
  Get a list of input tensors of a node.
96
90
  Args:
@@ -102,7 +96,7 @@ class ConnectivityHandler(object):
102
96
  return self._nodes2input_tensors[node] if node in self._nodes2input_tensors else []
103
97
 
104
98
  def node2output_tensors(self,
105
- node: BaseNode) -> List[TFReference]:
99
+ node: Node) -> List[TFReference]:
106
100
  """
107
101
  Get a list of output tensors of a node.
108
102
  Args:
@@ -179,8 +173,8 @@ class ConnectivityHandler(object):
179
173
  self._output_tensors2nodes[output_t] = node
180
174
 
181
175
  def get_edge_indices(self,
182
- src_node: BaseNode,
183
- dst_node: BaseNode,
176
+ src_node: Node,
177
+ dst_node: Node,
184
178
  connecting_tensor: TFReference) -> Tuple[int, int]:
185
179
  """
186
180
  Get indices of an edge by its source/destination nodes and the connecting tensor which defines the edge.
@@ -199,7 +193,7 @@ class ConnectivityHandler(object):
199
193
  return src_index, dst_index
200
194
 
201
195
  def get_out_edges_params_list(self,
202
- src_node: BaseNode) -> List[tuple]:
196
+ src_node: Node) -> List[tuple]:
203
197
  """
204
198
  Compute for a given node, all parameters of its outgoing edges.
205
199
  Args:
@@ -18,11 +18,11 @@ from typing import List, Dict, Tuple
18
18
 
19
19
  from model_compression_toolkit.common.graph.base_graph import Graph
20
20
  from model_compression_toolkit.common.graph.edge import Edge, convert_to_edge
21
- from model_compression_toolkit.common.graph.base_node import BaseNode
21
+ from model_compression_toolkit.common.graph.node import Node
22
22
  from model_compression_toolkit.keras.reader.connectivity_handler import OutTensor
23
23
 
24
24
 
25
- def merge_models_edges(inner_model_node: BaseNode,
25
+ def merge_models_edges(inner_model_node: Node,
26
26
  outer_graph: Graph,
27
27
  inner_graph: Graph) -> List[Edge]:
28
28
  """
@@ -91,8 +91,8 @@ def rewire_outgoing_edged(inner_model_outputs: List[OutTensor],
91
91
  res_edges.remove(model_node_out_edge)
92
92
 
93
93
 
94
- def rewire_incoming_edges(inner_inputs_out_edges: Dict[BaseNode, List[Edge]],
95
- inner_model_inputs_dict: List[BaseNode],
94
+ def rewire_incoming_edges(inner_inputs_out_edges: Dict[Node, List[Edge]],
95
+ inner_model_inputs_dict: List[Node],
96
96
  model_node_in_edges: List[Edge],
97
97
  res_edge_list: List[Edge]):
98
98
  """
@@ -126,7 +126,7 @@ def rewire_incoming_edges(inner_inputs_out_edges: Dict[BaseNode, List[Edge]],
126
126
  res_edge_list.remove(model_node_in_edge)
127
127
 
128
128
 
129
- def get_inner_inputs_successors(inner_graph: Graph) -> Dict[BaseNode, List[Edge]]:
129
+ def get_inner_inputs_successors(inner_graph: Graph) -> Dict[Node, List[Edge]]:
130
130
  """
131
131
  Compute out edges the input nodes of the inner model has.
132
132
  Args:
@@ -147,7 +147,7 @@ def get_inner_inputs_successors(inner_graph: Graph) -> Dict[BaseNode, List[Edge]
147
147
  return inner_inputs_out_edges
148
148
 
149
149
 
150
- def get_model_node_edges(model_node: BaseNode,
150
+ def get_model_node_edges(model_node: Node,
151
151
  outer_edge_list: List[Edge]) -> Tuple[List[Edge], List[Edge]]:
152
152
  """
153
153
  Get incoming and outgoing edges the inner model node has in the outer graph.
@@ -15,13 +15,13 @@
15
15
 
16
16
 
17
17
  from model_compression_toolkit.common.graph.base_graph import Graph
18
- from model_compression_toolkit.common.graph.base_node import BaseNode
18
+ from model_compression_toolkit.common.graph.node import Node
19
19
  from model_compression_toolkit.keras.reader.nested_model.edges_merger import merge_models_edges
20
20
  from model_compression_toolkit.keras.reader.nested_model.nodes_merger import merge_models_nodes
21
21
  from model_compression_toolkit.keras.reader.nested_model.outputs_merger import merge_models_outputs
22
22
 
23
23
 
24
- def merge_graphs(inner_model_node: BaseNode,
24
+ def merge_graphs(inner_model_node: Node,
25
25
  outer_graph: Graph,
26
26
  inner_graph: Graph) -> Graph:
27
27
  """
@@ -20,12 +20,12 @@ import copy
20
20
  from typing import List
21
21
 
22
22
  from model_compression_toolkit.common.graph.base_graph import Graph
23
- from model_compression_toolkit.common.graph.base_node import BaseNode
23
+ from model_compression_toolkit.common.graph.node import Node
24
24
 
25
25
 
26
- def merge_models_nodes(inner_model_node: BaseNode,
26
+ def merge_models_nodes(inner_model_node: Node,
27
27
  outer_graph: Graph,
28
- inner_graph: Graph) -> List[BaseNode]:
28
+ inner_graph: Graph) -> List[Node]:
29
29
  """
30
30
  Given two MultiDiGraphs (one of an outer model and the second of the inner model), merge their nodes into
31
31
  a single nodes list representing the nodes that should be in a single MultiDiGraph after unrolling the inner graph.
@@ -19,11 +19,11 @@ import copy
19
19
  from typing import List
20
20
 
21
21
  from model_compression_toolkit.common.graph.base_graph import Graph
22
- from model_compression_toolkit.common.graph.base_node import BaseNode
22
+ from model_compression_toolkit.common.graph.node import Node
23
23
  from model_compression_toolkit.keras.reader.connectivity_handler import OutTensor
24
24
 
25
25
 
26
- def merge_models_outputs(inner_model_node: BaseNode,
26
+ def merge_models_outputs(inner_model_node: Node,
27
27
  outer_graph: Graph,
28
28
  inner_graph: Graph) -> List[OutTensor]:
29
29
  """
@@ -12,21 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Any
16
15
 
17
- import tensorflow as tf
18
16
 
19
- if tf.__version__ < "2.6":
20
- from tensorflow.python.keras.layers.core import TFOpLambda, SlicingOpLambda
21
- from tensorflow.python.keras.engine.keras_tensor import KerasTensor
22
- from tensorflow.python.keras.engine.node import Node as KerasNode
23
- else:
24
- from keras.layers.core import TFOpLambda, SlicingOpLambda
25
- from keras.engine.keras_tensor import KerasTensor
26
- from keras.engine.node import Node as KerasNode
17
+ import tensorflow as tf
18
+ from tensorflow.python.keras.engine.node import Node as KerasNode
27
19
 
28
- from model_compression_toolkit.common.graph.base_node import BaseNode
29
- from model_compression_toolkit.common.graph.functional_node import FunctionalNode
20
+ from model_compression_toolkit.common.graph.node import Node
30
21
 
31
22
  keras = tf.keras
32
23
  layers = keras.layers
@@ -35,7 +26,7 @@ REUSED_IDENTIFIER = '_reused_'
35
26
 
36
27
 
37
28
  def build_node(node: KerasNode,
38
- node_name_to_node: dict) -> BaseNode:
29
+ node_name_to_node: dict) -> Node:
39
30
  """
40
31
  Build a node from a Keras node. A node contains all information to reconstruct the layer it's representing
41
32
  in a model:
@@ -51,8 +42,7 @@ def build_node(node: KerasNode,
51
42
  """
52
43
  keras_layer = node.layer # get the layer the node represents.
53
44
  layer_config = keras_layer.get_config() # layer configuration to reconstruct it.
54
- op_call_args = node.call_args
55
- op_call_kwargs = node.call_kwargs
45
+ op_call_args = node.call_kwargs
56
46
  layer_class = type(keras_layer) # class path to instantiating it in back2framework.
57
47
  weights = {v.name: v.numpy() for v in keras_layer.weights} # layer's weights
58
48
 
@@ -76,56 +66,16 @@ def build_node(node: KerasNode,
76
66
  input_shape = keras_layer.get_input_shape_at(io_index)
77
67
  output_shape = keras_layer.get_output_shape_at(io_index)
78
68
 
79
- if layer_class in [TFOpLambda, SlicingOpLambda]:
80
- # Some functional ops (such as tf.concat) should receive the input tensors as a list
81
- # and some are not (such as tf.multiply), so each FunctionalNode holds
82
- # a flag to indicate that.
83
- inputs_as_list = __is_functional_inputs_a_list(op_call_args)
84
- # Do not hold the tensors that are in op_call_args as they are
85
- # not needed. Thus, if the first element in op_call_args is a list of
86
- # Keras tensors, remove it from op_call_args.
87
- op_call_args = op_call_args[int(inputs_as_list):]
88
- node = FunctionalNode(node_name,
89
- layer_config,
90
- input_shape,
91
- output_shape,
92
- weights,
93
- layer_class,
94
- [arg for arg in op_call_args if not isinstance(arg, KerasTensor)], # Do not hold the tensors that are in op_call_args
95
- {k: v for k, v in op_call_kwargs.items() if not isinstance(v, KerasTensor)}, # In TF2.5 tensors are in kwargs as well.
96
- is_reused,
97
- reuse_group,
98
- functional_op=keras_layer.function,
99
- inputs_as_list=inputs_as_list)
100
- else:
101
- node = BaseNode(node_name,
102
- layer_config,
103
- input_shape,
104
- output_shape,
105
- weights,
106
- layer_class,
107
- is_reused,
108
- reuse_group)
69
+ node = Node(node_name,
70
+ layer_config,
71
+ input_shape,
72
+ output_shape,
73
+ weights,
74
+ layer_class,
75
+ is_reused,
76
+ reuse_group,
77
+ op_call_args)
109
78
 
110
79
  node_name_to_node[node_name] = node
111
- return node
112
-
113
-
114
- def __is_functional_inputs_a_list(op_call_args: Any) -> bool:
115
- """
116
- Check whether the input tensors should be passed as a list
117
- or not.
118
80
 
119
- Args:
120
- op_call_args: Arguments list to check.
121
-
122
- Returns:
123
- Whether the input tensors should be passed as a list or not.
124
- """
125
-
126
- if len(op_call_args) > 0 and isinstance(op_call_args[0], list):
127
- inputs_as_list = True
128
- for arg in op_call_args[0]:
129
- inputs_as_list = inputs_as_list and isinstance(arg, KerasTensor)
130
- return inputs_as_list
131
- return False
81
+ return node
@@ -24,8 +24,8 @@ from typing import List
24
24
 
25
25
  from model_compression_toolkit.common.graph.base_graph import Graph
26
26
  from model_compression_toolkit.common.graph.edge import Edge
27
- from model_compression_toolkit.common.graph.base_node import BaseNode
28
- from model_compression_toolkit.keras.reader.common import is_node_a_model, is_node_an_input_layer
27
+ from model_compression_toolkit.common.graph.node import Node
28
+ from model_compression_toolkit.keras.reader.common import is_node_a_model
29
29
  from model_compression_toolkit.keras.reader.connectivity_handler import ConnectivityHandler
30
30
  from model_compression_toolkit.keras.reader.nested_model.nested_model_handler import merge_graphs
31
31
 
@@ -70,6 +70,7 @@ def build_tensors_list(tensors_list) -> List[TFReference]:
70
70
  tensor in
71
71
  tensors_list]
72
72
 
73
+
73
74
  def build_connectivity_handler(model: Model) -> ConnectivityHandler:
74
75
  """
75
76
  Build a connectivity handler containing all information about connections in the model (nodes and
@@ -84,8 +85,7 @@ def build_connectivity_handler(model: Model) -> ConnectivityHandler:
84
85
  connectivity_handler = ConnectivityHandler()
85
86
  for nodes in model._nodes_by_depth.values():
86
87
  for node in nodes: # nodes_by depth values are lists (each list for a different depth)
87
- node_inputs = node.input_tensors if is_node_an_input_layer(node) else node.keras_inputs
88
- input_tensors = build_tensors_list(node_inputs) # build input tensors of the node
88
+ input_tensors = build_tensors_list(node.input_tensors) # build input tensors of the node
89
89
  output_tensors = build_tensors_list(node.output_tensors) # build output tensors of the node
90
90
  connectivity_handler.add_node(node,
91
91
  input_tensors,
@@ -143,7 +143,7 @@ def parse_model(model: Model) -> Graph:
143
143
 
144
144
 
145
145
  def flatten_nested_model(outer_graph: Graph,
146
- inner_model_node: BaseNode,
146
+ inner_model_node: Node,
147
147
  outer_keras_model: Model):
148
148
  """
149
149
  Flat a nested model given two graphs: inner and outer models' graphs.
@@ -0,0 +1,113 @@
1
+ # Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from tensorflow.keras.layers import ReLU, Activation
18
+
19
+ from model_compression_toolkit import common
20
+ from model_compression_toolkit.common import FrameworkInfo
21
+ from model_compression_toolkit.common.graph.node import Node
22
+ from model_compression_toolkit.common.statistics_collector import BaseStatsContainer
23
+ from model_compression_toolkit.keras.constants import LINEAR, ACTIVATION, RELU_MAX_VALUE, THRESHOLD, NEGATIVE_SLOPE
24
+
25
+
26
+ def get_stats_collector_for_activation_op(n: Node,
27
+ fw_info: FrameworkInfo) -> common.StatsContainer:
28
+ """
29
+ Create and initial a statistics collector for an activation layer. If the activation function's min/max
30
+ output values are known, the statistics collector is initialized with these values.
31
+
32
+ Args:
33
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
34
+ groups of layers by how they should be quantized, etc.)
35
+ n: Node to create a statistics collector for it.
36
+
37
+ Returns:
38
+ Statistics collector initialized with known min/max values.
39
+ """
40
+
41
+ if n.layer_class == ReLU:
42
+ negative_slope = n.framework_attr[NEGATIVE_SLOPE]
43
+ threshold = n.framework_attr[THRESHOLD]
44
+ init_max = n.framework_attr[RELU_MAX_VALUE]
45
+ return common.StatsContainer(init_min_value=threshold if negative_slope == 0 else None,
46
+ init_max_value=init_max)
47
+
48
+ if n.layer_class == Activation:
49
+ init_min, init_max = fw_info.activation_min_max_mapping[n.framework_attr[ACTIVATION]]
50
+ return common.StatsContainer(init_min_value=init_min,
51
+ init_max_value=init_max)
52
+
53
+ if fw_info.layers_has_min_max(n.layer_class):
54
+ init_min, init_max = fw_info.layer_min_max_mapping[n.layer_class]
55
+ return common.StatsContainer(init_min_value=init_min,
56
+ init_max_value=init_max)
57
+
58
+ return common.StatsContainer()
59
+
60
+
61
+ def get_stats_collector_for_kernel_op(n: common.Node,
62
+ fw_info: FrameworkInfo) -> BaseStatsContainer:
63
+ """
64
+ Create and initial a statistics collector for a linear operator. If the layer has an activation function and
65
+ its min/max output values are known, the statistics collector is initialized with these values.
66
+ If the layer's output should not be quantized, NoStatsContainer is created.
67
+
68
+ Args:
69
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
70
+ groups of layers by how they should be quantized, etc.)
71
+ n: Node to create a statistics collector for it.
72
+
73
+ Returns:
74
+ BaseStatsContainer according to statistics that are collected.
75
+ """
76
+
77
+ if n.framework_attr[ACTIVATION] == LINEAR and n.output_quantization:
78
+ return common.StatsContainer()
79
+
80
+ if n.framework_attr[ACTIVATION] == LINEAR and not n.output_quantization:
81
+ return common.NoStatsContainer()
82
+
83
+ if n.framework_attr[ACTIVATION] in fw_info.activation_min_max_mapping.keys():
84
+ min_value, max_value = fw_info.activation_min_max_mapping[n.framework_attr[ACTIVATION]]
85
+ return common.StatsContainer(init_min_value=min_value,
86
+ init_max_value=max_value)
87
+
88
+ return common.StatsContainer()
89
+
90
+
91
+ def get_node_stats_collector(node: common.Node,
92
+ fw_info: common.FrameworkInfo) -> common.statistics_collector.BaseStatsContainer:
93
+ """
94
+ Gets a node and a groups list and create and return a statistics collector for the node
95
+ according to the group the node is in.
96
+
97
+ Args:
98
+ node: Node to create its statistics collector.
99
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
100
+ groups of layers by how they should be quantized, etc.)
101
+
102
+ Returns:
103
+ Statistics collector for statistics collection for the node.
104
+ """
105
+
106
+ stats_collector = get_stats_collector_for_activation_op(node, fw_info)
107
+ if fw_info.in_no_quantization_ops(node): # node should not be quantized
108
+ stats_collector = common.NoStatsContainer()
109
+
110
+ if fw_info.in_kernel_ops(node): # node's kernel should be quantized
111
+ stats_collector = get_stats_collector_for_kernel_op(node, fw_info)
112
+
113
+ return stats_collector
@@ -22,10 +22,10 @@ from model_compression_toolkit.common import Graph
22
22
  from model_compression_toolkit.common.similarity_analyzer import compute_cs
23
23
  from model_compression_toolkit.keras.back2framework.model_builder import model_builder
24
24
  from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
25
- from model_compression_toolkit.common.graph.base_node import BaseNode
25
+ from model_compression_toolkit.common.graph.node import Node
26
26
 
27
27
 
28
- def get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str]]:
28
+ def get_compare_points(input_graph: Graph) -> Tuple[List[Node], List[str]]:
29
29
  """
30
30
  Create a list of nodes in a graph where we collect their output statistics, and a corresponding list
31
31
  of their names for tensors comparison purposes.
@@ -1,43 +0,0 @@
1
- # Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from model_compression_toolkit import common
18
- from model_compression_toolkit.common.collectors.statistics_collector import BaseStatsCollector
19
-
20
-
21
- def create_stats_collector_for_node(node: common.BaseNode,
22
- output_channel_index: int) -> BaseStatsCollector:
23
- """
24
- Gets a node and a groups list and create and return a statistics collector for a node
25
- according to whether its statistics should be collected and the prior information we
26
- have about this node.
27
-
28
- Args:
29
- node: Node to create its statistics collector.
30
- output_channel_index: Index of output channels (for statistics per-channel).
31
-
32
- Returns:
33
- Statistics collector for statistics collection for the node.
34
- """
35
-
36
- if node.is_activation_quantization_enabled():
37
- stats_collector = common.StatsCollector(init_min_value=node.prior_info.min_output,
38
- init_max_value=node.prior_info.max_output,
39
- output_channel_index=output_channel_index)
40
- else:
41
- stats_collector = common.NoStatsCollector()
42
-
43
- return stats_collector