mct-nightly 2.0.0.20240418.439__py3-none-any.whl → 2.0.0.20240419.358__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.0.0.20240418.439.dist-info → mct_nightly-2.0.0.20240419.358.dist-info}/METADATA +1 -1
- {mct_nightly-2.0.0.20240418.439.dist-info → mct_nightly-2.0.0.20240419.358.dist-info}/RECORD +39 -39
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/base_graph.py +2 -2
- model_compression_toolkit/core/common/graph/base_node.py +25 -8
- model_compression_toolkit/core/common/graph/functional_node.py +18 -1
- model_compression_toolkit/core/common/network_editors/node_filters.py +4 -3
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +0 -5
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -3
- model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +4 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -7
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/keras_implementation.py +10 -10
- model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -4
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +4 -5
- model_compression_toolkit/core/keras/reader/common.py +2 -2
- model_compression_toolkit/core/keras/reader/node_builder.py +28 -9
- model_compression_toolkit/core/keras/tf_tensor_numpy.py +5 -2
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +34 -21
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_folding.py +8 -8
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +4 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +8 -8
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +4 -5
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +5 -1
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +9 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +20 -6
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +22 -8
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +1 -1
- {mct_nightly-2.0.0.20240418.439.dist-info → mct_nightly-2.0.0.20240419.358.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.0.0.20240418.439.dist-info → mct_nightly-2.0.0.20240419.358.dist-info}/WHEEL +0 -0
- {mct_nightly-2.0.0.20240418.439.dist-info → mct_nightly-2.0.0.20240419.358.dist-info}/top_level.txt +0 -0
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from functools import partial
|
|
16
|
-
from typing import List, Any, Tuple, Callable, Dict
|
|
16
|
+
from typing import List, Any, Tuple, Callable, Dict, Union
|
|
17
17
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
import tensorflow as tf
|
|
@@ -412,12 +412,13 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
412
412
|
Returns: True if the node should be considered an interest point, False otherwise.
|
|
413
413
|
"""
|
|
414
414
|
|
|
415
|
-
if node.
|
|
415
|
+
if node.is_match_type(Activation):
|
|
416
416
|
node_type_name = node.framework_attr[keras_constants.ACTIVATION]
|
|
417
417
|
if node_type_name in [keras_constants.SOFTMAX, keras_constants.SIGMOID]:
|
|
418
418
|
return True
|
|
419
|
-
elif node.
|
|
420
|
-
|
|
419
|
+
elif any([node.is_match_type(_type) for _type in [tf.nn.softmax, tf.keras.layers.Softmax, tf.nn.sigmoid, Conv2D,
|
|
420
|
+
DepthwiseConv2D, Conv2DTranspose, Dense, Concatenate, tf.concat,
|
|
421
|
+
Add, tf.add]]):
|
|
421
422
|
return True
|
|
422
423
|
|
|
423
424
|
return False
|
|
@@ -529,18 +530,18 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
529
530
|
kernel_shape = node.get_weights_by_keys(fw_info.get_kernel_op_attributes(node.type)[0]).shape
|
|
530
531
|
output_channel_axis, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
|
|
531
532
|
|
|
532
|
-
if node.
|
|
533
|
+
if node.is_match_type(Conv2D) or node.is_match_type(Conv2DTranspose):
|
|
533
534
|
# (C_out * W_out * H_out) * C_in * (W_kernel * H_kernel)
|
|
534
535
|
return np.prod([x for x in output_shape if x is not None]) * \
|
|
535
536
|
kernel_shape[input_channel_axis] * \
|
|
536
537
|
(kernel_shape[0] * kernel_shape[1])
|
|
537
|
-
elif node.
|
|
538
|
+
elif node.is_match_type(DepthwiseConv2D):
|
|
538
539
|
# Depth * (W_out * H_out) * C_in * (W_kernel * H_kernel)
|
|
539
540
|
return node.framework_attr.get(DEPTH_MULTIPLIER) * \
|
|
540
541
|
np.prod([x for x in output_shape if x is not None]) / output_shape[output_channel_axis] * \
|
|
541
542
|
kernel_shape[input_channel_axis] * \
|
|
542
543
|
(kernel_shape[0] * kernel_shape[1])
|
|
543
|
-
elif node.
|
|
544
|
+
elif node.is_match_type(Dense):
|
|
544
545
|
# IN * OUT
|
|
545
546
|
return kernel_shape[0] * kernel_shape[1]
|
|
546
547
|
else:
|
|
@@ -593,10 +594,9 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
593
594
|
Returns:
|
|
594
595
|
weight_quantizers: A dictionary between a weight's name to its quantizer.
|
|
595
596
|
activation_quantizers: A list of activations quantization, one for each layer output.
|
|
596
|
-
|
|
597
597
|
"""
|
|
598
598
|
|
|
599
|
-
def _weight_name(w: str) -> str:
|
|
599
|
+
def _weight_name(w: Union[str, int]) -> Union[str, int]:
|
|
600
600
|
"""
|
|
601
601
|
Extracts the weight name from the full TensorFlow variable name.
|
|
602
602
|
|
|
@@ -609,7 +609,7 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
609
609
|
Extracted weight name.
|
|
610
610
|
"""
|
|
611
611
|
|
|
612
|
-
return w.split(':')[0].split('/')[-1]
|
|
612
|
+
return w.split(':')[0].split('/')[-1] if isinstance(w, str) else w
|
|
613
613
|
|
|
614
614
|
attribute_names = [_weight_name(wn) for wn in node.get_node_weights_attributes()
|
|
615
615
|
if node.is_weights_quantization_enabled(wn)]
|
|
@@ -56,13 +56,13 @@ def _get_min_max_outputs(node: BaseNode,
|
|
|
56
56
|
"""
|
|
57
57
|
min_output, max_output = None, None
|
|
58
58
|
|
|
59
|
-
if node.
|
|
59
|
+
if node.is_match_type(ReLU):
|
|
60
60
|
min_output = node.framework_attr[THRESHOLD] if node.framework_attr[NEGATIVE_SLOPE] == 0 else None
|
|
61
61
|
|
|
62
62
|
elif fw_info.layers_has_min_max(node.type):
|
|
63
63
|
min_output, max_output = fw_info.layer_min_max_mapping[node.type]
|
|
64
64
|
|
|
65
|
-
elif node.
|
|
65
|
+
elif node.is_match_type(Activation) and fw_info.activation_has_min_max(node.framework_attr[ACTIVATION]):
|
|
66
66
|
min_output, max_output = fw_info.activation_min_max_mapping[node.framework_attr[ACTIVATION]]
|
|
67
67
|
|
|
68
68
|
return min_output, max_output
|
|
@@ -82,7 +82,7 @@ def _get_mean_std_outputs(node: BaseNode,
|
|
|
82
82
|
"""
|
|
83
83
|
mean_output, std_output = None, None
|
|
84
84
|
|
|
85
|
-
if node.
|
|
85
|
+
if node.is_match_type(BatchNormalization):
|
|
86
86
|
mean_output = node.get_weights_by_keys(BETA)
|
|
87
87
|
if node.get_weights_by_keys(GAMMA) is None:
|
|
88
88
|
std_output = 1.0
|
|
@@ -92,7 +92,7 @@ def _get_mean_std_outputs(node: BaseNode,
|
|
|
92
92
|
mean_output = 0.0
|
|
93
93
|
else:
|
|
94
94
|
next_node_list = graph.get_next_nodes(node)
|
|
95
|
-
bn_nodes = [bn_node for bn_node in next_node_list if bn_node.
|
|
95
|
+
bn_nodes = [bn_node for bn_node in next_node_list if bn_node.is_match_type(BatchNormalization)]
|
|
96
96
|
if len(bn_nodes) != 0:
|
|
97
97
|
bn_node = bn_nodes[0]
|
|
98
98
|
moving_variance = bn_node.get_weights_by_keys(MOVING_VARIANCE)
|
|
@@ -209,10 +209,9 @@ def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool:
|
|
|
209
209
|
"""
|
|
210
210
|
|
|
211
211
|
# Check if the node is a Conv2D or Conv2DTranspose layer with groups set to 1.
|
|
212
|
-
if node.
|
|
212
|
+
if node.is_match_type(keras.layers.Conv2D) or node.is_match_type(keras.layers.Conv2DTranspose):
|
|
213
213
|
return node.framework_attr[GROUPS] == 1
|
|
214
|
-
return node.
|
|
215
|
-
|
|
214
|
+
return node.is_match_type(keras.layers.Dense)
|
|
216
215
|
|
|
217
216
|
|
|
218
217
|
def _prune_keras_edge_node(node: BaseNode,
|
|
@@ -250,9 +249,9 @@ def _prune_keras_edge_node(node: BaseNode,
|
|
|
250
249
|
|
|
251
250
|
if not is_exit_node:
|
|
252
251
|
# Update 'filters' or 'units' attributes for entry node Conv2D/Conv2DTranspose layers.
|
|
253
|
-
if node.
|
|
252
|
+
if node.is_match_type(keras.layers.Conv2D) or node.is_match_type(keras.layers.Conv2DTranspose):
|
|
254
253
|
node.framework_attr[FILTERS] = int(np.sum(mask))
|
|
255
|
-
elif node.
|
|
254
|
+
elif node.is_match_type(keras.layers.Dense):
|
|
256
255
|
node.framework_attr[UNITS] = int(np.sum(mask))
|
|
257
256
|
|
|
258
257
|
if is_exit_node:
|
|
@@ -43,7 +43,7 @@ def is_node_an_input_layer(node: BaseNode) -> bool:
|
|
|
43
43
|
Whether the node represents an input layer or not.
|
|
44
44
|
"""
|
|
45
45
|
if isinstance(node, BaseNode):
|
|
46
|
-
return node.
|
|
46
|
+
return node.is_match_type(InputLayer)
|
|
47
47
|
elif isinstance(node, KerasNode):
|
|
48
48
|
return isinstance(node.layer, InputLayer)
|
|
49
49
|
else:
|
|
@@ -60,7 +60,7 @@ def is_node_a_model(node: BaseNode) -> bool:
|
|
|
60
60
|
Whether the node represents a Keras model or not.
|
|
61
61
|
"""
|
|
62
62
|
if isinstance(node, BaseNode):
|
|
63
|
-
return node.
|
|
63
|
+
return node.is_match_type(Functional) or node.is_match_type(Sequential)
|
|
64
64
|
elif isinstance(node, KerasNode):
|
|
65
65
|
return isinstance(node.layer, Functional) or isinstance(node.layer, Sequential)
|
|
66
66
|
else:
|
|
@@ -41,7 +41,7 @@ layers = keras.layers
|
|
|
41
41
|
|
|
42
42
|
REUSED_IDENTIFIER = '_reused_'
|
|
43
43
|
|
|
44
|
-
is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray))
|
|
44
|
+
is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray, float))
|
|
45
45
|
is_tensor = lambda x: isinstance(x, KerasTensor)
|
|
46
46
|
|
|
47
47
|
|
|
@@ -61,18 +61,36 @@ def get_kwargs2index(tfoplambda_layer: TFOpLambda) -> Dict[str, int]:
|
|
|
61
61
|
"""
|
|
62
62
|
Positional weights are saved according to their index in the node's call arguments, so
|
|
63
63
|
need to know the function arguments' names in case the weights are in the kwargs.
|
|
64
|
+
|
|
65
|
+
Note: the kwargs2index dictionary is initialized manually (and not with tf_inspect) so
|
|
66
|
+
it will only include the arguments that may contain constants. For example, we don't
|
|
67
|
+
want the transpose_a attribute of tf.matmul to be saved as a constant.
|
|
68
|
+
|
|
69
|
+
Every operation we add support to, needs to be added here.
|
|
70
|
+
|
|
64
71
|
Args:
|
|
65
72
|
tfoplambda_layer: TFOpLambda layer.
|
|
66
73
|
|
|
67
74
|
Returns:
|
|
68
75
|
A dictionary with argument number and index: {arg_name: arg_index}.
|
|
69
76
|
"""
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
77
|
+
kwargs2index = {tf.add: {'x': 0, 'y': 1},
|
|
78
|
+
tf.subtract: {'x': 0, 'y': 1},
|
|
79
|
+
tf.divide: {'x': 0, 'y': 1},
|
|
80
|
+
tf.truediv: {'x': 0, 'y': 1},
|
|
81
|
+
tf.multiply: {'x': 0, 'y': 1},
|
|
82
|
+
tf.pow: {'x': 0, 'y': 1},
|
|
83
|
+
tf.matmul: {'a': 0, 'b': 1}}.get(tfoplambda_layer.function)
|
|
84
|
+
if not kwargs2index:
|
|
85
|
+
# In TF 2.15 the function attribute is different and doesn't match the original
|
|
86
|
+
# operation object we use. Therefore, we extract kwargs2index with the symbol.
|
|
87
|
+
kwargs2index = {'__operators__.add': {'x': 0, 'y': 1},
|
|
88
|
+
'math.add': {'x': 0, 'y': 1},
|
|
89
|
+
'math.multiply': {'x': 0, 'y': 1},
|
|
90
|
+
'linalg.matmul': {'a': 0, 'b': 1},
|
|
91
|
+
'concat': {'values': 0}}.get(tfoplambda_layer.symbol, {})
|
|
92
|
+
|
|
93
|
+
return kwargs2index
|
|
76
94
|
|
|
77
95
|
|
|
78
96
|
def build_node(node: KerasNode,
|
|
@@ -154,8 +172,9 @@ def build_node(node: KerasNode,
|
|
|
154
172
|
if is_const(v) or (keras_layer.function in [tf.add, tf.multiply, tf.subtract, tf.divide, tf.truediv, tf.pow,
|
|
155
173
|
tf.matmul] and
|
|
156
174
|
isinstance(v, (tuple, list))):
|
|
157
|
-
|
|
158
|
-
|
|
175
|
+
if k in kwarg2index:
|
|
176
|
+
weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)})
|
|
177
|
+
weight_keys.append(k)
|
|
159
178
|
# remove weights and KerasTensors and weights from op_call_kwargs
|
|
160
179
|
op_call_kwargs = {k: v for k, v in op_call_kwargs.items()
|
|
161
180
|
if not (kwarg2index.get(k) in weights or is_tensor(v))}
|
|
@@ -40,7 +40,7 @@ def to_tf_tensor(tensor):
|
|
|
40
40
|
Logger.critical(f'Unsupported type for conversion to TF tensor: {type(tensor)}.')
|
|
41
41
|
|
|
42
42
|
|
|
43
|
-
def tf_tensor_to_numpy(tensor: Union[List, Tuple, np.ndarray, tf.Tensor],
|
|
43
|
+
def tf_tensor_to_numpy(tensor: Union[List, Tuple, np.ndarray, tf.Tensor, float],
|
|
44
44
|
is_single_tensor=False) -> np.ndarray:
|
|
45
45
|
"""
|
|
46
46
|
Convert a TF tensor to a Numpy array.
|
|
@@ -65,6 +65,9 @@ def tf_tensor_to_numpy(tensor: Union[List, Tuple, np.ndarray, tf.Tensor],
|
|
|
65
65
|
else:
|
|
66
66
|
return (tf_tensor_to_numpy(t) for t in tensor)
|
|
67
67
|
elif isinstance(tensor, tf.Tensor):
|
|
68
|
-
|
|
68
|
+
np_tensor = tensor.numpy()
|
|
69
|
+
return np.array([np_tensor]) if np.isscalar(np_tensor) else np_tensor
|
|
70
|
+
elif isinstance(tensor, float):
|
|
71
|
+
return np.array([tensor])
|
|
69
72
|
else:
|
|
70
73
|
Logger.critical(f'Unsupported type for conversion to Numpy array: {type(tensor)}.')
|
|
@@ -33,26 +33,31 @@ from model_compression_toolkit.core.pytorch.pytorch_device_config import get_wor
|
|
|
33
33
|
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
|
|
34
34
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
35
35
|
from mct_quantizers.common.constants import ACTIVATION_HOLDER_QUANTIZER
|
|
36
|
+
from mct_quantizers import PytorchQuantizationWrapper
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
def _build_input_tensors_list(node: BaseNode,
|
|
39
40
|
graph: Graph,
|
|
40
41
|
inputs: Tuple[Any],
|
|
41
|
-
node_to_output_tensors_dict: Dict[BaseNode, List]
|
|
42
|
+
node_to_output_tensors_dict: Dict[BaseNode, List],
|
|
43
|
+
is_op_quantize_wrapper: bool) -> List[List]:
|
|
42
44
|
"""
|
|
43
|
-
Given a node, build a list of input tensors the node gets. The list is built
|
|
44
|
-
|
|
45
|
+
Given a node, build a list of input tensors the node gets. The list is built based on the
|
|
46
|
+
node's incoming edges, previous nodes' output tensors and the node's positional weights.
|
|
47
|
+
Positional weights aren't used if the node's op is PytorchQuantizationWrapper, since it's
|
|
48
|
+
positional weights are already in the wrapper.
|
|
45
49
|
|
|
46
50
|
Args:
|
|
47
51
|
node: Node to build its input tensors list.
|
|
48
52
|
graph: Graph the node is in.
|
|
49
|
-
inputs: list of input tensors to model
|
|
53
|
+
inputs: list of input tensors to model.
|
|
50
54
|
node_to_output_tensors_dict: A dictionary from a node to its output tensors.
|
|
55
|
+
is_op_quantize_wrapper: Whether the func_op is a PytorchQuantizationWrapper or not.
|
|
51
56
|
|
|
52
57
|
Returns:
|
|
53
58
|
A list of the node's input tensors.
|
|
54
59
|
"""
|
|
55
|
-
if node.
|
|
60
|
+
if node.is_match_type(DummyPlaceHolder):
|
|
56
61
|
input_tensors = [inputs[graph.get_inputs().index(node)]]
|
|
57
62
|
else:
|
|
58
63
|
input_tensors = []
|
|
@@ -62,7 +67,8 @@ def _build_input_tensors_list(node: BaseNode,
|
|
|
62
67
|
_input_tensors = node_to_output_tensors_dict[ie.source_node]
|
|
63
68
|
input_tensors.append(_input_tensors)
|
|
64
69
|
input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists
|
|
65
|
-
|
|
70
|
+
if not is_op_quantize_wrapper:
|
|
71
|
+
input_tensors = node.insert_positional_weights_to_input_list(input_tensors)
|
|
66
72
|
# convert inputs from positional weights (numpy arrays) to tensors. Must handle each element in the
|
|
67
73
|
# list separately, because in FX the tensors are FX objects and fail to_torch_tensor
|
|
68
74
|
input_tensors = [to_torch_tensor(t) if isinstance(t, np.ndarray) else t
|
|
@@ -70,22 +76,27 @@ def _build_input_tensors_list(node: BaseNode,
|
|
|
70
76
|
return input_tensors
|
|
71
77
|
|
|
72
78
|
|
|
73
|
-
def _merge_inputs(_node, input_tensors: List, op_call_args: List
|
|
79
|
+
def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List,
|
|
80
|
+
is_op_quantize_wrapper: bool) -> List:
|
|
74
81
|
"""
|
|
75
|
-
Merge input tensors list with op_call_args, according to correct order
|
|
82
|
+
Merge input tensors list with op_call_args, according to correct order.
|
|
76
83
|
|
|
77
84
|
Args:
|
|
78
|
-
_node: The node the inputs are for
|
|
85
|
+
_node: The node the inputs are for.
|
|
79
86
|
input_tensors: activation input tensors to node.
|
|
80
|
-
op_call_args: framework node call args
|
|
87
|
+
op_call_args: framework node call args.
|
|
88
|
+
is_op_quantize_wrapper: Whether the func_op is a PytorchQuantizationWrapper or not.
|
|
81
89
|
Returns:
|
|
82
|
-
Combined list of input_tensors and op_call_args
|
|
90
|
+
Combined list of input_tensors and op_call_args.
|
|
83
91
|
"""
|
|
84
92
|
if isinstance(_node, FunctionalNode) and _node.tensor_input_indices:
|
|
85
|
-
assert len(_node.tensor_input_indices) == len(input_tensors), 'Mismatch between input tensors and indices'
|
|
86
93
|
_input_list = op_call_args.copy()
|
|
87
|
-
|
|
88
|
-
_input_list
|
|
94
|
+
if is_op_quantize_wrapper:
|
|
95
|
+
_input_list = input_tensors + _input_list
|
|
96
|
+
else:
|
|
97
|
+
assert len(_node.tensor_input_indices) == len(input_tensors), 'Mismatch between input tensors and indices'
|
|
98
|
+
for i, t in zip(_node.tensor_input_indices, input_tensors):
|
|
99
|
+
_input_list.insert(i, t)
|
|
89
100
|
else:
|
|
90
101
|
_input_list = input_tensors + op_call_args
|
|
91
102
|
|
|
@@ -118,7 +129,8 @@ def _run_operation(n: BaseNode,
|
|
|
118
129
|
if isinstance(n, FunctionalNode) and n.inputs_as_list:
|
|
119
130
|
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
|
|
120
131
|
else:
|
|
121
|
-
|
|
132
|
+
merged_inputs = _merge_inputs(n, input_tensors, op_call_args, isinstance(op_func, PytorchQuantizationWrapper))
|
|
133
|
+
out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs)
|
|
122
134
|
|
|
123
135
|
# Add a fake quant node if the node has an activation threshold.
|
|
124
136
|
out_tensors_of_n = out_tensors_of_n_float
|
|
@@ -279,12 +291,12 @@ class PytorchModel(torch.nn.Module):
|
|
|
279
291
|
node_to_output_tensors_dict_float = dict()
|
|
280
292
|
configurable_nodes = self.graph.get_configurable_sorted_nodes_names(DEFAULT_PYTORCH_INFO)
|
|
281
293
|
for node in self.node_sort:
|
|
294
|
+
op_func = self._get_op_func(node, configurable_nodes)
|
|
282
295
|
input_tensors = _build_input_tensors_list(node,
|
|
283
296
|
self.graph,
|
|
284
297
|
args,
|
|
285
|
-
node_to_output_tensors_dict
|
|
286
|
-
|
|
287
|
-
op_func = self._get_op_func(node, configurable_nodes)
|
|
298
|
+
node_to_output_tensors_dict,
|
|
299
|
+
isinstance(op_func, PytorchQuantizationWrapper))
|
|
288
300
|
use_activation_quantization, activation_quantization_fn = self._get_activation_quantization_fn(node)
|
|
289
301
|
|
|
290
302
|
# Run node operation and fetch outputs
|
|
@@ -326,15 +338,16 @@ class PytorchModel(torch.nn.Module):
|
|
|
326
338
|
"""
|
|
327
339
|
return getattr(self, node.name)
|
|
328
340
|
|
|
329
|
-
def _get_activation_quantization_fn(self, node) -> Tuple[bool,
|
|
341
|
+
def _get_activation_quantization_fn(self, node) -> Tuple[bool, Callable]:
|
|
330
342
|
"""
|
|
331
343
|
Get activation quantization parameters for this node.
|
|
332
344
|
|
|
333
345
|
Args:
|
|
334
346
|
node: Node from which to extract the activation quantization parameters.
|
|
335
347
|
|
|
336
|
-
Returns:
|
|
337
|
-
|
|
348
|
+
Returns:
|
|
349
|
+
Flag to indicate if we quantize activations using a quantization holder and a quantization
|
|
350
|
+
function to use for the node's activations.
|
|
338
351
|
"""
|
|
339
352
|
activation_quantization_holder = self.node_to_activation_quantization_holder.get(node.name)
|
|
340
353
|
use_activation_quantization = node.is_activation_quantization_enabled()
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_folding.py
CHANGED
|
@@ -62,11 +62,11 @@ def update_kernel_for_bn_folding_fn(conv_node: BaseNode,
|
|
|
62
62
|
Returns:
|
|
63
63
|
The modified convolution node's weight/kernel/
|
|
64
64
|
"""
|
|
65
|
-
if conv_node.
|
|
65
|
+
if conv_node.is_match_type(ConvTranspose2d):
|
|
66
66
|
_scale = weights_scale[None, :, None, None]
|
|
67
67
|
else:
|
|
68
68
|
_scale = weights_scale[:, None, None, None]
|
|
69
|
-
if conv_node.
|
|
69
|
+
if conv_node.is_match_type(ConvTranspose2d) and conv_node.framework_attr[GROUPS] > 1:
|
|
70
70
|
# PyTorch ConvTranspose2d kernel with groups stacks groups on in_channels axis, so need to reshape the kernel
|
|
71
71
|
# so the groups are stacked on the out_channels axis to match the scale vector (then reshape back to original
|
|
72
72
|
# shape)
|
|
@@ -93,10 +93,10 @@ def update_weights_for_bn_forward_folding_fn(conv_node: BaseNode,
|
|
|
93
93
|
Returns:
|
|
94
94
|
The modified convolution node's weight/kernel/
|
|
95
95
|
"""
|
|
96
|
-
if conv_node.
|
|
96
|
+
if conv_node.is_match_type(Conv2d) and conv_node.framework_attr[GROUPS] > 1:
|
|
97
97
|
bias_update = (kernel * bias_factor[:, None, None, None]).flatten()
|
|
98
98
|
_scale = weights_scale[:, None, None, None]
|
|
99
|
-
elif conv_node.
|
|
99
|
+
elif conv_node.is_match_type(ConvTranspose2d):
|
|
100
100
|
bias_update = (kernel * bias_factor[:, None, None, None]).sum(axis=0).flatten()
|
|
101
101
|
_scale = weights_scale[:, None, None, None]
|
|
102
102
|
else:
|
|
@@ -125,8 +125,8 @@ def is_group_conv_fn(node: BaseNode) -> bool:
|
|
|
125
125
|
Returns:
|
|
126
126
|
True if the node is a group convolution, else False
|
|
127
127
|
"""
|
|
128
|
-
return node.
|
|
129
|
-
|
|
128
|
+
return (node.is_match_type(Conv2d) or node.is_match_type(ConvTranspose2d)) and \
|
|
129
|
+
node.framework_attr[GROUPS] not in [node.framework_attr[IN_CHANNELS], 1]
|
|
130
130
|
|
|
131
131
|
|
|
132
132
|
def get_foldable_node_type_and_validity_fn(node: BaseNode) -> [bool, bool]:
|
|
@@ -140,8 +140,8 @@ def get_foldable_node_type_and_validity_fn(node: BaseNode) -> [bool, bool]:
|
|
|
140
140
|
is_bn: True if the node is a batch norm, else False
|
|
141
141
|
is_dw_valid: True if the node is a dw-convolution valid for folding or a batch-norm node, else False
|
|
142
142
|
"""
|
|
143
|
-
is_bn = node.
|
|
144
|
-
is_dw = node.
|
|
143
|
+
is_bn = node.is_match_type(BatchNorm2d)
|
|
144
|
+
is_dw = node.is_match_type(Conv2d) and node.framework_attr[GROUPS] == node.framework_attr[IN_CHANNELS]
|
|
145
145
|
is_dw_valid = is_dw and np.all(np.array(node.get_weights_by_keys(KERNEL).shape[2:]) == 1)
|
|
146
146
|
return is_bn, is_dw_valid
|
|
147
147
|
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py
CHANGED
|
@@ -48,9 +48,9 @@ class FunctionalConvSubstitution(common.BaseSubstitution):
|
|
|
48
48
|
Graph after applying the substitution.
|
|
49
49
|
"""
|
|
50
50
|
# Set new layer
|
|
51
|
-
if func_node.
|
|
51
|
+
if func_node.is_match_type(conv2d):
|
|
52
52
|
new_layer = Conv2d
|
|
53
|
-
elif func_node.
|
|
53
|
+
elif func_node.is_match_type(conv_transpose2d):
|
|
54
54
|
new_layer = ConvTranspose2d
|
|
55
55
|
else:
|
|
56
56
|
Logger.critical(f'Substitution filter mismatch. Layer {func_node.type}. Must be {type(Conv2d)} or {type(ConvTranspose2d)}.') # pragma: no cover
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py
CHANGED
|
@@ -53,7 +53,7 @@ def conv2d_collapsing_fn(first_node: BaseNode,
|
|
|
53
53
|
Returns:
|
|
54
54
|
The modified layer node's weights: kernel, bias
|
|
55
55
|
"""
|
|
56
|
-
if first_node.
|
|
56
|
+
if first_node.is_match_type(Conv2d) and second_node.is_match_type(Conv2d):
|
|
57
57
|
# Get nodes attributes
|
|
58
58
|
kernel1 = first_node.get_weights_by_keys(kernel_str)
|
|
59
59
|
kernel2 = second_node.get_weights_by_keys(kernel_str)
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py
CHANGED
|
@@ -76,17 +76,17 @@ class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
|
|
|
76
76
|
second_op2d_node = nodes_list[2]
|
|
77
77
|
|
|
78
78
|
# only act on bound relu with not POT max value and 0 min value
|
|
79
|
-
if non_linear_node.
|
|
79
|
+
if non_linear_node.is_match_type(ReLU6):
|
|
80
80
|
scale_factor = 6.0 / self.threshold
|
|
81
81
|
non_linear_node.layer_class = Hardtanh
|
|
82
82
|
non_linear_node.framework_attr[INPLACE] = False
|
|
83
83
|
non_linear_node.framework_attr[HARDTANH_MIN_VAL] = 0.0
|
|
84
84
|
non_linear_node.framework_attr[HARDTANH_MAX_VAL] = self.threshold
|
|
85
|
-
elif non_linear_node.
|
|
85
|
+
elif non_linear_node.is_match_type(relu6):
|
|
86
86
|
scale_factor = 6.0 / self.threshold
|
|
87
87
|
non_linear_node.functional_op = hardtanh
|
|
88
88
|
non_linear_node.functional_op.__defaults__ = (0.0, self.threshold, False)
|
|
89
|
-
elif non_linear_node.
|
|
89
|
+
elif non_linear_node.is_match_type(Hardtanh):
|
|
90
90
|
if (non_linear_node.framework_attr[HARDTANH_MIN_VAL] == 0.0) and not \
|
|
91
91
|
(np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]).astype(int) -
|
|
92
92
|
np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]) == 0):
|
|
@@ -94,7 +94,7 @@ class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
|
|
|
94
94
|
non_linear_node.framework_attr[HARDTANH_MAX_VAL] = self.threshold
|
|
95
95
|
else:
|
|
96
96
|
return graph
|
|
97
|
-
elif non_linear_node.
|
|
97
|
+
elif non_linear_node.is_match_type(hardtanh):
|
|
98
98
|
if (non_linear_node.framework_attr[HARDTANH_MIN_VAL] == 0.0) and not \
|
|
99
99
|
(np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]).astype(int) -
|
|
100
100
|
np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]) == 0):
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py
CHANGED
|
@@ -46,7 +46,7 @@ def residual_collapsing_fn(first_node: BaseNode,
|
|
|
46
46
|
Returns:
|
|
47
47
|
The modified layer node's weights: kernel
|
|
48
48
|
"""
|
|
49
|
-
if first_node.
|
|
49
|
+
if first_node.is_match_type(Conv2d):
|
|
50
50
|
# Get nodes attributes
|
|
51
51
|
kernel = first_node.get_weights_by_keys(kernel_str)
|
|
52
52
|
(Cout, Cin, kH, kW) = kernel.shape
|
|
@@ -76,9 +76,9 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
|
76
76
|
pruned_parameters = {}
|
|
77
77
|
mask_bool = output_mask.astype(bool)
|
|
78
78
|
node.weights = pruned_parameters
|
|
79
|
-
if node.
|
|
79
|
+
if node.is_match_type(torch.nn.BatchNorm2d):
|
|
80
80
|
node.framework_attr[NUM_FEATURES] = int(np.sum(input_mask))
|
|
81
|
-
elif node.
|
|
81
|
+
elif node.is_match_type(torch.nn.PReLU):
|
|
82
82
|
if node.framework_attr[NUM_PARAMETERS] > 1:
|
|
83
83
|
node.framework_attr[NUM_PARAMETERS] = int(np.sum(input_mask))
|
|
84
84
|
else:
|
|
@@ -227,9 +227,9 @@ def _is_pytorch_node_pruning_section_edge(node: BaseNode) -> bool:
|
|
|
227
227
|
"""
|
|
228
228
|
|
|
229
229
|
# Check if the node is a Conv2D or Conv2DTranspose layer with groups set to 1.
|
|
230
|
-
if node.
|
|
230
|
+
if node.is_match_type(torch.nn.Conv2d) or node.is_match_type(torch.nn.ConvTranspose2d):
|
|
231
231
|
return node.framework_attr[GROUPS] == 1
|
|
232
|
-
return node.
|
|
232
|
+
return node.is_match_type(torch.nn.Linear)
|
|
233
233
|
|
|
234
234
|
|
|
235
235
|
def _prune_pytorch_edge_node(node: BaseNode,
|
|
@@ -268,18 +268,18 @@ def _prune_pytorch_edge_node(node: BaseNode,
|
|
|
268
268
|
if not is_exit_node:
|
|
269
269
|
# Update 'out_channels' or 'out_features' attributes for entry nodes
|
|
270
270
|
# Conv2d,ConvTranspose2d / Linear layers.
|
|
271
|
-
if node.
|
|
271
|
+
if node.is_match_type(torch.nn.Conv2d) or node.is_match_type(torch.nn.ConvTranspose2d):
|
|
272
272
|
node.framework_attr[OUT_CHANNELS] = int(np.sum(mask))
|
|
273
|
-
elif node.
|
|
273
|
+
elif node.is_match_type(torch.nn.Linear):
|
|
274
274
|
node.framework_attr[OUT_FEATURES] = int(np.sum(mask))
|
|
275
275
|
else:
|
|
276
276
|
Logger.critical(f"{node.type} is currently not supported"
|
|
277
277
|
f"as an edge node in a pruning section")
|
|
278
278
|
|
|
279
279
|
if is_exit_node:
|
|
280
|
-
if node.
|
|
280
|
+
if node.is_match_type(torch.nn.Conv2d) or node.is_match_type(torch.nn.ConvTranspose2d):
|
|
281
281
|
node.framework_attr[IN_CHANNELS] = int(np.sum(mask))
|
|
282
|
-
elif node.
|
|
282
|
+
elif node.is_match_type(torch.nn.Linear):
|
|
283
283
|
node.framework_attr[IN_FEATURES] = int(np.sum(mask))
|
|
284
284
|
else:
|
|
285
285
|
Logger.critical(f"{node.type} is currently not supported"
|
|
@@ -398,8 +398,8 @@ class PytorchImplementation(FrameworkImplementation):
|
|
|
398
398
|
Returns: True if the node should be considered an interest point, False otherwise.
|
|
399
399
|
"""
|
|
400
400
|
|
|
401
|
-
if node.
|
|
402
|
-
|
|
401
|
+
if any([node.is_match_type(_type) for _type in [Conv2d, Linear, ConvTranspose2d, Sigmoid, sigmoid, Softmax,
|
|
402
|
+
softmax, operator.add, add, cat, operator.concat]]):
|
|
403
403
|
return True
|
|
404
404
|
return False
|
|
405
405
|
|
|
@@ -464,12 +464,12 @@ class PytorchImplementation(FrameworkImplementation):
|
|
|
464
464
|
kernel_shape = node.get_weights_by_keys(fw_info.get_kernel_op_attributes(node.type)[0]).shape
|
|
465
465
|
output_channel_axis, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
|
|
466
466
|
|
|
467
|
-
if node.
|
|
467
|
+
if node.is_match_type(Conv2d) or node.is_match_type(ConvTranspose2d):
|
|
468
468
|
# (C_out * W_out * H_out) * C_in * (W_kernel * H_kernel)
|
|
469
469
|
return np.prod([x for x in output_shape if x is not None]) * \
|
|
470
470
|
kernel_shape[input_channel_axis] * \
|
|
471
471
|
(kernel_shape[0] * kernel_shape[1])
|
|
472
|
-
elif node.
|
|
472
|
+
elif node.is_match_type(Linear):
|
|
473
473
|
# IN * OUT
|
|
474
474
|
return kernel_shape[0] * kernel_shape[1]
|
|
475
475
|
else:
|
|
@@ -552,7 +552,6 @@ class PytorchImplementation(FrameworkImplementation):
|
|
|
552
552
|
Returns:
|
|
553
553
|
weight_quantizers: A dictionary between a weight's name to its quantizer.
|
|
554
554
|
activation_quantizers: A list of activations quantization, one for each layer output.
|
|
555
|
-
|
|
556
555
|
"""
|
|
557
556
|
|
|
558
557
|
return get_inferable_quantizers(node,
|
|
@@ -62,7 +62,7 @@ def _get_mean_std_outputs(node: BaseNode,
|
|
|
62
62
|
"""
|
|
63
63
|
mean_output, std_output = None, None
|
|
64
64
|
|
|
65
|
-
if node.
|
|
65
|
+
if node.is_match_type(BatchNorm2d):
|
|
66
66
|
mean_output = node.get_weights_by_keys(BETA)
|
|
67
67
|
if node.get_weights_by_keys(GAMMA) is None:
|
|
68
68
|
std_output = 1.0
|
|
@@ -72,7 +72,7 @@ def _get_mean_std_outputs(node: BaseNode,
|
|
|
72
72
|
mean_output = 0.0
|
|
73
73
|
else:
|
|
74
74
|
next_node_list = graph.get_next_nodes(node)
|
|
75
|
-
bn_nodes = [bn_node for bn_node in next_node_list if bn_node.
|
|
75
|
+
bn_nodes = [bn_node for bn_node in next_node_list if bn_node.is_match_type(BatchNorm2d)]
|
|
76
76
|
if len(bn_nodes) != 0:
|
|
77
77
|
bn_node = bn_nodes[0]
|
|
78
78
|
moving_variance = bn_node.get_weights_by_keys(MOVING_VARIANCE)
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -42,8 +42,12 @@ if FOUND_TF:
|
|
|
42
42
|
"""
|
|
43
43
|
weights_quantizers, _ = fw_impl.get_inferable_quantizers(node)
|
|
44
44
|
if len(weights_quantizers) > 0:
|
|
45
|
+
# for positional weights we need to extract the weight's value.
|
|
46
|
+
weights_values = {attr: node.get_weights_by_keys(attr)
|
|
47
|
+
for attr in weights_quantizers if isinstance(attr, int)}
|
|
45
48
|
return KerasQuantizationWrapper(layer,
|
|
46
|
-
weights_quantizers
|
|
49
|
+
weights_quantizers,
|
|
50
|
+
weights_values)
|
|
47
51
|
return layer
|
|
48
52
|
|
|
49
53
|
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -29,7 +29,7 @@ if FOUND_TORCH:
|
|
|
29
29
|
|
|
30
30
|
def fully_quantized_wrapper(node: common.BaseNode,
|
|
31
31
|
module: torch.nn.Module,
|
|
32
|
-
fw_impl) -> Union[torch.nn.Module,PytorchQuantizationWrapper]:
|
|
32
|
+
fw_impl) -> Union[torch.nn.Module, PytorchQuantizationWrapper]:
|
|
33
33
|
"""
|
|
34
34
|
A function which takes a computational graph node and a pytorch module and
|
|
35
35
|
perform the quantization wrapping
|
|
@@ -37,20 +37,26 @@ if FOUND_TORCH:
|
|
|
37
37
|
Args:
|
|
38
38
|
node: A node of mct graph.
|
|
39
39
|
module: A Pytorch module
|
|
40
|
+
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
40
41
|
Returns: Wrapped layer
|
|
41
42
|
|
|
42
43
|
"""
|
|
43
44
|
weight_quantizers, _ = fw_impl.get_inferable_quantizers(node)
|
|
44
45
|
if len(weight_quantizers) > 0:
|
|
45
|
-
|
|
46
|
+
# for positional weights we need to extract the weight's value.
|
|
47
|
+
weights_values = {attr: fw_impl.to_tensor(node.get_weights_by_keys(attr))
|
|
48
|
+
for attr in weight_quantizers if isinstance(attr, int)}
|
|
49
|
+
return PytorchQuantizationWrapper(module, weight_quantizers, weights_values)
|
|
46
50
|
return module
|
|
47
51
|
|
|
52
|
+
|
|
48
53
|
def get_activation_quantizer_holder(node: BaseNode, fw_impl) -> Callable:
|
|
49
54
|
"""
|
|
50
55
|
Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node.
|
|
51
56
|
If the layer is not supposed to be wrapped with an activation quantizer - return None.
|
|
52
57
|
Args:
|
|
53
58
|
node: Node to attach a PytorchActivationQuantizationHolder to its output.
|
|
59
|
+
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
54
60
|
Returns:
|
|
55
61
|
A PytorchActivationQuantizationHolder module for the node's activation quantization.
|
|
56
62
|
"""
|
|
@@ -64,6 +70,7 @@ if FOUND_TORCH:
|
|
|
64
70
|
f'PytorchActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
|
|
65
71
|
f'were found for node {node}')
|
|
66
72
|
|
|
73
|
+
|
|
67
74
|
def get_exportable_pytorch_model(graph: Graph):
|
|
68
75
|
"""
|
|
69
76
|
Convert graph to fully quantized PyTorch model.
|