mct-nightly 2.1.0.20240622.419__py3-none-any.whl → 2.1.0.20240624.520__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.1.0.20240622.419.dist-info → mct_nightly-2.1.0.20240624.520.dist-info}/METADATA +1 -1
- {mct_nightly-2.1.0.20240622.419.dist-info → mct_nightly-2.1.0.20240624.520.dist-info}/RECORD +14 -14
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/functional_node.py +3 -4
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +10 -12
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +11 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py +21 -15
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_layer_norm.py +19 -17
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +193 -116
- model_compression_toolkit/core/runner.py +1 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +1 -1
- {mct_nightly-2.1.0.20240622.419.dist-info → mct_nightly-2.1.0.20240624.520.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.1.0.20240622.419.dist-info → mct_nightly-2.1.0.20240624.520.dist-info}/WHEEL +0 -0
- {mct_nightly-2.1.0.20240622.419.dist-info → mct_nightly-2.1.0.20240624.520.dist-info}/top_level.txt +0 -0
{mct_nightly-2.1.0.20240622.419.dist-info → mct_nightly-2.1.0.20240624.520.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=9NcQg8E0BkwMF32oeqvh_b8tuvTkx4OAmKJ_1q74DmE,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=9pVleMwnhlM4QwIL2HcEq42I1uF4rlSw63RUjkxOF4w,3923
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -7,7 +7,7 @@ model_compression_toolkit/core/__init__.py,sha256=TrRgkWpT1AN2Faw1M_1HXyJkJnbxfn
|
|
7
7
|
model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
|
8
8
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=kM70wmNG3yMFiGQc0uO0wn9j4ZbSWxUEykpxDK55doc,10567
|
9
9
|
model_compression_toolkit/core/quantization_prep_runner.py,sha256=0ga95vh_ZXO79r8FB26L5GIZKHkG98wq1hMsNH1bIeU,6453
|
10
|
-
model_compression_toolkit/core/runner.py,sha256=
|
10
|
+
model_compression_toolkit/core/runner.py,sha256=4TtOgyNb4cXr52dOlDqYxLm3rnLR6uHPDNoZiEFL9XA,12655
|
11
11
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
12
12
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
13
13
|
model_compression_toolkit/core/common/framework_implementation.py,sha256=8b6M1GcUR9bDgoxwqyNP8C6KSU9OTQ5hIk20Y74eLPo,20896
|
@@ -33,7 +33,7 @@ model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaU
|
|
33
33
|
model_compression_toolkit/core/common/graph/base_graph.py,sha256=lmIw0srKiwCvz7KWqfwKTxyQHDy3s6rWMIXzFAa1UMo,38326
|
34
34
|
model_compression_toolkit/core/common/graph/base_node.py,sha256=X_0zqHrKYAsmnj9tAKjVYasbFcZD8OHpjdiMj9ugQs0,29436
|
35
35
|
model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
|
36
|
-
model_compression_toolkit/core/common/graph/functional_node.py,sha256=
|
36
|
+
model_compression_toolkit/core/common/graph/functional_node.py,sha256=BbxQ-WRk4R-5hbpQDBANkhRRTkaG7eogeiJwLfLb_EU,3950
|
37
37
|
model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
|
38
38
|
model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-oL0lFO9PhQFt9fEFgVJwpc1u4fHExAtE,5128
|
39
39
|
model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=3el-A7j1oyoo1_9zq3faQp7IeRsFXFCvnrb3zZFXpU0,9803
|
@@ -222,7 +222,7 @@ model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,s
|
|
222
222
|
model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
|
223
223
|
model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
|
224
224
|
model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=D7lU1r9Uq_7fdNuKk2BMF8ho5GrsY-8gyGN6yYoHaVg,15060
|
225
|
-
model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=
|
225
|
+
model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=b3RJ9XpbN2XXlCXEVjxLg3NenmtFfnp_UBRKDIEka8A,18698
|
226
226
|
model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
|
227
227
|
model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
228
228
|
model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
|
@@ -233,9 +233,9 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchno
|
|
233
233
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=B7aC2TZNrQJ2oQVGBFhKAVqdUU5lYVJSMmwKhjxOHWk,2822
|
234
234
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=JDWOaNwYrZG0zTwd3HwoZUM3tKu7zPbzLOrqNQsu8xA,2162
|
235
235
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py,sha256=SBrR24ZAnWPftLinv4FuIqdBGjfYtfXbYQJN5mgy5V4,2861
|
236
|
-
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py,sha256=
|
237
|
-
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py,sha256=
|
238
|
-
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_layer_norm.py,sha256=
|
236
|
+
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py,sha256=iTuP1hjuTZTGcE7izfs_UOWBGeEBFRvRIU4QCh-b21M,4627
|
237
|
+
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py,sha256=7GZY7lU3LUUaO5iiccHkUP62PB0QeGAGOZdUSGMkFBY,4450
|
238
|
+
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_layer_norm.py,sha256=XhiLVcnCc_gF-6mjxbf9C4bYg5YL_GCvDJmcdLkBNAg,4151
|
239
239
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py,sha256=CXSMASpc_Zed3BJ2CsER69zKxE6ncFvvKQWDO1JxKYI,5849
|
240
240
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=VNg-VgzCxSyqy2J3neEPl6U0SPO8UIVU_T47bGhz4FE,38459
|
241
241
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py,sha256=q1a3HieQtaOmWG2WGXp6GHYAvxa3CZ9dJUx9dqMAsS8,5695
|
@@ -261,7 +261,7 @@ model_compression_toolkit/core/pytorch/quantizer/__init__.py,sha256=Rf1RcYmelmdZ
|
|
261
261
|
model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py,sha256=D8_CEuFqKAhbUgKaRw7Jlxo0zlqgPTMu6CIIIM4LfS0,7045
|
262
262
|
model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py,sha256=uyeBtNokyDUikk-YkDP_mN_2DX0J5oPm3kSfdSUT2Ck,4420
|
263
263
|
model_compression_toolkit/core/pytorch/reader/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
264
|
-
model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=
|
264
|
+
model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=ESL8k7RLZogTyG_oTTFDmm4RauZvx2gU-UvnOnEsH6Q,15948
|
265
265
|
model_compression_toolkit/core/pytorch/reader/node_holders.py,sha256=TaolORuwBZEddWe-q0Mg79Nmswz-Sq3-9-4o8UxFQ50,1028
|
266
266
|
model_compression_toolkit/core/pytorch/reader/reader.py,sha256=GEJE0QX8XJFWbYCkbRBtzttZtmmuoACLx8gw9KyAQCE,6015
|
267
267
|
model_compression_toolkit/core/pytorch/statistics_correction/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
@@ -425,7 +425,7 @@ model_compression_toolkit/target_platform_capabilities/target_platform/targetpla
|
|
425
425
|
model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attribute_filter.py,sha256=jfhszvuD2Fyy6W2KjlLzXBQKFzTqGAaDZeFVr4-ONQw,8776
|
426
426
|
model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/current_tpc.py,sha256=fIheShGOnxWYKqT8saHpBJqOU5RG_1Hp9qHry7IviIw,2115
|
427
427
|
model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/layer_filter_params.py,sha256=Cl6-mACpje2jM8RJkibbqE3hvTkFR3r26-lW021mIiA,4019
|
428
|
-
model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py,sha256=
|
428
|
+
model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py,sha256=iZDgHd0SVbgNTT-jtSP0SWsaRGfAJM_p-wpBlBkpRAQ,6723
|
429
429
|
model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py,sha256=KP8IWlHzkXzVjqIiRtAW6sTYyHJ2wVFFX4hMt_N6o3s,9910
|
430
430
|
model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities_component.py,sha256=FvrYI0Qy7DCmDp2gyUYyCZq5pY84JgLtJqSIiVTJ8Ss,1030
|
431
431
|
model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -517,8 +517,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
517
517
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=yrZNVRm2IRU7r7R-hjS2lOQ6wvEEvbeunvf2jKoWjXk,3277
|
518
518
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
519
519
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=eyMoXt5o5EnMr6d-rpCwQdX5mAiYiymvbgKv4tf7-a0,4576
|
520
|
-
mct_nightly-2.1.0.
|
521
|
-
mct_nightly-2.1.0.
|
522
|
-
mct_nightly-2.1.0.
|
523
|
-
mct_nightly-2.1.0.
|
524
|
-
mct_nightly-2.1.0.
|
520
|
+
mct_nightly-2.1.0.20240624.520.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
521
|
+
mct_nightly-2.1.0.20240624.520.dist-info/METADATA,sha256=0Lh6S3Ea0DK-D1dmGnRH-IwyzULmoho7PC7LXgUL5x0,19726
|
522
|
+
mct_nightly-2.1.0.20240624.520.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
523
|
+
mct_nightly-2.1.0.20240624.520.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
524
|
+
mct_nightly-2.1.0.20240624.520.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.1.0.
|
30
|
+
__version__ = "2.1.0.20240624.000520"
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Dict, Any, Tuple, Type
|
1
|
+
from typing import Dict, Any, Tuple, Type, List, Union
|
2
2
|
|
3
3
|
from model_compression_toolkit.constants import FOUND_TF
|
4
4
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
@@ -25,7 +25,7 @@ class FunctionalNode(BaseNode):
|
|
25
25
|
functional_op: Any = None,
|
26
26
|
inputs_as_list: bool = False,
|
27
27
|
has_activation: bool = True,
|
28
|
-
tensor_input_allocs = None):
|
28
|
+
tensor_input_allocs: List[Union[int, str]] = None):
|
29
29
|
"""
|
30
30
|
Init a FunctionalNode object.
|
31
31
|
|
@@ -44,8 +44,7 @@ class FunctionalNode(BaseNode):
|
|
44
44
|
functional_op: The op the node implements.
|
45
45
|
inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
|
46
46
|
has_activation: Whether the node has activations that we might want to quantize.
|
47
|
-
tensor_input_allocs: A list of indices for
|
48
|
-
|
47
|
+
tensor_input_allocs: A list of indices and strings for allocations input tensors in the node's args and kwargs.
|
49
48
|
"""
|
50
49
|
|
51
50
|
super().__init__(name,
|
@@ -106,7 +106,7 @@ def _run_operation(n: BaseNode,
|
|
106
106
|
input_tensors: List,
|
107
107
|
op_func: Any,
|
108
108
|
quantize_node_activation_fn,
|
109
|
-
use_activation_quantization: bool) -> Tuple[
|
109
|
+
use_activation_quantization: bool) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
110
110
|
"""
|
111
111
|
Applying the layer (op_func) to the input tensors (input_tensors).
|
112
112
|
If quantized is set to True, and the layer's corresponding node (n) has quantization
|
@@ -126,17 +126,17 @@ def _run_operation(n: BaseNode,
|
|
126
126
|
op_call_args = n.op_call_args if isinstance(n, FunctionalNode) else []
|
127
127
|
functional_kwargs = n.op_call_kwargs if isinstance(n, FunctionalNode) else {}
|
128
128
|
|
129
|
-
|
130
|
-
|
131
|
-
|
129
|
+
# Insert positional weights only when not a quantized functional node, because quantized functional nodes
|
130
|
+
# insert the quantized weights in the wrapper.
|
131
|
+
if isinstance(n, FunctionalNode) and isinstance(op_func, PytorchQuantizationWrapper):
|
132
|
+
_tensor_input_allocs = [i for i in n.tensor_input_allocs if i not in n.weights]
|
133
|
+
else:
|
132
134
|
input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
|
133
135
|
# convert inputs from positional weights (numpy arrays) to tensors. Must handle each element in the
|
134
136
|
# list separately, because in FX the tensors are FX objects and fail to_torch_tensor
|
135
137
|
input_tensors = [to_torch_tensor(t, numpy_type=t.dtype) if isinstance(t, np.ndarray) else t
|
136
138
|
for t in input_tensors]
|
137
139
|
_tensor_input_allocs = None
|
138
|
-
else:
|
139
|
-
_tensor_input_allocs = [i for i in n.tensor_input_allocs if i not in n.weights]
|
140
140
|
|
141
141
|
if isinstance(n, FunctionalNode) and n.inputs_as_list:
|
142
142
|
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
|
@@ -152,6 +152,8 @@ def _run_operation(n: BaseNode,
|
|
152
152
|
out_tensors_of_n_float = torch.cat(out_tensors_of_n_float, dim=0)
|
153
153
|
out_tensors_of_n = quantize_node_activation_fn(out_tensors_of_n_float)
|
154
154
|
|
155
|
+
if not isinstance(out_tensors_of_n, list):
|
156
|
+
out_tensors_of_n, out_tensors_of_n_float = [out_tensors_of_n], [out_tensors_of_n_float]
|
155
157
|
return out_tensors_of_n, out_tensors_of_n_float
|
156
158
|
|
157
159
|
|
@@ -318,12 +320,8 @@ class PytorchModel(torch.nn.Module):
|
|
318
320
|
quantize_node_activation_fn=activation_quantization_fn,
|
319
321
|
use_activation_quantization=use_activation_quantization)
|
320
322
|
|
321
|
-
|
322
|
-
|
323
|
-
node_to_output_tensors_dict_float.update({node: out_tensors_of_n_float})
|
324
|
-
else:
|
325
|
-
node_to_output_tensors_dict.update({node: [out_tensors_of_n]})
|
326
|
-
node_to_output_tensors_dict_float.update({node: [out_tensors_of_n_float]})
|
323
|
+
node_to_output_tensors_dict.update({node: out_tensors_of_n})
|
324
|
+
node_to_output_tensors_dict_float.update({node: out_tensors_of_n_float})
|
327
325
|
|
328
326
|
if self.append2output:
|
329
327
|
outputs = _generate_outputs(self.append2output,
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py
CHANGED
@@ -19,6 +19,7 @@ from model_compression_toolkit.logger import Logger
|
|
19
19
|
from model_compression_toolkit.core import common
|
20
20
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
21
21
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
22
|
+
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
22
23
|
from model_compression_toolkit.core.pytorch.constants import IN_CHANNELS, OUT_CHANNELS, KERNEL_SIZE, KERNEL, BIAS
|
23
24
|
from model_compression_toolkit.core.common import FrameworkInfo
|
24
25
|
|
@@ -37,7 +38,7 @@ class FunctionalConvSubstitution(common.BaseSubstitution):
|
|
37
38
|
|
38
39
|
def substitute(self,
|
39
40
|
graph: Graph,
|
40
|
-
func_node:
|
41
|
+
func_node: FunctionalNode) -> Graph:
|
41
42
|
"""
|
42
43
|
Substitute functional and conv/linear layer with torch layer
|
43
44
|
Args:
|
@@ -60,9 +61,15 @@ class FunctionalConvSubstitution(common.BaseSubstitution):
|
|
60
61
|
# Create new node of layer convolution
|
61
62
|
if 1 not in func_node.weights:
|
62
63
|
Logger.critical(f'Weight input missing for node {func_node.name}.') # pragma: no cover
|
63
|
-
|
64
|
-
|
65
|
-
|
64
|
+
# Extract index of kernel and bias according to tensor_input_allocs if they were input as kwargs. If
|
65
|
+
# they were input as args, use their fixed positions.
|
66
|
+
weight_index = func_node.tensor_input_allocs.index(KERNEL) if KERNEL in func_node.tensor_input_allocs else 1
|
67
|
+
bias_index = func_node.tensor_input_allocs.index(BIAS) if BIAS in func_node.tensor_input_allocs else 2
|
68
|
+
if weight_index not in func_node.weights:
|
69
|
+
Logger.critical(f'Mismatch between tensor_input_allocs and weight index in node {func_node.name}.') # pragma: no cover
|
70
|
+
weight = func_node.weights[weight_index]
|
71
|
+
bias = func_node.weights.get(bias_index)
|
72
|
+
framework_attr = func_node.op_call_kwargs
|
66
73
|
framework_attr.update({OUT_CHANNELS: weight.shape[out_channel_index]})
|
67
74
|
framework_attr.update({IN_CHANNELS: weight.shape[in_channel_index]})
|
68
75
|
framework_attr.update({KERNEL_SIZE: weight.shape[2:]})
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py
CHANGED
@@ -20,6 +20,7 @@ import torch.nn.functional as F
|
|
20
20
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
21
21
|
from model_compression_toolkit.core import common
|
22
22
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
23
|
+
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
23
24
|
from model_compression_toolkit.core.pytorch.constants import *
|
24
25
|
from model_compression_toolkit.logger import Logger
|
25
26
|
|
@@ -37,9 +38,12 @@ class FunctionalBatchNorm(common.BaseSubstitution):
|
|
37
38
|
super().__init__(matcher_instance=bn_node)
|
38
39
|
|
39
40
|
@staticmethod
|
40
|
-
def get_attributes_from_weights(node:
|
41
|
+
def get_attributes_from_weights(node: FunctionalNode) -> Dict:
|
41
42
|
"""
|
42
|
-
|
43
|
+
Convert functional batch_norm positional weights to BatchNorm2d weights. Extract indices of gamma
|
44
|
+
and beta according to tensor_input_allocs if they were input as kwargs. If they were input as args,
|
45
|
+
use their fixed positions.
|
46
|
+
|
43
47
|
Args:
|
44
48
|
node: functional batch_norm node.
|
45
49
|
|
@@ -53,23 +57,22 @@ class FunctionalBatchNorm(common.BaseSubstitution):
|
|
53
57
|
GAMMA: np.ones(node.weights[1].shape),
|
54
58
|
BETA: np.zeros(node.weights[1].shape)}
|
55
59
|
|
56
|
-
|
57
|
-
|
60
|
+
# Check if weight and/or bias were not given.
|
61
|
+
if KERNEL in node.tensor_input_allocs:
|
62
|
+
weights_dict[GAMMA] = node.weights[node.tensor_input_allocs.index(KERNEL)]
|
63
|
+
elif KERNEL not in node.op_call_kwargs:
|
64
|
+
weights_dict[GAMMA] = node.weights[3]
|
58
65
|
|
59
|
-
if
|
60
|
-
|
61
|
-
|
62
|
-
else:
|
63
|
-
weights_dict[BETA] = node.weights[3]
|
64
|
-
if 4 in node.weights:
|
65
|
-
assert has_bias
|
66
|
+
if BIAS in node.tensor_input_allocs:
|
67
|
+
weights_dict[BETA] = node.weights[node.tensor_input_allocs.index(BIAS)]
|
68
|
+
elif BIAS not in node.op_call_kwargs:
|
66
69
|
weights_dict[BETA] = node.weights[4]
|
67
70
|
|
68
71
|
return weights_dict
|
69
72
|
|
70
73
|
def substitute(self,
|
71
74
|
graph: Graph,
|
72
|
-
node:
|
75
|
+
node: FunctionalNode) -> Graph:
|
73
76
|
"""
|
74
77
|
Substitute functional.batch_norm and its inputs with BatchNorm2d.
|
75
78
|
Args:
|
@@ -87,10 +90,13 @@ class FunctionalBatchNorm(common.BaseSubstitution):
|
|
87
90
|
bn_node_weights = self.get_attributes_from_weights(node)
|
88
91
|
if not bn_node_weights:
|
89
92
|
return graph
|
93
|
+
framework_attr = {NUM_FEATURES: out_channels}
|
94
|
+
if EPSILON in node.op_call_kwargs:
|
95
|
+
framework_attr.update({EPSILON: node.op_call_kwargs[EPSILON]})
|
96
|
+
if MOMENTUM in node.op_call_kwargs:
|
97
|
+
framework_attr.update({MOMENTUM: node.op_call_kwargs[MOMENTUM]})
|
90
98
|
new_batchnorm2d = BaseNode(name=node.name + '_into_BatchNorm2d',
|
91
|
-
framework_attr=
|
92
|
-
EPSILON: EPSILON_VAL,
|
93
|
-
MOMENTUM: MOMENTUM_VAL},
|
99
|
+
framework_attr=framework_attr,
|
94
100
|
input_shape=node.output_shape,
|
95
101
|
output_shape=node.output_shape,
|
96
102
|
weights=bn_node_weights,
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_layer_norm.py
CHANGED
@@ -21,6 +21,7 @@ from typing import Dict, Tuple, List
|
|
21
21
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
22
22
|
from model_compression_toolkit.core import common
|
23
23
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
24
|
+
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
24
25
|
from model_compression_toolkit.core.pytorch.constants import *
|
25
26
|
from model_compression_toolkit.logger import Logger
|
26
27
|
|
@@ -38,9 +39,11 @@ class FunctionalLayerNorm(common.BaseSubstitution):
|
|
38
39
|
super().__init__(matcher_instance=ln_node)
|
39
40
|
|
40
41
|
@staticmethod
|
41
|
-
def get_attributes_from_weights(node:
|
42
|
+
def get_attributes_from_weights(node: FunctionalNode, normalized_shape: [Tuple, List, int]) -> Dict:
|
42
43
|
"""
|
43
|
-
|
44
|
+
Convert functional layer_norm positional weights to LayerNorm weights. Extract indices of gamma
|
45
|
+
and beta according to tensor_input_allocs if they were input as kwargs. If they were input as args,
|
46
|
+
use their fixed positions.
|
44
47
|
Args:
|
45
48
|
node: Node that match the pattern in the substitution init.
|
46
49
|
normalized_shape: nn.LayerNorm "normalized_shape" argument
|
@@ -50,28 +53,26 @@ class FunctionalLayerNorm(common.BaseSubstitution):
|
|
50
53
|
"""
|
51
54
|
|
52
55
|
# Define default weight and bias
|
53
|
-
weights_dict = {GAMMA: np.ones(normalized_shape),
|
54
|
-
BETA: np.zeros(normalized_shape)
|
56
|
+
weights_dict = {GAMMA: np.ones(normalized_shape), # Default value in case weight is not given
|
57
|
+
BETA: np.zeros(normalized_shape) # Default value in case bias is not given
|
55
58
|
}
|
56
59
|
|
57
60
|
# Check if weight and/or bias were not given.
|
58
|
-
|
59
|
-
|
61
|
+
if KERNEL in node.tensor_input_allocs:
|
62
|
+
weights_dict[GAMMA] = node.weights[node.tensor_input_allocs.index(KERNEL)]
|
63
|
+
elif KERNEL not in node.op_call_kwargs:
|
64
|
+
weights_dict[GAMMA] = node.weights[1]
|
60
65
|
|
61
|
-
if
|
62
|
-
|
63
|
-
|
64
|
-
else:
|
65
|
-
weights_dict[BETA] = node.weights[1]
|
66
|
-
if 2 in node.weights:
|
67
|
-
assert has_bias
|
66
|
+
if BIAS in node.tensor_input_allocs:
|
67
|
+
weights_dict[BETA] = node.weights[node.tensor_input_allocs.index(BIAS)]
|
68
|
+
elif BIAS not in node.op_call_kwargs:
|
68
69
|
weights_dict[BETA] = node.weights[2]
|
69
70
|
|
70
71
|
return weights_dict
|
71
72
|
|
72
73
|
def substitute(self,
|
73
74
|
graph: Graph,
|
74
|
-
node:
|
75
|
+
node: FunctionalNode) -> Graph:
|
75
76
|
"""
|
76
77
|
Substitute functional.layer_norm and its inputs with LayerNorm.
|
77
78
|
Args:
|
@@ -85,10 +86,11 @@ class FunctionalLayerNorm(common.BaseSubstitution):
|
|
85
86
|
|
86
87
|
ln_node_weights = self.get_attributes_from_weights(node, normalized_shape)
|
87
88
|
|
89
|
+
framework_attr = {NORMALIZED_SHAPE: normalized_shape}
|
90
|
+
if EPSILON in node.op_call_kwargs:
|
91
|
+
framework_attr.update({EPSILON: node.op_call_kwargs[EPSILON]})
|
88
92
|
new_layernorm = BaseNode(name=node.name + '_into_LayerNorm',
|
89
|
-
framework_attr=
|
90
|
-
EPSILON: node.framework_attr.get('eps'),
|
91
|
-
},
|
93
|
+
framework_attr=framework_attr,
|
92
94
|
input_shape=node.output_shape,
|
93
95
|
output_shape=node.output_shape,
|
94
96
|
weights=ln_node_weights,
|
@@ -13,11 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import inspect
|
16
|
-
from
|
16
|
+
from operator import getitem
|
17
|
+
from typing import Dict, List, Tuple, Callable, Union, Any, Type
|
18
|
+
|
19
|
+
import numpy as np
|
17
20
|
import torch
|
18
21
|
from torch.fx import GraphModule, Node
|
19
22
|
|
20
|
-
from model_compression_toolkit.core import common
|
21
23
|
from model_compression_toolkit.core.common import BaseNode
|
22
24
|
from model_compression_toolkit.core.common.graph.base_graph import OutTensor
|
23
25
|
from model_compression_toolkit.core.common.graph.edge import Edge
|
@@ -28,29 +30,131 @@ from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlac
|
|
28
30
|
from model_compression_toolkit.logger import Logger
|
29
31
|
|
30
32
|
|
31
|
-
def
|
33
|
+
def _extract_parameters_and_buffers(module: Union[torch.nn.Module, GraphModule],
|
34
|
+
to_numpy: Callable) -> Dict[str, np.ndarray]:
|
32
35
|
"""
|
33
|
-
Extract
|
36
|
+
Extract parameters & buffers from input module to a dictionary.
|
34
37
|
Args:
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
38
|
+
module: FX ot PyTorch module to extract parameters and buffers from.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
Dictionary containing module parameters and buffers by name.
|
42
|
+
"""
|
43
|
+
|
44
|
+
named_parameters = {name: to_numpy(parameter) for name, parameter in module.named_parameters()}
|
45
|
+
named_buffers = {name: to_numpy(buffer) for name, buffer in module.named_buffers()}
|
46
|
+
|
47
|
+
return {**named_parameters, **named_buffers}
|
48
|
+
|
49
|
+
|
50
|
+
def is_instance_first_arg(n: Node, expected_type: Union[Type, Tuple[Type]]) -> bool:
|
51
|
+
"""
|
52
|
+
Check whether first argument of the node is the expected type
|
53
|
+
Args:
|
54
|
+
n: fx node.
|
55
|
+
expected_type: Expected 1st argument type.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
True is the first argument of node n is of the expected type, else return False.
|
59
|
+
|
60
|
+
"""
|
61
|
+
return len(n.args) > 0 and isinstance(n.args[0], expected_type)
|
62
|
+
|
63
|
+
|
64
|
+
def _build_input_alloc_and_call_args(n: Node, input_tensors_in_node_kwargs: Dict,
|
65
|
+
inputs_as_list: bool) -> Tuple[List, List]:
|
66
|
+
"""
|
67
|
+
Build the tensor inputs list and op_call_args of the functional node.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
n: fx node.
|
71
|
+
input_tensors_in_node_kwargs: A dictionary of node kwarg name and input fx node.
|
72
|
+
inputs_as_list: Is node's inputs are a list.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
A list of updated op_call args.
|
76
|
+
A list of tensor allocations in node's inputs.
|
77
|
+
|
78
|
+
"""
|
79
|
+
|
80
|
+
tensor_input_alloc = []
|
81
|
+
op_call_args = list(n.args)
|
82
|
+
if inputs_as_list:
|
83
|
+
op_call_args.pop(0)
|
84
|
+
else:
|
85
|
+
for in_node in n.all_input_nodes:
|
86
|
+
# The extra for loop is used to tackle the case of the same input tensor for this node (e.g. torch.add(x, x)).
|
87
|
+
for i, arg in enumerate(n.args):
|
88
|
+
if arg == in_node:
|
89
|
+
tensor_input_alloc.append(i)
|
90
|
+
for k, arg in input_tensors_in_node_kwargs.items():
|
91
|
+
if arg == in_node:
|
92
|
+
tensor_input_alloc.append(k)
|
93
|
+
|
94
|
+
return op_call_args, tensor_input_alloc
|
95
|
+
|
96
|
+
|
97
|
+
def _extract_torch_layer_data(node_module: torch.nn.Module,
|
98
|
+
to_numpy: Callable) -> Tuple[Any, Dict[str, np.ndarray], Dict]:
|
99
|
+
"""
|
100
|
+
Extract required data from a non-functional node to rebuild the PyTorch layer.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
node_module: Torch layer, such as nn.Conv2d, nn.Linear, etc.
|
39
104
|
to_numpy: Function to convert framework's tensor to a Numpy array.
|
40
105
|
|
41
106
|
Returns:
|
42
|
-
|
107
|
+
Node layer class.
|
108
|
+
A mapping between the layer's named parameters and buffers to their tensor values.
|
109
|
+
A framework_attr dictionary required to instantiate the node with the layer class.
|
110
|
+
"""
|
111
|
+
node_type = type(node_module)
|
112
|
+
if not isinstance(node_module, torch.nn.Module):
|
113
|
+
Logger.error(f"Expected an instance of torch.nn.Module for node {node_module.name}, but got {node_type}")
|
114
|
+
# Extract the instance framework_attr (i.e. the arguments the class instance was initialized with). "fullargspec"
|
115
|
+
# is a list of the layer's attribute names, that will be used as keys of the framework_attr dictionary. We the
|
116
|
+
# values from the layer instance.
|
117
|
+
fullargspec = inspect.getfullargspec(node_type.__init__).args
|
118
|
+
framework_attr = {k: v for k, v in node_module.__dict__.items() if k in fullargspec}
|
119
|
+
# The "bias" argument doesn't appear in the node_module.__dict__, so we add it manually.
|
120
|
+
if hasattr(node_module, BIAS) and BIAS in fullargspec:
|
121
|
+
framework_attr[BIAS] = False if node_module.bias is None else True
|
122
|
+
|
123
|
+
# Extract layer weights and named buffers.
|
124
|
+
weights = {n: w for n, w in _extract_parameters_and_buffers(node_module, to_numpy).items() if len(w.shape) > 0}
|
125
|
+
return node_type, weights, framework_attr
|
126
|
+
|
127
|
+
|
128
|
+
def _extract_input_and_output_shapes(_node: Node) -> Tuple[List, List]:
|
129
|
+
"""
|
130
|
+
Extract input and output shapes of a node.
|
131
|
+
Args:
|
132
|
+
_node: fx node.
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
Input and output shapes as lists.
|
43
136
|
"""
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
137
|
+
input_shape = []
|
138
|
+
if _node.op != PLACEHOLDER:
|
139
|
+
for i, input_node in enumerate(_node.all_input_nodes):
|
140
|
+
tensor_meta = input_node.meta
|
141
|
+
if tensor_meta[TYPE] in [torch.Tensor, torch.nn.parameter.Parameter]:
|
142
|
+
input_shape += [list(tensor_meta[TENSOR_META].shape)]
|
143
|
+
elif tensor_meta[TYPE] == tuple:
|
144
|
+
input_shape += [list(n.shape) for n in tensor_meta[TENSOR_META]]
|
145
|
+
elif tensor_meta[TYPE] == int:
|
146
|
+
input_shape += [[1]]
|
147
|
+
|
148
|
+
if _node.meta[TYPE] == torch.Tensor:
|
149
|
+
output_shape = [list(_node.meta[TENSOR_META].shape)]
|
150
|
+
elif _node.meta[TYPE] in (list, tuple):
|
151
|
+
output_shape = [list(m.shape) for m in _node.meta[TENSOR_META]]
|
152
|
+
elif _node.meta[TYPE] == int:
|
153
|
+
output_shape = [[1]]
|
154
|
+
else:
|
155
|
+
output_shape = []
|
50
156
|
|
51
|
-
|
52
|
-
weights.update(named_buffer_weights)
|
53
|
-
return weights
|
157
|
+
return input_shape, output_shape
|
54
158
|
|
55
159
|
|
56
160
|
def nodes_builder(model: GraphModule,
|
@@ -67,135 +171,104 @@ def nodes_builder(model: GraphModule,
|
|
67
171
|
Returns:
|
68
172
|
A list of Graph nodes that were built from the fx GraphModule nodes.
|
69
173
|
"""
|
70
|
-
#
|
71
|
-
inputs = []
|
72
|
-
|
73
|
-
nodes = []
|
74
|
-
output_nodes = []
|
174
|
+
# Init function variables:
|
175
|
+
inputs, outputs = [], []
|
176
|
+
nodes, output_nodes = [], []
|
75
177
|
fx_node_2_graph_node = {}
|
76
178
|
consts_dict = {}
|
77
179
|
used_consts = set()
|
78
180
|
|
181
|
+
# Init parameters & buffers dictionary of the entire model. We later extract the constants values from this dictionary.
|
182
|
+
model_parameters_and_buffers = _extract_parameters_and_buffers(model, to_numpy)
|
183
|
+
|
79
184
|
for node in model.graph.nodes:
|
80
|
-
|
81
|
-
|
185
|
+
|
186
|
+
# ##############################################
|
187
|
+
# Extract node type and framework attributes #
|
188
|
+
# ##############################################
|
189
|
+
weights = {}
|
190
|
+
framework_attr = {}
|
82
191
|
node_has_activation = True
|
192
|
+
|
83
193
|
if node.target in module_dict.keys():
|
84
|
-
|
85
|
-
node_type =
|
86
|
-
|
87
|
-
fullargspec = inspect.getfullargspec(node_type.__init__).args
|
88
|
-
framework_attr = {k: v for k, v in framework_attr.items() if k in fullargspec}
|
89
|
-
if hasattr(node_module, BIAS) and BIAS in fullargspec:
|
90
|
-
framework_attr[BIAS] = False if node_module.bias is None else True
|
194
|
+
# PyTorch module node, such as nn.Conv2d or nn.Linear.
|
195
|
+
node_type, weights, framework_attr = _extract_torch_layer_data(module_dict[node.target], to_numpy)
|
196
|
+
|
91
197
|
elif node.op == CALL_FUNCTION:
|
198
|
+
# Node is a function that handle a parameter\buffer in the model.
|
92
199
|
node_type = node.target
|
93
|
-
if node_type
|
200
|
+
if node_type in [getattr, getitem]:
|
94
201
|
node_has_activation = False
|
95
|
-
|
96
|
-
'Pytorch model has a parameter or constant Tensor value. This can cause unexpected behaviour when '
|
97
|
-
'converting the model.')
|
202
|
+
|
98
203
|
elif node.op == PLACEHOLDER:
|
204
|
+
# Input node to the model.
|
99
205
|
node_type = DummyPlaceHolder
|
206
|
+
|
100
207
|
elif node.op == OUTPUT:
|
208
|
+
# Output node of the model. Only saved in output_nodes for later handling.
|
101
209
|
output_nodes += node.all_input_nodes
|
102
210
|
continue
|
211
|
+
|
103
212
|
elif node.op == CALL_METHOD:
|
213
|
+
# Node is a PyTorch function such as torch.add, torch.reshape etc.
|
104
214
|
if hasattr(torch, node.target):
|
105
215
|
node_type = getattr(torch, node.target)
|
106
216
|
elif hasattr(torch.Tensor, node.target):
|
107
217
|
node_type = getattr(torch.Tensor, node.target)
|
108
218
|
else:
|
109
|
-
Logger.critical(f"The call method '{node.target}' is not supported.")
|
110
|
-
elif node.op == GET_ATTR:
|
111
|
-
Logger.warning(
|
112
|
-
'Pytorch model has a parameter or constant Tensor value. This can cause unexpected behaviour when '
|
113
|
-
'converting the model.')
|
114
|
-
else:
|
115
|
-
Logger.critical(f'Encountered an unsupported node type in node: {node.name}.')
|
219
|
+
Logger.critical(f"The call method '{node.target}' in {node} is not supported.")
|
116
220
|
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
named_parameters_weights = {name: to_numpy(parameter) for name, parameter in
|
121
|
-
module_dict[node.target].named_parameters()}
|
122
|
-
named_buffer_weights = {name: to_numpy(parameter) for name, parameter in
|
123
|
-
module_dict[node.target].named_buffers() if len(parameter.shape) > 0}
|
124
|
-
weights.update(named_parameters_weights)
|
125
|
-
weights.update(named_buffer_weights)
|
126
|
-
|
127
|
-
if node.op == GET_ATTR:
|
128
|
-
new_const = extract_holder_weights(node, node.target, model, weights, to_numpy)
|
129
|
-
if list(new_const.keys())[0] in consts_dict:
|
221
|
+
elif node.op == GET_ATTR:
|
222
|
+
# Node holding a constant -> add to consts_dict so can add them later to weights of next node.
|
223
|
+
if node.target in consts_dict:
|
130
224
|
Logger.critical('A constant weight appears to have been recorded multiple times.')
|
131
|
-
consts_dict.
|
225
|
+
consts_dict[node] = model_parameters_and_buffers[node.target]
|
132
226
|
continue
|
227
|
+
else:
|
228
|
+
Logger.critical(f'Encountered an unsupported node type in node: {node.name}.')
|
133
229
|
|
134
|
-
#
|
135
|
-
input_shape = []
|
230
|
+
# Add constants to weights dictionary.
|
136
231
|
if node.op != PLACEHOLDER:
|
137
232
|
for i, input_node in enumerate(node.all_input_nodes):
|
138
233
|
if input_node in consts_dict:
|
139
234
|
used_consts.add(input_node)
|
140
235
|
weights.update({i: consts_dict[input_node]})
|
141
236
|
|
142
|
-
|
143
|
-
|
144
|
-
input_shape += [list(tensor_meta[TENSOR_META].shape)]
|
145
|
-
elif tensor_meta[TYPE] == tuple:
|
146
|
-
input_shape += [list(n.shape) for n in tensor_meta[TENSOR_META]]
|
147
|
-
elif tensor_meta[TYPE] == int:
|
148
|
-
input_shape += [[1]]
|
149
|
-
|
150
|
-
# extract output shapes
|
151
|
-
if node.meta[TYPE] == torch.Tensor:
|
152
|
-
output_shape = [list(node.meta[TENSOR_META].shape)]
|
153
|
-
elif node.meta[TYPE] in (list, tuple):
|
154
|
-
output_shape = [list(m.shape) for m in node.meta[TENSOR_META]]
|
155
|
-
elif node.meta[TYPE] == int:
|
156
|
-
output_shape = [[1]]
|
157
|
-
else:
|
158
|
-
output_shape = []
|
159
|
-
|
160
|
-
# filter Nodes from framework attributes, we replace these attributes with nx graph nodes
|
161
|
-
framework_attr_filtered = {}
|
162
|
-
framework_attr_nodes = {}
|
163
|
-
for k, v in framework_attr.items():
|
164
|
-
if isinstance(v, torch.fx.node.Node):
|
165
|
-
framework_attr_nodes[k] = v
|
166
|
-
else:
|
167
|
-
framework_attr_filtered[k] = v
|
168
|
-
framework_attr = framework_attr_filtered
|
169
|
-
|
170
|
-
# filter Nodes from node kwargs, we replace these attributes with nx graph nodes
|
171
|
-
node_kwargs = {}
|
172
|
-
for k, v in node.kwargs.items():
|
173
|
-
if not isinstance(v, torch.fx.node.Node):
|
174
|
-
node_kwargs[k] = v
|
237
|
+
# Extract input and output shapes of the node.
|
238
|
+
input_shape, output_shape = _extract_input_and_output_shapes(node)
|
175
239
|
|
176
|
-
#
|
240
|
+
# Initiate graph nodes.
|
177
241
|
if node.op in [CALL_METHOD, CALL_FUNCTION]:
|
178
242
|
graph_node_type = FunctionalNode
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
243
|
+
|
244
|
+
# Filter FX nodes from node_kwargs. These FX nodes are tensor inputs to the node that are part of the
|
245
|
+
# model's graph. We remove them because the node_kwargs should not include input tensors of the node.
|
246
|
+
# These input tensors will be inserted in the kwargs according to the tensor_input_alloc which is used
|
247
|
+
# to convert the input_tensors list in the builder to the node's args & kwargs.
|
248
|
+
node_kwargs, input_tensors_in_node_kwargs = {}, {}
|
249
|
+
for k, v in node.kwargs.items():
|
250
|
+
if isinstance(v, Node):
|
251
|
+
input_tensors_in_node_kwargs[k] = v
|
252
|
+
else:
|
253
|
+
node_kwargs[k] = v
|
254
|
+
|
255
|
+
# Check if node's first input argument is a list of input fx nodes, such as torch.cat:
|
256
|
+
is_first_input_list_of_nodes = is_instance_first_arg(node, (list, tuple)) and all(
|
257
|
+
[isinstance(n, Node) for n in node.args[0]])
|
258
|
+
is_placeholder_a_list = is_instance_first_arg(node, Node) and \
|
259
|
+
node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple)
|
260
|
+
inputs_as_list = is_first_input_list_of_nodes or is_placeholder_a_list
|
261
|
+
|
262
|
+
# Build tensor_input_alloc required for the model builder. All input nodes are received as a list in the builder,
|
263
|
+
# so tensor_input_alloc is used to allocate each input tensor in the correct place in the node's args & kwargs.
|
264
|
+
op_call_args, tensor_input_alloc = _build_input_alloc_and_call_args(node, input_tensors_in_node_kwargs,
|
265
|
+
inputs_as_list)
|
266
|
+
|
267
|
+
# Remove torch.fx.node.Node from inputs to the functional node. FX nodes are input tensors in the builder,
|
268
|
+
# so they are remove from the op_call_args (same as op_call_kwargs) and are inserted back according to the
|
269
|
+
# tensor_input_alloc list.
|
197
270
|
op_call_args = [arg for arg in op_call_args if not isinstance(arg, Node)]
|
198
|
-
#
|
271
|
+
# Convert torch.fx.immutable_collections.immutable_list to tuple.
|
199
272
|
op_call_args = [tuple(arg) if isinstance(arg, torch.fx.immutable_collections.immutable_list) else arg
|
200
273
|
for arg in op_call_args]
|
201
274
|
|
@@ -205,8 +278,12 @@ def nodes_builder(model: GraphModule,
|
|
205
278
|
INPUTS_AS_LIST: inputs_as_list,
|
206
279
|
TENSOR_INPUT_ALLOCS: tensor_input_alloc}
|
207
280
|
else:
|
281
|
+
if not all([not isinstance(v, Node) for v in framework_attr.values()]):
|
282
|
+
Logger.critical(f'Found FX nodes in framework attributes of {node.name}. This node type should not contain any.') # pragma: no cover
|
283
|
+
|
208
284
|
graph_node_type = BaseNode
|
209
285
|
kwargs = {}
|
286
|
+
|
210
287
|
graph_node = graph_node_type(name=node.name,
|
211
288
|
framework_attr=framework_attr,
|
212
289
|
input_shape=input_shape,
|
@@ -216,7 +293,7 @@ def nodes_builder(model: GraphModule,
|
|
216
293
|
has_activation=node_has_activation,
|
217
294
|
**kwargs)
|
218
295
|
|
219
|
-
#
|
296
|
+
# Generate graph inputs list.
|
220
297
|
if node.op == PLACEHOLDER:
|
221
298
|
for ii in range(len(output_shape)):
|
222
299
|
inputs.append(graph_node)
|
@@ -224,12 +301,12 @@ def nodes_builder(model: GraphModule,
|
|
224
301
|
fx_node_2_graph_node[node] = graph_node
|
225
302
|
nodes.append(graph_node)
|
226
303
|
|
227
|
-
#
|
304
|
+
# Check whether all extracted constants were used in the graph.
|
228
305
|
not_connected_consts = [c for c in consts_dict if c not in used_consts]
|
229
306
|
if not_connected_consts:
|
230
|
-
Logger.critical(f'Error reading graph: These constants are not connected in the graph: {not_connected_consts}.')
|
307
|
+
Logger.critical(f'Error reading graph: These constants are not connected in the graph: {not_connected_consts}.') # pragma: no cover
|
231
308
|
|
232
|
-
#
|
309
|
+
# Generate graph outputs list.
|
233
310
|
for node in output_nodes:
|
234
311
|
outputs.append(OutTensor(fx_node_2_graph_node[node], output_nodes.index(node)))
|
235
312
|
|
@@ -216,7 +216,7 @@ def _set_final_resource_utilization(graph: Graph,
|
|
216
216
|
# No relevant nodes have been quantized with affect on the given target - since we only consider
|
217
217
|
# in the model's final size the quantized layers size, this means that the final size for this target
|
218
218
|
# is zero.
|
219
|
-
Logger.warning(f"No relevant quantized layers for the ru target {ru_target} were found, the recorded"
|
219
|
+
Logger.warning(f"No relevant quantized layers for the ru target {ru_target} were found, the recorded "
|
220
220
|
f"final ru for this target would be 0.")
|
221
221
|
final_ru_dict[ru_target] = 0
|
222
222
|
|
@@ -148,6 +148,6 @@ class OperationsToLayers:
|
|
148
148
|
qco_by_opset_name = _current_tpc.get().tp_model.get_config_options_by_operators_set(ops2layers.name)
|
149
149
|
if layer in existing_layers:
|
150
150
|
Logger.critical(f'Found layer {layer.__name__} in more than one '
|
151
|
-
|
151
|
+
f'OperatorsSet') # pragma: no cover
|
152
152
|
else:
|
153
153
|
existing_layers.update({layer: qco_by_opset_name})
|
{mct_nightly-2.1.0.20240622.419.dist-info → mct_nightly-2.1.0.20240624.520.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.1.0.20240622.419.dist-info → mct_nightly-2.1.0.20240624.520.dist-info}/top_level.txt
RENAMED
File without changes
|