mct-nightly 2.0.0.20240417.406__py3-none-any.whl → 2.0.0.20240419.358__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {mct_nightly-2.0.0.20240417.406.dist-info → mct_nightly-2.0.0.20240419.358.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.0.0.20240417.406.dist-info → mct_nightly-2.0.0.20240419.358.dist-info}/RECORD +60 -57
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/constants.py +2 -0
  5. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  6. model_compression_toolkit/core/common/graph/base_node.py +26 -9
  7. model_compression_toolkit/core/common/graph/functional_node.py +18 -1
  8. model_compression_toolkit/core/common/hessian/hessian_info_service.py +2 -3
  9. model_compression_toolkit/core/common/hessian/trace_hessian_request.py +1 -3
  10. model_compression_toolkit/core/common/network_editors/node_filters.py +4 -3
  11. model_compression_toolkit/core/common/quantization/node_quantization_config.py +0 -5
  12. model_compression_toolkit/core/common/quantization/quantization_config.py +5 -2
  13. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +67 -4
  14. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +12 -4
  15. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +14 -4
  16. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +30 -3
  17. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +17 -7
  18. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +14 -3
  19. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +13 -3
  20. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +16 -3
  21. model_compression_toolkit/core/common/similarity_analyzer.py +16 -4
  22. model_compression_toolkit/core/common/substitutions/remove_identity.py +48 -0
  23. model_compression_toolkit/core/graph_prep_runner.py +10 -4
  24. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +4 -1
  25. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -7
  26. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  27. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_identity.py +51 -0
  28. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  29. model_compression_toolkit/core/keras/keras_implementation.py +13 -11
  30. model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -4
  31. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +4 -5
  32. model_compression_toolkit/core/keras/reader/common.py +2 -2
  33. model_compression_toolkit/core/keras/reader/node_builder.py +28 -9
  34. model_compression_toolkit/core/keras/tf_tensor_numpy.py +5 -2
  35. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +34 -21
  36. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_folding.py +8 -8
  37. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +2 -2
  38. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  39. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +4 -4
  40. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_identity.py +50 -0
  41. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  42. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +8 -8
  43. model_compression_toolkit/core/pytorch/pytorch_implementation.py +7 -6
  44. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +2 -2
  45. model_compression_toolkit/core/quantization_prep_runner.py +6 -2
  46. model_compression_toolkit/core/runner.py +5 -2
  47. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +5 -1
  48. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +9 -2
  49. model_compression_toolkit/gptq/keras/quantization_facade.py +2 -1
  50. model_compression_toolkit/gptq/pytorch/quantization_facade.py +3 -1
  51. model_compression_toolkit/gptq/runner.py +1 -0
  52. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +5 -5
  53. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +1 -1
  54. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +20 -6
  55. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +1 -1
  56. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +22 -8
  57. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +1 -1
  58. {mct_nightly-2.0.0.20240417.406.dist-info → mct_nightly-2.0.0.20240419.358.dist-info}/LICENSE.md +0 -0
  59. {mct_nightly-2.0.0.20240417.406.dist-info → mct_nightly-2.0.0.20240419.358.dist-info}/WHEEL +0 -0
  60. {mct_nightly-2.0.0.20240417.406.dist-info → mct_nightly-2.0.0.20240419.358.dist-info}/top_level.txt +0 -0
@@ -39,6 +39,7 @@ from model_compression_toolkit.core.common import BaseNode
39
39
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
40
40
  from model_compression_toolkit.core.keras.back2framework.instance_builder import OperationHandler
41
41
  from model_compression_toolkit.core.keras.reader.connectivity_handler import OutTensor
42
+ from mct_quantizers import KerasQuantizationWrapper
42
43
 
43
44
  # In tf2.3 fake quant node is implemented as TensorFlowOpLayer, while in tf2.4 as TFOpLambda.
44
45
  FQ_NODE_OP_V2_3 = 'FakeQuantWithMinMaxVars'
@@ -270,7 +271,9 @@ class KerasModelBuilder(BaseModelBuilder):
270
271
  out_tensors_of_n_float)
271
272
  else:
272
273
  input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists
273
- input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
274
+ if not isinstance(op_func, KerasQuantizationWrapper):
275
+ # The KerasQuantizationWrapper will insert the quantized positional weights internally.
276
+ input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
274
277
  # Build a functional node using its args
275
278
  if isinstance(n, FunctionalNode):
276
279
  if n.inputs_as_list: # If the first argument should be a list of tensors:
@@ -70,9 +70,9 @@ def update_kernel_for_bn_folding_fn(conv_node: BaseNode,
70
70
  Returns:
71
71
  The modified convolution node's weight/kernel/
72
72
  """
73
- if conv_node.type == DepthwiseConv2D:
73
+ if conv_node.is_match_type(DepthwiseConv2D):
74
74
  kernel = kernel * weights_scale.reshape((1, 1, kernel.shape[-2], kernel.shape[-1]))
75
- elif conv_node.type == Conv2DTranspose:
75
+ elif conv_node.is_match_type(Conv2DTranspose):
76
76
  kernel = kernel * weights_scale.reshape((1, 1, -1, 1))
77
77
  else:
78
78
  kernel = kernel * weights_scale.reshape((1, 1, 1, -1))
@@ -98,10 +98,10 @@ def update_weights_for_bn_forward_folding_fn(conv_node: BaseNode,
98
98
  Returns:
99
99
  The modified convolution node's weight/kernel/
100
100
  """
101
- if conv_node.type == DepthwiseConv2D:
101
+ if conv_node.is_match_type(DepthwiseConv2D):
102
102
  bias_update = kernel * bias_factor.reshape((1, 1, -1, 1))
103
103
  kernel = kernel * weights_scale.reshape((1, 1, -1, 1))
104
- elif conv_node.type == Conv2DTranspose:
104
+ elif conv_node.is_match_type(Conv2DTranspose):
105
105
  bias_update = (kernel * bias_factor.reshape((1, 1, 1, -1))).sum(3)
106
106
  kernel = kernel * weights_scale.reshape((1, 1, 1, -1))
107
107
  else:
@@ -133,7 +133,7 @@ def is_group_conv_fn(node: BaseNode) -> bool:
133
133
  Returns:
134
134
  True if the node is a group convolution, else False
135
135
  """
136
- return (node.type == Conv2D) and node.framework_attr[GROUPS] > 1
136
+ return (node.is_match_type(Conv2D)) and node.framework_attr[GROUPS] > 1
137
137
 
138
138
 
139
139
  def get_foldable_node_type_and_validity_fn(node: BaseNode) -> [bool, bool]:
@@ -147,8 +147,8 @@ def get_foldable_node_type_and_validity_fn(node: BaseNode) -> [bool, bool]:
147
147
  is_bn: True if the node is a batch norm, else False
148
148
  is_dw_valid: True if the node is a dw-convolution valid for folding or a batch-norm node, else False
149
149
  """
150
- is_bn = node.type is BatchNormalization
151
- is_dw = node.type is DepthwiseConv2D
150
+ is_bn = node.is_match_type(BatchNormalization)
151
+ is_dw = node.is_match_type(DepthwiseConv2D)
152
152
  is_dw_valid = is_dw and np.all(np.array(node.get_weights_by_keys(DEPTHWISE_KERNEL).shape[:2]) == 1)
153
153
  return is_bn, is_dw_valid
154
154
 
@@ -58,7 +58,7 @@ def conv2d_collapsing_fn(first_node: BaseNode,
58
58
  Returns:
59
59
  The modified layer node's weights: kernel, bias
60
60
  """
61
- if first_node.type == Conv2D and second_node.type == Conv2D:
61
+ if first_node.is_match_type(Conv2D) and second_node.is_match_type(Conv2D):
62
62
  # Get nodes attributes
63
63
  kernel1 = first_node.get_weights_by_keys(kernel_str)
64
64
  kernel2 = second_node.get_weights_by_keys(kernel_str)
@@ -0,0 +1,51 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import keras
17
+ import tensorflow as tf
18
+
19
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
20
+ from model_compression_toolkit.core import common
21
+ from model_compression_toolkit.core.common.graph.base_graph import Graph
22
+ from model_compression_toolkit.core.common.graph.base_node import BaseNode
23
+ from model_compression_toolkit.core.common.substitutions.remove_identity import remove_identity_node
24
+
25
+
26
+ class RemoveIdentity(common.BaseSubstitution):
27
+ """
28
+ Remove Identity layers from the graph.
29
+ """
30
+
31
+ def __init__(self):
32
+ nodes = NodeOperationMatcher(keras.layers.Identity) | NodeOperationMatcher(tf.identity)
33
+ super().__init__(matcher_instance=nodes)
34
+
35
+ def substitute(self,
36
+ graph: Graph,
37
+ node: BaseNode) -> Graph:
38
+ """
39
+ The method to perform the substitution of the identity keras node by
40
+ reconnecting its input directly to its output, effectively removing the node
41
+ from the graph.
42
+
43
+ Args:
44
+ graph: The current graph of operations where the node resides.
45
+ node: The specific `BaseNode` that is matched to be an Identity operation.
46
+
47
+ Returns:
48
+ Graph: The updated graph after removing the identity node.
49
+ """
50
+ return remove_identity_node(graph, node)
51
+
@@ -49,7 +49,7 @@ def residual_collapsing_fn(first_node: BaseNode,
49
49
  Returns:
50
50
  The modified layer node's weights: kernel
51
51
  """
52
- if first_node.type == Conv2D:
52
+ if first_node.is_match_type(Conv2D):
53
53
  # Get nodes attributes
54
54
  kernel = first_node.get_weights_by_keys(kernel_str)
55
55
  (kH, kW, Cin, Cout) = kernel.shape
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from functools import partial
16
- from typing import List, Any, Tuple, Callable, Dict
16
+ from typing import List, Any, Tuple, Callable, Dict, Union
17
17
 
18
18
  import numpy as np
19
19
  import tensorflow as tf
@@ -22,6 +22,7 @@ from tensorflow.keras.models import Model
22
22
 
23
23
  from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
24
24
  from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, HessianInfoService
25
+ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.remove_identity import RemoveIdentity
25
26
  from model_compression_toolkit.core.keras.hessian.activation_trace_hessian_calculator_keras import \
26
27
  ActivationTraceHessianCalculatorKeras
27
28
  from model_compression_toolkit.core.keras.hessian.weights_trace_hessian_calculator_keras import WeightsTraceHessianCalculatorKeras
@@ -246,7 +247,8 @@ class KerasImplementation(FrameworkImplementation):
246
247
  MatmulToDenseSubstitution(),
247
248
  MultiHeadAttentionDecomposition(),
248
249
  ActivationDecomposition(),
249
- DwconvToConv()]
250
+ DwconvToConv(),
251
+ RemoveIdentity()]
250
252
 
251
253
  def get_substitutions_pre_statistics_collection(self, quant_config: QuantizationConfig) -> \
252
254
  List[common.BaseSubstitution]:
@@ -410,12 +412,13 @@ class KerasImplementation(FrameworkImplementation):
410
412
  Returns: True if the node should be considered an interest point, False otherwise.
411
413
  """
412
414
 
413
- if node.type == Activation:
415
+ if node.is_match_type(Activation):
414
416
  node_type_name = node.framework_attr[keras_constants.ACTIVATION]
415
417
  if node_type_name in [keras_constants.SOFTMAX, keras_constants.SIGMOID]:
416
418
  return True
417
- elif node.type in [tf.nn.softmax, tf.keras.layers.Softmax, tf.nn.sigmoid, Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense, Concatenate,
418
- tf.concat, Add, tf.add]:
419
+ elif any([node.is_match_type(_type) for _type in [tf.nn.softmax, tf.keras.layers.Softmax, tf.nn.sigmoid, Conv2D,
420
+ DepthwiseConv2D, Conv2DTranspose, Dense, Concatenate, tf.concat,
421
+ Add, tf.add]]):
419
422
  return True
420
423
 
421
424
  return False
@@ -527,18 +530,18 @@ class KerasImplementation(FrameworkImplementation):
527
530
  kernel_shape = node.get_weights_by_keys(fw_info.get_kernel_op_attributes(node.type)[0]).shape
528
531
  output_channel_axis, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
529
532
 
530
- if node.type is Conv2D or node.type is Conv2DTranspose:
533
+ if node.is_match_type(Conv2D) or node.is_match_type(Conv2DTranspose):
531
534
  # (C_out * W_out * H_out) * C_in * (W_kernel * H_kernel)
532
535
  return np.prod([x for x in output_shape if x is not None]) * \
533
536
  kernel_shape[input_channel_axis] * \
534
537
  (kernel_shape[0] * kernel_shape[1])
535
- elif node.type is DepthwiseConv2D:
538
+ elif node.is_match_type(DepthwiseConv2D):
536
539
  # Depth * (W_out * H_out) * C_in * (W_kernel * H_kernel)
537
540
  return node.framework_attr.get(DEPTH_MULTIPLIER) * \
538
541
  np.prod([x for x in output_shape if x is not None]) / output_shape[output_channel_axis] * \
539
542
  kernel_shape[input_channel_axis] * \
540
543
  (kernel_shape[0] * kernel_shape[1])
541
- elif node.type is Dense:
544
+ elif node.is_match_type(Dense):
542
545
  # IN * OUT
543
546
  return kernel_shape[0] * kernel_shape[1]
544
547
  else:
@@ -591,10 +594,9 @@ class KerasImplementation(FrameworkImplementation):
591
594
  Returns:
592
595
  weight_quantizers: A dictionary between a weight's name to its quantizer.
593
596
  activation_quantizers: A list of activations quantization, one for each layer output.
594
-
595
597
  """
596
598
 
597
- def _weight_name(w: str) -> str:
599
+ def _weight_name(w: Union[str, int]) -> Union[str, int]:
598
600
  """
599
601
  Extracts the weight name from the full TensorFlow variable name.
600
602
 
@@ -607,7 +609,7 @@ class KerasImplementation(FrameworkImplementation):
607
609
  Extracted weight name.
608
610
  """
609
611
 
610
- return w.split(':')[0].split('/')[-1]
612
+ return w.split(':')[0].split('/')[-1] if isinstance(w, str) else w
611
613
 
612
614
  attribute_names = [_weight_name(wn) for wn in node.get_node_weights_attributes()
613
615
  if node.is_weights_quantization_enabled(wn)]
@@ -56,13 +56,13 @@ def _get_min_max_outputs(node: BaseNode,
56
56
  """
57
57
  min_output, max_output = None, None
58
58
 
59
- if node.type == ReLU:
59
+ if node.is_match_type(ReLU):
60
60
  min_output = node.framework_attr[THRESHOLD] if node.framework_attr[NEGATIVE_SLOPE] == 0 else None
61
61
 
62
62
  elif fw_info.layers_has_min_max(node.type):
63
63
  min_output, max_output = fw_info.layer_min_max_mapping[node.type]
64
64
 
65
- elif node.type == Activation and fw_info.activation_has_min_max(node.framework_attr[ACTIVATION]):
65
+ elif node.is_match_type(Activation) and fw_info.activation_has_min_max(node.framework_attr[ACTIVATION]):
66
66
  min_output, max_output = fw_info.activation_min_max_mapping[node.framework_attr[ACTIVATION]]
67
67
 
68
68
  return min_output, max_output
@@ -82,7 +82,7 @@ def _get_mean_std_outputs(node: BaseNode,
82
82
  """
83
83
  mean_output, std_output = None, None
84
84
 
85
- if node.type == BatchNormalization:
85
+ if node.is_match_type(BatchNormalization):
86
86
  mean_output = node.get_weights_by_keys(BETA)
87
87
  if node.get_weights_by_keys(GAMMA) is None:
88
88
  std_output = 1.0
@@ -92,7 +92,7 @@ def _get_mean_std_outputs(node: BaseNode,
92
92
  mean_output = 0.0
93
93
  else:
94
94
  next_node_list = graph.get_next_nodes(node)
95
- bn_nodes = [bn_node for bn_node in next_node_list if bn_node.type == BatchNormalization]
95
+ bn_nodes = [bn_node for bn_node in next_node_list if bn_node.is_match_type(BatchNormalization)]
96
96
  if len(bn_nodes) != 0:
97
97
  bn_node = bn_nodes[0]
98
98
  moving_variance = bn_node.get_weights_by_keys(MOVING_VARIANCE)
@@ -209,10 +209,9 @@ def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool:
209
209
  """
210
210
 
211
211
  # Check if the node is a Conv2D or Conv2DTranspose layer with groups set to 1.
212
- if node.type in [keras.layers.Conv2D, keras.layers.Conv2DTranspose]:
212
+ if node.is_match_type(keras.layers.Conv2D) or node.is_match_type(keras.layers.Conv2DTranspose):
213
213
  return node.framework_attr[GROUPS] == 1
214
- return node.type == keras.layers.Dense
215
-
214
+ return node.is_match_type(keras.layers.Dense)
216
215
 
217
216
 
218
217
  def _prune_keras_edge_node(node: BaseNode,
@@ -250,9 +249,9 @@ def _prune_keras_edge_node(node: BaseNode,
250
249
 
251
250
  if not is_exit_node:
252
251
  # Update 'filters' or 'units' attributes for entry node Conv2D/Conv2DTranspose layers.
253
- if node.type in [keras.layers.Conv2D, keras.layers.Conv2DTranspose]:
252
+ if node.is_match_type(keras.layers.Conv2D) or node.is_match_type(keras.layers.Conv2DTranspose):
254
253
  node.framework_attr[FILTERS] = int(np.sum(mask))
255
- elif node.type == keras.layers.Dense:
254
+ elif node.is_match_type(keras.layers.Dense):
256
255
  node.framework_attr[UNITS] = int(np.sum(mask))
257
256
 
258
257
  if is_exit_node:
@@ -43,7 +43,7 @@ def is_node_an_input_layer(node: BaseNode) -> bool:
43
43
  Whether the node represents an input layer or not.
44
44
  """
45
45
  if isinstance(node, BaseNode):
46
- return node.type == InputLayer
46
+ return node.is_match_type(InputLayer)
47
47
  elif isinstance(node, KerasNode):
48
48
  return isinstance(node.layer, InputLayer)
49
49
  else:
@@ -60,7 +60,7 @@ def is_node_a_model(node: BaseNode) -> bool:
60
60
  Whether the node represents a Keras model or not.
61
61
  """
62
62
  if isinstance(node, BaseNode):
63
- return node.type in [Functional, Sequential]
63
+ return node.is_match_type(Functional) or node.is_match_type(Sequential)
64
64
  elif isinstance(node, KerasNode):
65
65
  return isinstance(node.layer, Functional) or isinstance(node.layer, Sequential)
66
66
  else:
@@ -41,7 +41,7 @@ layers = keras.layers
41
41
 
42
42
  REUSED_IDENTIFIER = '_reused_'
43
43
 
44
- is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray))
44
+ is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray, float))
45
45
  is_tensor = lambda x: isinstance(x, KerasTensor)
46
46
 
47
47
 
@@ -61,18 +61,36 @@ def get_kwargs2index(tfoplambda_layer: TFOpLambda) -> Dict[str, int]:
61
61
  """
62
62
  Positional weights are saved according to their index in the node's call arguments, so
63
63
  need to know the function arguments' names in case the weights are in the kwargs.
64
+
65
+ Note: the kwargs2index dictionary is initialized manually (and not with tf_inspect) so
66
+ it will only include the arguments that may contain constants. For example, we don't
67
+ want the transpose_a attribute of tf.matmul to be saved as a constant.
68
+
69
+ Every operation we add support to, needs to be added here.
70
+
64
71
  Args:
65
72
  tfoplambda_layer: TFOpLambda layer.
66
73
 
67
74
  Returns:
68
75
  A dictionary with argument number and index: {arg_name: arg_index}.
69
76
  """
70
- if tfoplambda_layer.function in [tf.add, tf.subtract, tf.divide, tf.truediv, tf.multiply, tf.pow,
71
- tf.matmul, tf.image.crop_and_resize, tf.image.combined_non_max_suppression] or \
72
- tfoplambda_layer.symbol in ['__operators__.add', 'math.add', 'math.multiply', 'linalg.matmul', 'concat']:
73
- return {arg_name: i for i, arg_name in enumerate(tf_inspect.getfullargspec(tfoplambda_layer.function).args)}
74
- else:
75
- return {}
77
+ kwargs2index = {tf.add: {'x': 0, 'y': 1},
78
+ tf.subtract: {'x': 0, 'y': 1},
79
+ tf.divide: {'x': 0, 'y': 1},
80
+ tf.truediv: {'x': 0, 'y': 1},
81
+ tf.multiply: {'x': 0, 'y': 1},
82
+ tf.pow: {'x': 0, 'y': 1},
83
+ tf.matmul: {'a': 0, 'b': 1}}.get(tfoplambda_layer.function)
84
+ if not kwargs2index:
85
+ # In TF 2.15 the function attribute is different and doesn't match the original
86
+ # operation object we use. Therefore, we extract kwargs2index with the symbol.
87
+ kwargs2index = {'__operators__.add': {'x': 0, 'y': 1},
88
+ 'math.add': {'x': 0, 'y': 1},
89
+ 'math.multiply': {'x': 0, 'y': 1},
90
+ 'linalg.matmul': {'a': 0, 'b': 1},
91
+ 'concat': {'values': 0}}.get(tfoplambda_layer.symbol, {})
92
+
93
+ return kwargs2index
76
94
 
77
95
 
78
96
  def build_node(node: KerasNode,
@@ -154,8 +172,9 @@ def build_node(node: KerasNode,
154
172
  if is_const(v) or (keras_layer.function in [tf.add, tf.multiply, tf.subtract, tf.divide, tf.truediv, tf.pow,
155
173
  tf.matmul] and
156
174
  isinstance(v, (tuple, list))):
157
- weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)})
158
- weight_keys.append(k)
175
+ if k in kwarg2index:
176
+ weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)})
177
+ weight_keys.append(k)
159
178
  # remove weights and KerasTensors and weights from op_call_kwargs
160
179
  op_call_kwargs = {k: v for k, v in op_call_kwargs.items()
161
180
  if not (kwarg2index.get(k) in weights or is_tensor(v))}
@@ -40,7 +40,7 @@ def to_tf_tensor(tensor):
40
40
  Logger.critical(f'Unsupported type for conversion to TF tensor: {type(tensor)}.')
41
41
 
42
42
 
43
- def tf_tensor_to_numpy(tensor: Union[List, Tuple, np.ndarray, tf.Tensor],
43
+ def tf_tensor_to_numpy(tensor: Union[List, Tuple, np.ndarray, tf.Tensor, float],
44
44
  is_single_tensor=False) -> np.ndarray:
45
45
  """
46
46
  Convert a TF tensor to a Numpy array.
@@ -65,6 +65,9 @@ def tf_tensor_to_numpy(tensor: Union[List, Tuple, np.ndarray, tf.Tensor],
65
65
  else:
66
66
  return (tf_tensor_to_numpy(t) for t in tensor)
67
67
  elif isinstance(tensor, tf.Tensor):
68
- return tensor.numpy()
68
+ np_tensor = tensor.numpy()
69
+ return np.array([np_tensor]) if np.isscalar(np_tensor) else np_tensor
70
+ elif isinstance(tensor, float):
71
+ return np.array([tensor])
69
72
  else:
70
73
  Logger.critical(f'Unsupported type for conversion to Numpy array: {type(tensor)}.')
@@ -33,26 +33,31 @@ from model_compression_toolkit.core.pytorch.pytorch_device_config import get_wor
33
33
  from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
34
34
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
35
35
  from mct_quantizers.common.constants import ACTIVATION_HOLDER_QUANTIZER
36
+ from mct_quantizers import PytorchQuantizationWrapper
36
37
 
37
38
 
38
39
  def _build_input_tensors_list(node: BaseNode,
39
40
  graph: Graph,
40
41
  inputs: Tuple[Any],
41
- node_to_output_tensors_dict: Dict[BaseNode, List]) -> List[List]:
42
+ node_to_output_tensors_dict: Dict[BaseNode, List],
43
+ is_op_quantize_wrapper: bool) -> List[List]:
42
44
  """
43
- Given a node, build a list of input tensors the node gets. The list is built
44
- based on the node's incoming edges and previous nodes' output tensors.
45
+ Given a node, build a list of input tensors the node gets. The list is built based on the
46
+ node's incoming edges, previous nodes' output tensors and the node's positional weights.
47
+ Positional weights aren't used if the node's op is PytorchQuantizationWrapper, since it's
48
+ positional weights are already in the wrapper.
45
49
 
46
50
  Args:
47
51
  node: Node to build its input tensors list.
48
52
  graph: Graph the node is in.
49
- inputs: list of input tensors to model
53
+ inputs: list of input tensors to model.
50
54
  node_to_output_tensors_dict: A dictionary from a node to its output tensors.
55
+ is_op_quantize_wrapper: Whether the func_op is a PytorchQuantizationWrapper or not.
51
56
 
52
57
  Returns:
53
58
  A list of the node's input tensors.
54
59
  """
55
- if node.type == DummyPlaceHolder:
60
+ if node.is_match_type(DummyPlaceHolder):
56
61
  input_tensors = [inputs[graph.get_inputs().index(node)]]
57
62
  else:
58
63
  input_tensors = []
@@ -62,7 +67,8 @@ def _build_input_tensors_list(node: BaseNode,
62
67
  _input_tensors = node_to_output_tensors_dict[ie.source_node]
63
68
  input_tensors.append(_input_tensors)
64
69
  input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists
65
- input_tensors = node.insert_positional_weights_to_input_list(input_tensors)
70
+ if not is_op_quantize_wrapper:
71
+ input_tensors = node.insert_positional_weights_to_input_list(input_tensors)
66
72
  # convert inputs from positional weights (numpy arrays) to tensors. Must handle each element in the
67
73
  # list separately, because in FX the tensors are FX objects and fail to_torch_tensor
68
74
  input_tensors = [to_torch_tensor(t) if isinstance(t, np.ndarray) else t
@@ -70,22 +76,27 @@ def _build_input_tensors_list(node: BaseNode,
70
76
  return input_tensors
71
77
 
72
78
 
73
- def _merge_inputs(_node, input_tensors: List, op_call_args: List) -> List:
79
+ def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List,
80
+ is_op_quantize_wrapper: bool) -> List:
74
81
  """
75
- Merge input tensors list with op_call_args, according to correct order
82
+ Merge input tensors list with op_call_args, according to correct order.
76
83
 
77
84
  Args:
78
- _node: The node the inputs are for
85
+ _node: The node the inputs are for.
79
86
  input_tensors: activation input tensors to node.
80
- op_call_args: framework node call args
87
+ op_call_args: framework node call args.
88
+ is_op_quantize_wrapper: Whether the func_op is a PytorchQuantizationWrapper or not.
81
89
  Returns:
82
- Combined list of input_tensors and op_call_args
90
+ Combined list of input_tensors and op_call_args.
83
91
  """
84
92
  if isinstance(_node, FunctionalNode) and _node.tensor_input_indices:
85
- assert len(_node.tensor_input_indices) == len(input_tensors), 'Mismatch between input tensors and indices'
86
93
  _input_list = op_call_args.copy()
87
- for i, t in zip(_node.tensor_input_indices, input_tensors):
88
- _input_list.insert(i, t)
94
+ if is_op_quantize_wrapper:
95
+ _input_list = input_tensors + _input_list
96
+ else:
97
+ assert len(_node.tensor_input_indices) == len(input_tensors), 'Mismatch between input tensors and indices'
98
+ for i, t in zip(_node.tensor_input_indices, input_tensors):
99
+ _input_list.insert(i, t)
89
100
  else:
90
101
  _input_list = input_tensors + op_call_args
91
102
 
@@ -118,7 +129,8 @@ def _run_operation(n: BaseNode,
118
129
  if isinstance(n, FunctionalNode) and n.inputs_as_list:
119
130
  out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
120
131
  else:
121
- out_tensors_of_n_float = op_func(*_merge_inputs(n, input_tensors, op_call_args), **functional_kwargs)
132
+ merged_inputs = _merge_inputs(n, input_tensors, op_call_args, isinstance(op_func, PytorchQuantizationWrapper))
133
+ out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs)
122
134
 
123
135
  # Add a fake quant node if the node has an activation threshold.
124
136
  out_tensors_of_n = out_tensors_of_n_float
@@ -279,12 +291,12 @@ class PytorchModel(torch.nn.Module):
279
291
  node_to_output_tensors_dict_float = dict()
280
292
  configurable_nodes = self.graph.get_configurable_sorted_nodes_names(DEFAULT_PYTORCH_INFO)
281
293
  for node in self.node_sort:
294
+ op_func = self._get_op_func(node, configurable_nodes)
282
295
  input_tensors = _build_input_tensors_list(node,
283
296
  self.graph,
284
297
  args,
285
- node_to_output_tensors_dict)
286
-
287
- op_func = self._get_op_func(node, configurable_nodes)
298
+ node_to_output_tensors_dict,
299
+ isinstance(op_func, PytorchQuantizationWrapper))
288
300
  use_activation_quantization, activation_quantization_fn = self._get_activation_quantization_fn(node)
289
301
 
290
302
  # Run node operation and fetch outputs
@@ -326,15 +338,16 @@ class PytorchModel(torch.nn.Module):
326
338
  """
327
339
  return getattr(self, node.name)
328
340
 
329
- def _get_activation_quantization_fn(self, node) -> Tuple[bool, bool, Callable]:
341
+ def _get_activation_quantization_fn(self, node) -> Tuple[bool, Callable]:
330
342
  """
331
343
  Get activation quantization parameters for this node.
332
344
 
333
345
  Args:
334
346
  node: Node from which to extract the activation quantization parameters.
335
347
 
336
- Returns: Flag to indicate if we quantize activations, flag to indicate if we quantize activations
337
- using a quantization holder and a quantization function to use for the node's activations.
348
+ Returns:
349
+ Flag to indicate if we quantize activations using a quantization holder and a quantization
350
+ function to use for the node's activations.
338
351
  """
339
352
  activation_quantization_holder = self.node_to_activation_quantization_holder.get(node.name)
340
353
  use_activation_quantization = node.is_activation_quantization_enabled()
@@ -62,11 +62,11 @@ def update_kernel_for_bn_folding_fn(conv_node: BaseNode,
62
62
  Returns:
63
63
  The modified convolution node's weight/kernel/
64
64
  """
65
- if conv_node.type == ConvTranspose2d:
65
+ if conv_node.is_match_type(ConvTranspose2d):
66
66
  _scale = weights_scale[None, :, None, None]
67
67
  else:
68
68
  _scale = weights_scale[:, None, None, None]
69
- if conv_node.type == ConvTranspose2d and conv_node.framework_attr[GROUPS] > 1:
69
+ if conv_node.is_match_type(ConvTranspose2d) and conv_node.framework_attr[GROUPS] > 1:
70
70
  # PyTorch ConvTranspose2d kernel with groups stacks groups on in_channels axis, so need to reshape the kernel
71
71
  # so the groups are stacked on the out_channels axis to match the scale vector (then reshape back to original
72
72
  # shape)
@@ -93,10 +93,10 @@ def update_weights_for_bn_forward_folding_fn(conv_node: BaseNode,
93
93
  Returns:
94
94
  The modified convolution node's weight/kernel/
95
95
  """
96
- if conv_node.type == Conv2d and conv_node.framework_attr[GROUPS] > 1:
96
+ if conv_node.is_match_type(Conv2d) and conv_node.framework_attr[GROUPS] > 1:
97
97
  bias_update = (kernel * bias_factor[:, None, None, None]).flatten()
98
98
  _scale = weights_scale[:, None, None, None]
99
- elif conv_node.type == ConvTranspose2d:
99
+ elif conv_node.is_match_type(ConvTranspose2d):
100
100
  bias_update = (kernel * bias_factor[:, None, None, None]).sum(axis=0).flatten()
101
101
  _scale = weights_scale[:, None, None, None]
102
102
  else:
@@ -125,8 +125,8 @@ def is_group_conv_fn(node: BaseNode) -> bool:
125
125
  Returns:
126
126
  True if the node is a group convolution, else False
127
127
  """
128
- return node.type in [Conv2d, ConvTranspose2d] and \
129
- node.framework_attr[GROUPS] not in [node.framework_attr[IN_CHANNELS], 1]
128
+ return (node.is_match_type(Conv2d) or node.is_match_type(ConvTranspose2d)) and \
129
+ node.framework_attr[GROUPS] not in [node.framework_attr[IN_CHANNELS], 1]
130
130
 
131
131
 
132
132
  def get_foldable_node_type_and_validity_fn(node: BaseNode) -> [bool, bool]:
@@ -140,8 +140,8 @@ def get_foldable_node_type_and_validity_fn(node: BaseNode) -> [bool, bool]:
140
140
  is_bn: True if the node is a batch norm, else False
141
141
  is_dw_valid: True if the node is a dw-convolution valid for folding or a batch-norm node, else False
142
142
  """
143
- is_bn = node.type is BatchNorm2d
144
- is_dw = node.type is Conv2d and node.framework_attr[GROUPS] == node.framework_attr[IN_CHANNELS]
143
+ is_bn = node.is_match_type(BatchNorm2d)
144
+ is_dw = node.is_match_type(Conv2d) and node.framework_attr[GROUPS] == node.framework_attr[IN_CHANNELS]
145
145
  is_dw_valid = is_dw and np.all(np.array(node.get_weights_by_keys(KERNEL).shape[2:]) == 1)
146
146
  return is_bn, is_dw_valid
147
147
 
@@ -48,9 +48,9 @@ class FunctionalConvSubstitution(common.BaseSubstitution):
48
48
  Graph after applying the substitution.
49
49
  """
50
50
  # Set new layer
51
- if func_node.type == conv2d:
51
+ if func_node.is_match_type(conv2d):
52
52
  new_layer = Conv2d
53
- elif func_node.type == conv_transpose2d:
53
+ elif func_node.is_match_type(conv_transpose2d):
54
54
  new_layer = ConvTranspose2d
55
55
  else:
56
56
  Logger.critical(f'Substitution filter mismatch. Layer {func_node.type}. Must be {type(Conv2d)} or {type(ConvTranspose2d)}.') # pragma: no cover
@@ -53,7 +53,7 @@ def conv2d_collapsing_fn(first_node: BaseNode,
53
53
  Returns:
54
54
  The modified layer node's weights: kernel, bias
55
55
  """
56
- if first_node.type == Conv2d and second_node.type == Conv2d:
56
+ if first_node.is_match_type(Conv2d) and second_node.is_match_type(Conv2d):
57
57
  # Get nodes attributes
58
58
  kernel1 = first_node.get_weights_by_keys(kernel_str)
59
59
  kernel2 = second_node.get_weights_by_keys(kernel_str)
@@ -76,17 +76,17 @@ class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
76
76
  second_op2d_node = nodes_list[2]
77
77
 
78
78
  # only act on bound relu with not POT max value and 0 min value
79
- if non_linear_node.type == ReLU6:
79
+ if non_linear_node.is_match_type(ReLU6):
80
80
  scale_factor = 6.0 / self.threshold
81
81
  non_linear_node.layer_class = Hardtanh
82
82
  non_linear_node.framework_attr[INPLACE] = False
83
83
  non_linear_node.framework_attr[HARDTANH_MIN_VAL] = 0.0
84
84
  non_linear_node.framework_attr[HARDTANH_MAX_VAL] = self.threshold
85
- elif non_linear_node.type == relu6:
85
+ elif non_linear_node.is_match_type(relu6):
86
86
  scale_factor = 6.0 / self.threshold
87
87
  non_linear_node.functional_op = hardtanh
88
88
  non_linear_node.functional_op.__defaults__ = (0.0, self.threshold, False)
89
- elif non_linear_node.type == Hardtanh:
89
+ elif non_linear_node.is_match_type(Hardtanh):
90
90
  if (non_linear_node.framework_attr[HARDTANH_MIN_VAL] == 0.0) and not \
91
91
  (np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]).astype(int) -
92
92
  np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]) == 0):
@@ -94,7 +94,7 @@ class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
94
94
  non_linear_node.framework_attr[HARDTANH_MAX_VAL] = self.threshold
95
95
  else:
96
96
  return graph
97
- elif non_linear_node.type == hardtanh:
97
+ elif non_linear_node.is_match_type(hardtanh):
98
98
  if (non_linear_node.framework_attr[HARDTANH_MIN_VAL] == 0.0) and not \
99
99
  (np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]).astype(int) -
100
100
  np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]) == 0):