mct-nightly 2.0.0.20240418.439__py3-none-any.whl → 2.0.0.20240420.357__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {mct_nightly-2.0.0.20240418.439.dist-info → mct_nightly-2.0.0.20240420.357.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.0.0.20240418.439.dist-info → mct_nightly-2.0.0.20240420.357.dist-info}/RECORD +39 -39
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  5. model_compression_toolkit/core/common/graph/base_node.py +25 -8
  6. model_compression_toolkit/core/common/graph/functional_node.py +18 -1
  7. model_compression_toolkit/core/common/network_editors/node_filters.py +4 -3
  8. model_compression_toolkit/core/common/quantization/node_quantization_config.py +0 -5
  9. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -3
  10. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  11. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +4 -1
  12. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -7
  13. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  14. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  15. model_compression_toolkit/core/keras/keras_implementation.py +10 -10
  16. model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -4
  17. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +4 -5
  18. model_compression_toolkit/core/keras/reader/common.py +2 -2
  19. model_compression_toolkit/core/keras/reader/node_builder.py +28 -9
  20. model_compression_toolkit/core/keras/tf_tensor_numpy.py +5 -2
  21. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +34 -21
  22. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_folding.py +8 -8
  23. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +2 -2
  24. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  25. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +4 -4
  26. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  27. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +8 -8
  28. model_compression_toolkit/core/pytorch/pytorch_implementation.py +4 -5
  29. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +2 -2
  30. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +5 -1
  31. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +9 -2
  32. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +1 -1
  33. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +20 -6
  34. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +1 -1
  35. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +22 -8
  36. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +1 -1
  37. {mct_nightly-2.0.0.20240418.439.dist-info → mct_nightly-2.0.0.20240420.357.dist-info}/LICENSE.md +0 -0
  38. {mct_nightly-2.0.0.20240418.439.dist-info → mct_nightly-2.0.0.20240420.357.dist-info}/WHEEL +0 -0
  39. {mct_nightly-2.0.0.20240418.439.dist-info → mct_nightly-2.0.0.20240420.357.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from functools import partial
16
- from typing import List, Any, Tuple, Callable, Dict
16
+ from typing import List, Any, Tuple, Callable, Dict, Union
17
17
 
18
18
  import numpy as np
19
19
  import tensorflow as tf
@@ -412,12 +412,13 @@ class KerasImplementation(FrameworkImplementation):
412
412
  Returns: True if the node should be considered an interest point, False otherwise.
413
413
  """
414
414
 
415
- if node.type == Activation:
415
+ if node.is_match_type(Activation):
416
416
  node_type_name = node.framework_attr[keras_constants.ACTIVATION]
417
417
  if node_type_name in [keras_constants.SOFTMAX, keras_constants.SIGMOID]:
418
418
  return True
419
- elif node.type in [tf.nn.softmax, tf.keras.layers.Softmax, tf.nn.sigmoid, Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense, Concatenate,
420
- tf.concat, Add, tf.add]:
419
+ elif any([node.is_match_type(_type) for _type in [tf.nn.softmax, tf.keras.layers.Softmax, tf.nn.sigmoid, Conv2D,
420
+ DepthwiseConv2D, Conv2DTranspose, Dense, Concatenate, tf.concat,
421
+ Add, tf.add]]):
421
422
  return True
422
423
 
423
424
  return False
@@ -529,18 +530,18 @@ class KerasImplementation(FrameworkImplementation):
529
530
  kernel_shape = node.get_weights_by_keys(fw_info.get_kernel_op_attributes(node.type)[0]).shape
530
531
  output_channel_axis, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
531
532
 
532
- if node.type is Conv2D or node.type is Conv2DTranspose:
533
+ if node.is_match_type(Conv2D) or node.is_match_type(Conv2DTranspose):
533
534
  # (C_out * W_out * H_out) * C_in * (W_kernel * H_kernel)
534
535
  return np.prod([x for x in output_shape if x is not None]) * \
535
536
  kernel_shape[input_channel_axis] * \
536
537
  (kernel_shape[0] * kernel_shape[1])
537
- elif node.type is DepthwiseConv2D:
538
+ elif node.is_match_type(DepthwiseConv2D):
538
539
  # Depth * (W_out * H_out) * C_in * (W_kernel * H_kernel)
539
540
  return node.framework_attr.get(DEPTH_MULTIPLIER) * \
540
541
  np.prod([x for x in output_shape if x is not None]) / output_shape[output_channel_axis] * \
541
542
  kernel_shape[input_channel_axis] * \
542
543
  (kernel_shape[0] * kernel_shape[1])
543
- elif node.type is Dense:
544
+ elif node.is_match_type(Dense):
544
545
  # IN * OUT
545
546
  return kernel_shape[0] * kernel_shape[1]
546
547
  else:
@@ -593,10 +594,9 @@ class KerasImplementation(FrameworkImplementation):
593
594
  Returns:
594
595
  weight_quantizers: A dictionary between a weight's name to its quantizer.
595
596
  activation_quantizers: A list of activations quantization, one for each layer output.
596
-
597
597
  """
598
598
 
599
- def _weight_name(w: str) -> str:
599
+ def _weight_name(w: Union[str, int]) -> Union[str, int]:
600
600
  """
601
601
  Extracts the weight name from the full TensorFlow variable name.
602
602
 
@@ -609,7 +609,7 @@ class KerasImplementation(FrameworkImplementation):
609
609
  Extracted weight name.
610
610
  """
611
611
 
612
- return w.split(':')[0].split('/')[-1]
612
+ return w.split(':')[0].split('/')[-1] if isinstance(w, str) else w
613
613
 
614
614
  attribute_names = [_weight_name(wn) for wn in node.get_node_weights_attributes()
615
615
  if node.is_weights_quantization_enabled(wn)]
@@ -56,13 +56,13 @@ def _get_min_max_outputs(node: BaseNode,
56
56
  """
57
57
  min_output, max_output = None, None
58
58
 
59
- if node.type == ReLU:
59
+ if node.is_match_type(ReLU):
60
60
  min_output = node.framework_attr[THRESHOLD] if node.framework_attr[NEGATIVE_SLOPE] == 0 else None
61
61
 
62
62
  elif fw_info.layers_has_min_max(node.type):
63
63
  min_output, max_output = fw_info.layer_min_max_mapping[node.type]
64
64
 
65
- elif node.type == Activation and fw_info.activation_has_min_max(node.framework_attr[ACTIVATION]):
65
+ elif node.is_match_type(Activation) and fw_info.activation_has_min_max(node.framework_attr[ACTIVATION]):
66
66
  min_output, max_output = fw_info.activation_min_max_mapping[node.framework_attr[ACTIVATION]]
67
67
 
68
68
  return min_output, max_output
@@ -82,7 +82,7 @@ def _get_mean_std_outputs(node: BaseNode,
82
82
  """
83
83
  mean_output, std_output = None, None
84
84
 
85
- if node.type == BatchNormalization:
85
+ if node.is_match_type(BatchNormalization):
86
86
  mean_output = node.get_weights_by_keys(BETA)
87
87
  if node.get_weights_by_keys(GAMMA) is None:
88
88
  std_output = 1.0
@@ -92,7 +92,7 @@ def _get_mean_std_outputs(node: BaseNode,
92
92
  mean_output = 0.0
93
93
  else:
94
94
  next_node_list = graph.get_next_nodes(node)
95
- bn_nodes = [bn_node for bn_node in next_node_list if bn_node.type == BatchNormalization]
95
+ bn_nodes = [bn_node for bn_node in next_node_list if bn_node.is_match_type(BatchNormalization)]
96
96
  if len(bn_nodes) != 0:
97
97
  bn_node = bn_nodes[0]
98
98
  moving_variance = bn_node.get_weights_by_keys(MOVING_VARIANCE)
@@ -209,10 +209,9 @@ def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool:
209
209
  """
210
210
 
211
211
  # Check if the node is a Conv2D or Conv2DTranspose layer with groups set to 1.
212
- if node.type in [keras.layers.Conv2D, keras.layers.Conv2DTranspose]:
212
+ if node.is_match_type(keras.layers.Conv2D) or node.is_match_type(keras.layers.Conv2DTranspose):
213
213
  return node.framework_attr[GROUPS] == 1
214
- return node.type == keras.layers.Dense
215
-
214
+ return node.is_match_type(keras.layers.Dense)
216
215
 
217
216
 
218
217
  def _prune_keras_edge_node(node: BaseNode,
@@ -250,9 +249,9 @@ def _prune_keras_edge_node(node: BaseNode,
250
249
 
251
250
  if not is_exit_node:
252
251
  # Update 'filters' or 'units' attributes for entry node Conv2D/Conv2DTranspose layers.
253
- if node.type in [keras.layers.Conv2D, keras.layers.Conv2DTranspose]:
252
+ if node.is_match_type(keras.layers.Conv2D) or node.is_match_type(keras.layers.Conv2DTranspose):
254
253
  node.framework_attr[FILTERS] = int(np.sum(mask))
255
- elif node.type == keras.layers.Dense:
254
+ elif node.is_match_type(keras.layers.Dense):
256
255
  node.framework_attr[UNITS] = int(np.sum(mask))
257
256
 
258
257
  if is_exit_node:
@@ -43,7 +43,7 @@ def is_node_an_input_layer(node: BaseNode) -> bool:
43
43
  Whether the node represents an input layer or not.
44
44
  """
45
45
  if isinstance(node, BaseNode):
46
- return node.type == InputLayer
46
+ return node.is_match_type(InputLayer)
47
47
  elif isinstance(node, KerasNode):
48
48
  return isinstance(node.layer, InputLayer)
49
49
  else:
@@ -60,7 +60,7 @@ def is_node_a_model(node: BaseNode) -> bool:
60
60
  Whether the node represents a Keras model or not.
61
61
  """
62
62
  if isinstance(node, BaseNode):
63
- return node.type in [Functional, Sequential]
63
+ return node.is_match_type(Functional) or node.is_match_type(Sequential)
64
64
  elif isinstance(node, KerasNode):
65
65
  return isinstance(node.layer, Functional) or isinstance(node.layer, Sequential)
66
66
  else:
@@ -41,7 +41,7 @@ layers = keras.layers
41
41
 
42
42
  REUSED_IDENTIFIER = '_reused_'
43
43
 
44
- is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray))
44
+ is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray, float))
45
45
  is_tensor = lambda x: isinstance(x, KerasTensor)
46
46
 
47
47
 
@@ -61,18 +61,36 @@ def get_kwargs2index(tfoplambda_layer: TFOpLambda) -> Dict[str, int]:
61
61
  """
62
62
  Positional weights are saved according to their index in the node's call arguments, so
63
63
  need to know the function arguments' names in case the weights are in the kwargs.
64
+
65
+ Note: the kwargs2index dictionary is initialized manually (and not with tf_inspect) so
66
+ it will only include the arguments that may contain constants. For example, we don't
67
+ want the transpose_a attribute of tf.matmul to be saved as a constant.
68
+
69
+ Every operation we add support to, needs to be added here.
70
+
64
71
  Args:
65
72
  tfoplambda_layer: TFOpLambda layer.
66
73
 
67
74
  Returns:
68
75
  A dictionary with argument number and index: {arg_name: arg_index}.
69
76
  """
70
- if tfoplambda_layer.function in [tf.add, tf.subtract, tf.divide, tf.truediv, tf.multiply, tf.pow,
71
- tf.matmul, tf.image.crop_and_resize, tf.image.combined_non_max_suppression] or \
72
- tfoplambda_layer.symbol in ['__operators__.add', 'math.add', 'math.multiply', 'linalg.matmul', 'concat']:
73
- return {arg_name: i for i, arg_name in enumerate(tf_inspect.getfullargspec(tfoplambda_layer.function).args)}
74
- else:
75
- return {}
77
+ kwargs2index = {tf.add: {'x': 0, 'y': 1},
78
+ tf.subtract: {'x': 0, 'y': 1},
79
+ tf.divide: {'x': 0, 'y': 1},
80
+ tf.truediv: {'x': 0, 'y': 1},
81
+ tf.multiply: {'x': 0, 'y': 1},
82
+ tf.pow: {'x': 0, 'y': 1},
83
+ tf.matmul: {'a': 0, 'b': 1}}.get(tfoplambda_layer.function)
84
+ if not kwargs2index:
85
+ # In TF 2.15 the function attribute is different and doesn't match the original
86
+ # operation object we use. Therefore, we extract kwargs2index with the symbol.
87
+ kwargs2index = {'__operators__.add': {'x': 0, 'y': 1},
88
+ 'math.add': {'x': 0, 'y': 1},
89
+ 'math.multiply': {'x': 0, 'y': 1},
90
+ 'linalg.matmul': {'a': 0, 'b': 1},
91
+ 'concat': {'values': 0}}.get(tfoplambda_layer.symbol, {})
92
+
93
+ return kwargs2index
76
94
 
77
95
 
78
96
  def build_node(node: KerasNode,
@@ -154,8 +172,9 @@ def build_node(node: KerasNode,
154
172
  if is_const(v) or (keras_layer.function in [tf.add, tf.multiply, tf.subtract, tf.divide, tf.truediv, tf.pow,
155
173
  tf.matmul] and
156
174
  isinstance(v, (tuple, list))):
157
- weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)})
158
- weight_keys.append(k)
175
+ if k in kwarg2index:
176
+ weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)})
177
+ weight_keys.append(k)
159
178
  # remove weights and KerasTensors and weights from op_call_kwargs
160
179
  op_call_kwargs = {k: v for k, v in op_call_kwargs.items()
161
180
  if not (kwarg2index.get(k) in weights or is_tensor(v))}
@@ -40,7 +40,7 @@ def to_tf_tensor(tensor):
40
40
  Logger.critical(f'Unsupported type for conversion to TF tensor: {type(tensor)}.')
41
41
 
42
42
 
43
- def tf_tensor_to_numpy(tensor: Union[List, Tuple, np.ndarray, tf.Tensor],
43
+ def tf_tensor_to_numpy(tensor: Union[List, Tuple, np.ndarray, tf.Tensor, float],
44
44
  is_single_tensor=False) -> np.ndarray:
45
45
  """
46
46
  Convert a TF tensor to a Numpy array.
@@ -65,6 +65,9 @@ def tf_tensor_to_numpy(tensor: Union[List, Tuple, np.ndarray, tf.Tensor],
65
65
  else:
66
66
  return (tf_tensor_to_numpy(t) for t in tensor)
67
67
  elif isinstance(tensor, tf.Tensor):
68
- return tensor.numpy()
68
+ np_tensor = tensor.numpy()
69
+ return np.array([np_tensor]) if np.isscalar(np_tensor) else np_tensor
70
+ elif isinstance(tensor, float):
71
+ return np.array([tensor])
69
72
  else:
70
73
  Logger.critical(f'Unsupported type for conversion to Numpy array: {type(tensor)}.')
@@ -33,26 +33,31 @@ from model_compression_toolkit.core.pytorch.pytorch_device_config import get_wor
33
33
  from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
34
34
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
35
35
  from mct_quantizers.common.constants import ACTIVATION_HOLDER_QUANTIZER
36
+ from mct_quantizers import PytorchQuantizationWrapper
36
37
 
37
38
 
38
39
  def _build_input_tensors_list(node: BaseNode,
39
40
  graph: Graph,
40
41
  inputs: Tuple[Any],
41
- node_to_output_tensors_dict: Dict[BaseNode, List]) -> List[List]:
42
+ node_to_output_tensors_dict: Dict[BaseNode, List],
43
+ is_op_quantize_wrapper: bool) -> List[List]:
42
44
  """
43
- Given a node, build a list of input tensors the node gets. The list is built
44
- based on the node's incoming edges and previous nodes' output tensors.
45
+ Given a node, build a list of input tensors the node gets. The list is built based on the
46
+ node's incoming edges, previous nodes' output tensors and the node's positional weights.
47
+ Positional weights aren't used if the node's op is PytorchQuantizationWrapper, since it's
48
+ positional weights are already in the wrapper.
45
49
 
46
50
  Args:
47
51
  node: Node to build its input tensors list.
48
52
  graph: Graph the node is in.
49
- inputs: list of input tensors to model
53
+ inputs: list of input tensors to model.
50
54
  node_to_output_tensors_dict: A dictionary from a node to its output tensors.
55
+ is_op_quantize_wrapper: Whether the func_op is a PytorchQuantizationWrapper or not.
51
56
 
52
57
  Returns:
53
58
  A list of the node's input tensors.
54
59
  """
55
- if node.type == DummyPlaceHolder:
60
+ if node.is_match_type(DummyPlaceHolder):
56
61
  input_tensors = [inputs[graph.get_inputs().index(node)]]
57
62
  else:
58
63
  input_tensors = []
@@ -62,7 +67,8 @@ def _build_input_tensors_list(node: BaseNode,
62
67
  _input_tensors = node_to_output_tensors_dict[ie.source_node]
63
68
  input_tensors.append(_input_tensors)
64
69
  input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists
65
- input_tensors = node.insert_positional_weights_to_input_list(input_tensors)
70
+ if not is_op_quantize_wrapper:
71
+ input_tensors = node.insert_positional_weights_to_input_list(input_tensors)
66
72
  # convert inputs from positional weights (numpy arrays) to tensors. Must handle each element in the
67
73
  # list separately, because in FX the tensors are FX objects and fail to_torch_tensor
68
74
  input_tensors = [to_torch_tensor(t) if isinstance(t, np.ndarray) else t
@@ -70,22 +76,27 @@ def _build_input_tensors_list(node: BaseNode,
70
76
  return input_tensors
71
77
 
72
78
 
73
- def _merge_inputs(_node, input_tensors: List, op_call_args: List) -> List:
79
+ def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List,
80
+ is_op_quantize_wrapper: bool) -> List:
74
81
  """
75
- Merge input tensors list with op_call_args, according to correct order
82
+ Merge input tensors list with op_call_args, according to correct order.
76
83
 
77
84
  Args:
78
- _node: The node the inputs are for
85
+ _node: The node the inputs are for.
79
86
  input_tensors: activation input tensors to node.
80
- op_call_args: framework node call args
87
+ op_call_args: framework node call args.
88
+ is_op_quantize_wrapper: Whether the func_op is a PytorchQuantizationWrapper or not.
81
89
  Returns:
82
- Combined list of input_tensors and op_call_args
90
+ Combined list of input_tensors and op_call_args.
83
91
  """
84
92
  if isinstance(_node, FunctionalNode) and _node.tensor_input_indices:
85
- assert len(_node.tensor_input_indices) == len(input_tensors), 'Mismatch between input tensors and indices'
86
93
  _input_list = op_call_args.copy()
87
- for i, t in zip(_node.tensor_input_indices, input_tensors):
88
- _input_list.insert(i, t)
94
+ if is_op_quantize_wrapper:
95
+ _input_list = input_tensors + _input_list
96
+ else:
97
+ assert len(_node.tensor_input_indices) == len(input_tensors), 'Mismatch between input tensors and indices'
98
+ for i, t in zip(_node.tensor_input_indices, input_tensors):
99
+ _input_list.insert(i, t)
89
100
  else:
90
101
  _input_list = input_tensors + op_call_args
91
102
 
@@ -118,7 +129,8 @@ def _run_operation(n: BaseNode,
118
129
  if isinstance(n, FunctionalNode) and n.inputs_as_list:
119
130
  out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
120
131
  else:
121
- out_tensors_of_n_float = op_func(*_merge_inputs(n, input_tensors, op_call_args), **functional_kwargs)
132
+ merged_inputs = _merge_inputs(n, input_tensors, op_call_args, isinstance(op_func, PytorchQuantizationWrapper))
133
+ out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs)
122
134
 
123
135
  # Add a fake quant node if the node has an activation threshold.
124
136
  out_tensors_of_n = out_tensors_of_n_float
@@ -279,12 +291,12 @@ class PytorchModel(torch.nn.Module):
279
291
  node_to_output_tensors_dict_float = dict()
280
292
  configurable_nodes = self.graph.get_configurable_sorted_nodes_names(DEFAULT_PYTORCH_INFO)
281
293
  for node in self.node_sort:
294
+ op_func = self._get_op_func(node, configurable_nodes)
282
295
  input_tensors = _build_input_tensors_list(node,
283
296
  self.graph,
284
297
  args,
285
- node_to_output_tensors_dict)
286
-
287
- op_func = self._get_op_func(node, configurable_nodes)
298
+ node_to_output_tensors_dict,
299
+ isinstance(op_func, PytorchQuantizationWrapper))
288
300
  use_activation_quantization, activation_quantization_fn = self._get_activation_quantization_fn(node)
289
301
 
290
302
  # Run node operation and fetch outputs
@@ -326,15 +338,16 @@ class PytorchModel(torch.nn.Module):
326
338
  """
327
339
  return getattr(self, node.name)
328
340
 
329
- def _get_activation_quantization_fn(self, node) -> Tuple[bool, bool, Callable]:
341
+ def _get_activation_quantization_fn(self, node) -> Tuple[bool, Callable]:
330
342
  """
331
343
  Get activation quantization parameters for this node.
332
344
 
333
345
  Args:
334
346
  node: Node from which to extract the activation quantization parameters.
335
347
 
336
- Returns: Flag to indicate if we quantize activations, flag to indicate if we quantize activations
337
- using a quantization holder and a quantization function to use for the node's activations.
348
+ Returns:
349
+ Flag to indicate if we quantize activations using a quantization holder and a quantization
350
+ function to use for the node's activations.
338
351
  """
339
352
  activation_quantization_holder = self.node_to_activation_quantization_holder.get(node.name)
340
353
  use_activation_quantization = node.is_activation_quantization_enabled()
@@ -62,11 +62,11 @@ def update_kernel_for_bn_folding_fn(conv_node: BaseNode,
62
62
  Returns:
63
63
  The modified convolution node's weight/kernel/
64
64
  """
65
- if conv_node.type == ConvTranspose2d:
65
+ if conv_node.is_match_type(ConvTranspose2d):
66
66
  _scale = weights_scale[None, :, None, None]
67
67
  else:
68
68
  _scale = weights_scale[:, None, None, None]
69
- if conv_node.type == ConvTranspose2d and conv_node.framework_attr[GROUPS] > 1:
69
+ if conv_node.is_match_type(ConvTranspose2d) and conv_node.framework_attr[GROUPS] > 1:
70
70
  # PyTorch ConvTranspose2d kernel with groups stacks groups on in_channels axis, so need to reshape the kernel
71
71
  # so the groups are stacked on the out_channels axis to match the scale vector (then reshape back to original
72
72
  # shape)
@@ -93,10 +93,10 @@ def update_weights_for_bn_forward_folding_fn(conv_node: BaseNode,
93
93
  Returns:
94
94
  The modified convolution node's weight/kernel/
95
95
  """
96
- if conv_node.type == Conv2d and conv_node.framework_attr[GROUPS] > 1:
96
+ if conv_node.is_match_type(Conv2d) and conv_node.framework_attr[GROUPS] > 1:
97
97
  bias_update = (kernel * bias_factor[:, None, None, None]).flatten()
98
98
  _scale = weights_scale[:, None, None, None]
99
- elif conv_node.type == ConvTranspose2d:
99
+ elif conv_node.is_match_type(ConvTranspose2d):
100
100
  bias_update = (kernel * bias_factor[:, None, None, None]).sum(axis=0).flatten()
101
101
  _scale = weights_scale[:, None, None, None]
102
102
  else:
@@ -125,8 +125,8 @@ def is_group_conv_fn(node: BaseNode) -> bool:
125
125
  Returns:
126
126
  True if the node is a group convolution, else False
127
127
  """
128
- return node.type in [Conv2d, ConvTranspose2d] and \
129
- node.framework_attr[GROUPS] not in [node.framework_attr[IN_CHANNELS], 1]
128
+ return (node.is_match_type(Conv2d) or node.is_match_type(ConvTranspose2d)) and \
129
+ node.framework_attr[GROUPS] not in [node.framework_attr[IN_CHANNELS], 1]
130
130
 
131
131
 
132
132
  def get_foldable_node_type_and_validity_fn(node: BaseNode) -> [bool, bool]:
@@ -140,8 +140,8 @@ def get_foldable_node_type_and_validity_fn(node: BaseNode) -> [bool, bool]:
140
140
  is_bn: True if the node is a batch norm, else False
141
141
  is_dw_valid: True if the node is a dw-convolution valid for folding or a batch-norm node, else False
142
142
  """
143
- is_bn = node.type is BatchNorm2d
144
- is_dw = node.type is Conv2d and node.framework_attr[GROUPS] == node.framework_attr[IN_CHANNELS]
143
+ is_bn = node.is_match_type(BatchNorm2d)
144
+ is_dw = node.is_match_type(Conv2d) and node.framework_attr[GROUPS] == node.framework_attr[IN_CHANNELS]
145
145
  is_dw_valid = is_dw and np.all(np.array(node.get_weights_by_keys(KERNEL).shape[2:]) == 1)
146
146
  return is_bn, is_dw_valid
147
147
 
@@ -48,9 +48,9 @@ class FunctionalConvSubstitution(common.BaseSubstitution):
48
48
  Graph after applying the substitution.
49
49
  """
50
50
  # Set new layer
51
- if func_node.type == conv2d:
51
+ if func_node.is_match_type(conv2d):
52
52
  new_layer = Conv2d
53
- elif func_node.type == conv_transpose2d:
53
+ elif func_node.is_match_type(conv_transpose2d):
54
54
  new_layer = ConvTranspose2d
55
55
  else:
56
56
  Logger.critical(f'Substitution filter mismatch. Layer {func_node.type}. Must be {type(Conv2d)} or {type(ConvTranspose2d)}.') # pragma: no cover
@@ -53,7 +53,7 @@ def conv2d_collapsing_fn(first_node: BaseNode,
53
53
  Returns:
54
54
  The modified layer node's weights: kernel, bias
55
55
  """
56
- if first_node.type == Conv2d and second_node.type == Conv2d:
56
+ if first_node.is_match_type(Conv2d) and second_node.is_match_type(Conv2d):
57
57
  # Get nodes attributes
58
58
  kernel1 = first_node.get_weights_by_keys(kernel_str)
59
59
  kernel2 = second_node.get_weights_by_keys(kernel_str)
@@ -76,17 +76,17 @@ class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
76
76
  second_op2d_node = nodes_list[2]
77
77
 
78
78
  # only act on bound relu with not POT max value and 0 min value
79
- if non_linear_node.type == ReLU6:
79
+ if non_linear_node.is_match_type(ReLU6):
80
80
  scale_factor = 6.0 / self.threshold
81
81
  non_linear_node.layer_class = Hardtanh
82
82
  non_linear_node.framework_attr[INPLACE] = False
83
83
  non_linear_node.framework_attr[HARDTANH_MIN_VAL] = 0.0
84
84
  non_linear_node.framework_attr[HARDTANH_MAX_VAL] = self.threshold
85
- elif non_linear_node.type == relu6:
85
+ elif non_linear_node.is_match_type(relu6):
86
86
  scale_factor = 6.0 / self.threshold
87
87
  non_linear_node.functional_op = hardtanh
88
88
  non_linear_node.functional_op.__defaults__ = (0.0, self.threshold, False)
89
- elif non_linear_node.type == Hardtanh:
89
+ elif non_linear_node.is_match_type(Hardtanh):
90
90
  if (non_linear_node.framework_attr[HARDTANH_MIN_VAL] == 0.0) and not \
91
91
  (np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]).astype(int) -
92
92
  np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]) == 0):
@@ -94,7 +94,7 @@ class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
94
94
  non_linear_node.framework_attr[HARDTANH_MAX_VAL] = self.threshold
95
95
  else:
96
96
  return graph
97
- elif non_linear_node.type == hardtanh:
97
+ elif non_linear_node.is_match_type(hardtanh):
98
98
  if (non_linear_node.framework_attr[HARDTANH_MIN_VAL] == 0.0) and not \
99
99
  (np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]).astype(int) -
100
100
  np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]) == 0):
@@ -46,7 +46,7 @@ def residual_collapsing_fn(first_node: BaseNode,
46
46
  Returns:
47
47
  The modified layer node's weights: kernel
48
48
  """
49
- if first_node.type == Conv2d:
49
+ if first_node.is_match_type(Conv2d):
50
50
  # Get nodes attributes
51
51
  kernel = first_node.get_weights_by_keys(kernel_str)
52
52
  (Cout, Cin, kH, kW) = kernel.shape
@@ -76,9 +76,9 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
76
76
  pruned_parameters = {}
77
77
  mask_bool = output_mask.astype(bool)
78
78
  node.weights = pruned_parameters
79
- if node.type == torch.nn.BatchNorm2d:
79
+ if node.is_match_type(torch.nn.BatchNorm2d):
80
80
  node.framework_attr[NUM_FEATURES] = int(np.sum(input_mask))
81
- elif node.type == torch.nn.PReLU:
81
+ elif node.is_match_type(torch.nn.PReLU):
82
82
  if node.framework_attr[NUM_PARAMETERS] > 1:
83
83
  node.framework_attr[NUM_PARAMETERS] = int(np.sum(input_mask))
84
84
  else:
@@ -227,9 +227,9 @@ def _is_pytorch_node_pruning_section_edge(node: BaseNode) -> bool:
227
227
  """
228
228
 
229
229
  # Check if the node is a Conv2D or Conv2DTranspose layer with groups set to 1.
230
- if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
230
+ if node.is_match_type(torch.nn.Conv2d) or node.is_match_type(torch.nn.ConvTranspose2d):
231
231
  return node.framework_attr[GROUPS] == 1
232
- return node.type == torch.nn.Linear
232
+ return node.is_match_type(torch.nn.Linear)
233
233
 
234
234
 
235
235
  def _prune_pytorch_edge_node(node: BaseNode,
@@ -268,18 +268,18 @@ def _prune_pytorch_edge_node(node: BaseNode,
268
268
  if not is_exit_node:
269
269
  # Update 'out_channels' or 'out_features' attributes for entry nodes
270
270
  # Conv2d,ConvTranspose2d / Linear layers.
271
- if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
271
+ if node.is_match_type(torch.nn.Conv2d) or node.is_match_type(torch.nn.ConvTranspose2d):
272
272
  node.framework_attr[OUT_CHANNELS] = int(np.sum(mask))
273
- elif node.type == torch.nn.Linear:
273
+ elif node.is_match_type(torch.nn.Linear):
274
274
  node.framework_attr[OUT_FEATURES] = int(np.sum(mask))
275
275
  else:
276
276
  Logger.critical(f"{node.type} is currently not supported"
277
277
  f"as an edge node in a pruning section")
278
278
 
279
279
  if is_exit_node:
280
- if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
280
+ if node.is_match_type(torch.nn.Conv2d) or node.is_match_type(torch.nn.ConvTranspose2d):
281
281
  node.framework_attr[IN_CHANNELS] = int(np.sum(mask))
282
- elif node.type == torch.nn.Linear:
282
+ elif node.is_match_type(torch.nn.Linear):
283
283
  node.framework_attr[IN_FEATURES] = int(np.sum(mask))
284
284
  else:
285
285
  Logger.critical(f"{node.type} is currently not supported"
@@ -398,8 +398,8 @@ class PytorchImplementation(FrameworkImplementation):
398
398
  Returns: True if the node should be considered an interest point, False otherwise.
399
399
  """
400
400
 
401
- if node.type in [Conv2d, Linear, ConvTranspose2d, Sigmoid, sigmoid, Softmax, softmax, operator.add, add, cat,
402
- operator.concat]:
401
+ if any([node.is_match_type(_type) for _type in [Conv2d, Linear, ConvTranspose2d, Sigmoid, sigmoid, Softmax,
402
+ softmax, operator.add, add, cat, operator.concat]]):
403
403
  return True
404
404
  return False
405
405
 
@@ -464,12 +464,12 @@ class PytorchImplementation(FrameworkImplementation):
464
464
  kernel_shape = node.get_weights_by_keys(fw_info.get_kernel_op_attributes(node.type)[0]).shape
465
465
  output_channel_axis, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
466
466
 
467
- if node.type is Conv2d or node.type is ConvTranspose2d:
467
+ if node.is_match_type(Conv2d) or node.is_match_type(ConvTranspose2d):
468
468
  # (C_out * W_out * H_out) * C_in * (W_kernel * H_kernel)
469
469
  return np.prod([x for x in output_shape if x is not None]) * \
470
470
  kernel_shape[input_channel_axis] * \
471
471
  (kernel_shape[0] * kernel_shape[1])
472
- elif node.type is Linear:
472
+ elif node.is_match_type(Linear):
473
473
  # IN * OUT
474
474
  return kernel_shape[0] * kernel_shape[1]
475
475
  else:
@@ -552,7 +552,6 @@ class PytorchImplementation(FrameworkImplementation):
552
552
  Returns:
553
553
  weight_quantizers: A dictionary between a weight's name to its quantizer.
554
554
  activation_quantizers: A list of activations quantization, one for each layer output.
555
-
556
555
  """
557
556
 
558
557
  return get_inferable_quantizers(node,
@@ -62,7 +62,7 @@ def _get_mean_std_outputs(node: BaseNode,
62
62
  """
63
63
  mean_output, std_output = None, None
64
64
 
65
- if node.type == BatchNorm2d:
65
+ if node.is_match_type(BatchNorm2d):
66
66
  mean_output = node.get_weights_by_keys(BETA)
67
67
  if node.get_weights_by_keys(GAMMA) is None:
68
68
  std_output = 1.0
@@ -72,7 +72,7 @@ def _get_mean_std_outputs(node: BaseNode,
72
72
  mean_output = 0.0
73
73
  else:
74
74
  next_node_list = graph.get_next_nodes(node)
75
- bn_nodes = [bn_node for bn_node in next_node_list if bn_node.type == BatchNorm2d]
75
+ bn_nodes = [bn_node for bn_node in next_node_list if bn_node.is_match_type(BatchNorm2d)]
76
76
  if len(bn_nodes) != 0:
77
77
  bn_node = bn_nodes[0]
78
78
  moving_variance = bn_node.get_weights_by_keys(MOVING_VARIANCE)
@@ -42,8 +42,12 @@ if FOUND_TF:
42
42
  """
43
43
  weights_quantizers, _ = fw_impl.get_inferable_quantizers(node)
44
44
  if len(weights_quantizers) > 0:
45
+ # for positional weights we need to extract the weight's value.
46
+ weights_values = {attr: node.get_weights_by_keys(attr)
47
+ for attr in weights_quantizers if isinstance(attr, int)}
45
48
  return KerasQuantizationWrapper(layer,
46
- weights_quantizers)
49
+ weights_quantizers,
50
+ weights_values)
47
51
  return layer
48
52
 
49
53
 
@@ -29,7 +29,7 @@ if FOUND_TORCH:
29
29
 
30
30
  def fully_quantized_wrapper(node: common.BaseNode,
31
31
  module: torch.nn.Module,
32
- fw_impl) -> Union[torch.nn.Module,PytorchQuantizationWrapper]:
32
+ fw_impl) -> Union[torch.nn.Module, PytorchQuantizationWrapper]:
33
33
  """
34
34
  A function which takes a computational graph node and a pytorch module and
35
35
  perform the quantization wrapping
@@ -37,20 +37,26 @@ if FOUND_TORCH:
37
37
  Args:
38
38
  node: A node of mct graph.
39
39
  module: A Pytorch module
40
+ fw_impl: FrameworkImplementation object with a specific framework methods implementation.
40
41
  Returns: Wrapped layer
41
42
 
42
43
  """
43
44
  weight_quantizers, _ = fw_impl.get_inferable_quantizers(node)
44
45
  if len(weight_quantizers) > 0:
45
- return PytorchQuantizationWrapper(module, weight_quantizers)
46
+ # for positional weights we need to extract the weight's value.
47
+ weights_values = {attr: fw_impl.to_tensor(node.get_weights_by_keys(attr))
48
+ for attr in weight_quantizers if isinstance(attr, int)}
49
+ return PytorchQuantizationWrapper(module, weight_quantizers, weights_values)
46
50
  return module
47
51
 
52
+
48
53
  def get_activation_quantizer_holder(node: BaseNode, fw_impl) -> Callable:
49
54
  """
50
55
  Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node.
51
56
  If the layer is not supposed to be wrapped with an activation quantizer - return None.
52
57
  Args:
53
58
  node: Node to attach a PytorchActivationQuantizationHolder to its output.
59
+ fw_impl: FrameworkImplementation object with a specific framework methods implementation.
54
60
  Returns:
55
61
  A PytorchActivationQuantizationHolder module for the node's activation quantization.
56
62
  """
@@ -64,6 +70,7 @@ if FOUND_TORCH:
64
70
  f'PytorchActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
65
71
  f'were found for node {node}')
66
72
 
73
+
67
74
  def get_exportable_pytorch_model(graph: Graph):
68
75
  """
69
76
  Convert graph to fully quantized PyTorch model.