mct-nightly 2.1.0.20240623.439__py3-none-any.whl → 2.1.0.20240625.423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.1.0.20240623.439
3
+ Version: 2.1.0.20240625.423
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -77,11 +77,13 @@ for hands-on learning. For example:
77
77
 
78
78
  Currently, MCT is being tested on various Python, Pytorch and TensorFlow versions:
79
79
 
80
- | | PyTorch 1.13 | PyTorch 2.0 | PyTorch 2.1 |
80
+
81
+ | | PyTorch 2.1 | PyTorch 2.2 | PyTorch 2.3 |
81
82
  |-------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
82
- | Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch113.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch113.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch20.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch20.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch21.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch21.yml) |
83
- | Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch112.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch112.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch113.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch113.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch20.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch20.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch21.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch21.yml) |
84
- | Python 3.11 | | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch20.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch20.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch21.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch21.yml) |
83
+ | Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch21.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch21.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch23.yml) |
84
+ | Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch21.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch21.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch23.yml) |
85
+ | Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch21.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch21.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml) |
86
+
85
87
 
86
88
 
87
89
  | | TensorFlow 2.12 | TensorFlow 2.13 | TensorFlow 2.14 | TensorFlow 2.15 |
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=nKuqjriGQEh_l4SctkdKW63z6vsnfkDyKoTdGEiQAbI,1573
1
+ model_compression_toolkit/__init__.py,sha256=8DfLm4qcvVO15TSGHhbmoV4qidIIU5TOEe0D8eMNg1M,1573
2
2
  model_compression_toolkit/constants.py,sha256=9pVleMwnhlM4QwIL2HcEq42I1uF4rlSw63RUjkxOF4w,3923
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -7,7 +7,7 @@ model_compression_toolkit/core/__init__.py,sha256=TrRgkWpT1AN2Faw1M_1HXyJkJnbxfn
7
7
  model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
8
8
  model_compression_toolkit/core/graph_prep_runner.py,sha256=kM70wmNG3yMFiGQc0uO0wn9j4ZbSWxUEykpxDK55doc,10567
9
9
  model_compression_toolkit/core/quantization_prep_runner.py,sha256=0ga95vh_ZXO79r8FB26L5GIZKHkG98wq1hMsNH1bIeU,6453
10
- model_compression_toolkit/core/runner.py,sha256=kMr8pK2Z_F1Fl3uHf6ymeNKEH1NaPWQjEGEqM7sRn04,12654
10
+ model_compression_toolkit/core/runner.py,sha256=4TtOgyNb4cXr52dOlDqYxLm3rnLR6uHPDNoZiEFL9XA,12655
11
11
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
12
12
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
13
13
  model_compression_toolkit/core/common/framework_implementation.py,sha256=8b6M1GcUR9bDgoxwqyNP8C6KSU9OTQ5hIk20Y74eLPo,20896
@@ -33,7 +33,7 @@ model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaU
33
33
  model_compression_toolkit/core/common/graph/base_graph.py,sha256=lmIw0srKiwCvz7KWqfwKTxyQHDy3s6rWMIXzFAa1UMo,38326
34
34
  model_compression_toolkit/core/common/graph/base_node.py,sha256=X_0zqHrKYAsmnj9tAKjVYasbFcZD8OHpjdiMj9ugQs0,29436
35
35
  model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
36
- model_compression_toolkit/core/common/graph/functional_node.py,sha256=_6HsBeLlrpLvXhLPRJswcyDa4z16-O3xzHzGuv46zBc,3897
36
+ model_compression_toolkit/core/common/graph/functional_node.py,sha256=BbxQ-WRk4R-5hbpQDBANkhRRTkaG7eogeiJwLfLb_EU,3950
37
37
  model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
38
38
  model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-oL0lFO9PhQFt9fEFgVJwpc1u4fHExAtE,5128
39
39
  model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=3el-A7j1oyoo1_9zq3faQp7IeRsFXFCvnrb3zZFXpU0,9803
@@ -222,7 +222,7 @@ model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,s
222
222
  model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
223
223
  model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
224
224
  model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=D7lU1r9Uq_7fdNuKk2BMF8ho5GrsY-8gyGN6yYoHaVg,15060
225
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=Pvpkirt3ziWEXDEspgOhR8ALf-XAZUh-78IkXg9YMWs,18830
225
+ model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=b3RJ9XpbN2XXlCXEVjxLg3NenmtFfnp_UBRKDIEka8A,18698
226
226
  model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
227
227
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
228
228
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
@@ -233,9 +233,9 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchno
233
233
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=B7aC2TZNrQJ2oQVGBFhKAVqdUU5lYVJSMmwKhjxOHWk,2822
234
234
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=JDWOaNwYrZG0zTwd3HwoZUM3tKu7zPbzLOrqNQsu8xA,2162
235
235
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py,sha256=SBrR24ZAnWPftLinv4FuIqdBGjfYtfXbYQJN5mgy5V4,2861
236
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py,sha256=Fs2YQBD4KJV-pGLOMqm-p485bfq2JDYgCzFroRljCoM,3933
237
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py,sha256=iX8bLHtw2osP42-peNLTRmbpX3cUxdGsAbEfw7NLpx0,3935
238
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_layer_norm.py,sha256=zKSgtVw_P9fUvdq4e7P9yaLDPG_vZ0cecM9sVPtm1ns,3799
236
+ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py,sha256=iTuP1hjuTZTGcE7izfs_UOWBGeEBFRvRIU4QCh-b21M,4627
237
+ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py,sha256=7GZY7lU3LUUaO5iiccHkUP62PB0QeGAGOZdUSGMkFBY,4450
238
+ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_layer_norm.py,sha256=XhiLVcnCc_gF-6mjxbf9C4bYg5YL_GCvDJmcdLkBNAg,4151
239
239
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py,sha256=CXSMASpc_Zed3BJ2CsER69zKxE6ncFvvKQWDO1JxKYI,5849
240
240
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=VNg-VgzCxSyqy2J3neEPl6U0SPO8UIVU_T47bGhz4FE,38459
241
241
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py,sha256=q1a3HieQtaOmWG2WGXp6GHYAvxa3CZ9dJUx9dqMAsS8,5695
@@ -261,7 +261,7 @@ model_compression_toolkit/core/pytorch/quantizer/__init__.py,sha256=Rf1RcYmelmdZ
261
261
  model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py,sha256=D8_CEuFqKAhbUgKaRw7Jlxo0zlqgPTMu6CIIIM4LfS0,7045
262
262
  model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py,sha256=uyeBtNokyDUikk-YkDP_mN_2DX0J5oPm3kSfdSUT2Ck,4420
263
263
  model_compression_toolkit/core/pytorch/reader/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
264
- model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=LiGV-ZqlhxN1evpM-ur2dDVPowhrLwO7JZa7AGPftSk,12913
264
+ model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=ESL8k7RLZogTyG_oTTFDmm4RauZvx2gU-UvnOnEsH6Q,15948
265
265
  model_compression_toolkit/core/pytorch/reader/node_holders.py,sha256=TaolORuwBZEddWe-q0Mg79Nmswz-Sq3-9-4o8UxFQ50,1028
266
266
  model_compression_toolkit/core/pytorch/reader/reader.py,sha256=GEJE0QX8XJFWbYCkbRBtzttZtmmuoACLx8gw9KyAQCE,6015
267
267
  model_compression_toolkit/core/pytorch/statistics_correction/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
@@ -425,7 +425,7 @@ model_compression_toolkit/target_platform_capabilities/target_platform/targetpla
425
425
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attribute_filter.py,sha256=jfhszvuD2Fyy6W2KjlLzXBQKFzTqGAaDZeFVr4-ONQw,8776
426
426
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/current_tpc.py,sha256=fIheShGOnxWYKqT8saHpBJqOU5RG_1Hp9qHry7IviIw,2115
427
427
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/layer_filter_params.py,sha256=Cl6-mACpje2jM8RJkibbqE3hvTkFR3r26-lW021mIiA,4019
428
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py,sha256=1JN3yvNiJyDfva0tLTH3ej_qORzrQcPz32bSMKl49_0,6720
428
+ model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py,sha256=iZDgHd0SVbgNTT-jtSP0SWsaRGfAJM_p-wpBlBkpRAQ,6723
429
429
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py,sha256=KP8IWlHzkXzVjqIiRtAW6sTYyHJ2wVFFX4hMt_N6o3s,9910
430
430
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities_component.py,sha256=FvrYI0Qy7DCmDp2gyUYyCZq5pY84JgLtJqSIiVTJ8Ss,1030
431
431
  model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -517,8 +517,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
517
517
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=yrZNVRm2IRU7r7R-hjS2lOQ6wvEEvbeunvf2jKoWjXk,3277
518
518
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
519
519
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=eyMoXt5o5EnMr6d-rpCwQdX5mAiYiymvbgKv4tf7-a0,4576
520
- mct_nightly-2.1.0.20240623.439.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
521
- mct_nightly-2.1.0.20240623.439.dist-info/METADATA,sha256=WEbIEXcMV0ByEAXphuekS9-QKjA-l5koAM7yAhRQuxc,19726
522
- mct_nightly-2.1.0.20240623.439.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
523
- mct_nightly-2.1.0.20240623.439.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
524
- mct_nightly-2.1.0.20240623.439.dist-info/RECORD,,
520
+ mct_nightly-2.1.0.20240625.423.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
521
+ mct_nightly-2.1.0.20240625.423.dist-info/METADATA,sha256=7jBDYu3Qpt8uYnxWqxnRbqI3O7C1dfX60xyIDj6LrYA,19719
522
+ mct_nightly-2.1.0.20240625.423.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
523
+ mct_nightly-2.1.0.20240625.423.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
524
+ mct_nightly-2.1.0.20240625.423.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.1.0.20240623.000439"
30
+ __version__ = "2.1.0.20240625.000423"
@@ -1,4 +1,4 @@
1
- from typing import Dict, Any, Tuple, Type
1
+ from typing import Dict, Any, Tuple, Type, List, Union
2
2
 
3
3
  from model_compression_toolkit.constants import FOUND_TF
4
4
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
@@ -25,7 +25,7 @@ class FunctionalNode(BaseNode):
25
25
  functional_op: Any = None,
26
26
  inputs_as_list: bool = False,
27
27
  has_activation: bool = True,
28
- tensor_input_allocs = None):
28
+ tensor_input_allocs: List[Union[int, str]] = None):
29
29
  """
30
30
  Init a FunctionalNode object.
31
31
 
@@ -44,8 +44,7 @@ class FunctionalNode(BaseNode):
44
44
  functional_op: The op the node implements.
45
45
  inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
46
46
  has_activation: Whether the node has activations that we might want to quantize.
47
- tensor_input_allocs: A list of indices for activation tensors in the node's input tensor list
48
-
47
+ tensor_input_allocs: A list of indices and strings for allocations input tensors in the node's args and kwargs.
49
48
  """
50
49
 
51
50
  super().__init__(name,
@@ -106,7 +106,7 @@ def _run_operation(n: BaseNode,
106
106
  input_tensors: List,
107
107
  op_func: Any,
108
108
  quantize_node_activation_fn,
109
- use_activation_quantization: bool) -> Tuple[Union[List, torch.Tensor], Union[List, torch.Tensor]]:
109
+ use_activation_quantization: bool) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
110
110
  """
111
111
  Applying the layer (op_func) to the input tensors (input_tensors).
112
112
  If quantized is set to True, and the layer's corresponding node (n) has quantization
@@ -126,17 +126,17 @@ def _run_operation(n: BaseNode,
126
126
  op_call_args = n.op_call_args if isinstance(n, FunctionalNode) else []
127
127
  functional_kwargs = n.op_call_kwargs if isinstance(n, FunctionalNode) else {}
128
128
 
129
- if not (isinstance(n, FunctionalNode) and isinstance(op_func, PytorchQuantizationWrapper)):
130
- # Insert positional weights only when not a quantized functional node, because quantized functional nodes
131
- # insert the quantized weights in the wrapper.
129
+ # Insert positional weights only when not a quantized functional node, because quantized functional nodes
130
+ # insert the quantized weights in the wrapper.
131
+ if isinstance(n, FunctionalNode) and isinstance(op_func, PytorchQuantizationWrapper):
132
+ _tensor_input_allocs = [i for i in n.tensor_input_allocs if i not in n.weights]
133
+ else:
132
134
  input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
133
135
  # convert inputs from positional weights (numpy arrays) to tensors. Must handle each element in the
134
136
  # list separately, because in FX the tensors are FX objects and fail to_torch_tensor
135
137
  input_tensors = [to_torch_tensor(t, numpy_type=t.dtype) if isinstance(t, np.ndarray) else t
136
138
  for t in input_tensors]
137
139
  _tensor_input_allocs = None
138
- else:
139
- _tensor_input_allocs = [i for i in n.tensor_input_allocs if i not in n.weights]
140
140
 
141
141
  if isinstance(n, FunctionalNode) and n.inputs_as_list:
142
142
  out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
@@ -152,6 +152,8 @@ def _run_operation(n: BaseNode,
152
152
  out_tensors_of_n_float = torch.cat(out_tensors_of_n_float, dim=0)
153
153
  out_tensors_of_n = quantize_node_activation_fn(out_tensors_of_n_float)
154
154
 
155
+ if not isinstance(out_tensors_of_n, list):
156
+ out_tensors_of_n, out_tensors_of_n_float = [out_tensors_of_n], [out_tensors_of_n_float]
155
157
  return out_tensors_of_n, out_tensors_of_n_float
156
158
 
157
159
 
@@ -318,12 +320,8 @@ class PytorchModel(torch.nn.Module):
318
320
  quantize_node_activation_fn=activation_quantization_fn,
319
321
  use_activation_quantization=use_activation_quantization)
320
322
 
321
- if isinstance(out_tensors_of_n, list):
322
- node_to_output_tensors_dict.update({node: out_tensors_of_n})
323
- node_to_output_tensors_dict_float.update({node: out_tensors_of_n_float})
324
- else:
325
- node_to_output_tensors_dict.update({node: [out_tensors_of_n]})
326
- node_to_output_tensors_dict_float.update({node: [out_tensors_of_n_float]})
323
+ node_to_output_tensors_dict.update({node: out_tensors_of_n})
324
+ node_to_output_tensors_dict_float.update({node: out_tensors_of_n_float})
327
325
 
328
326
  if self.append2output:
329
327
  outputs = _generate_outputs(self.append2output,
@@ -19,6 +19,7 @@ from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.core import common
20
20
  from model_compression_toolkit.core.common.graph.base_graph import Graph
21
21
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
22
+ from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
22
23
  from model_compression_toolkit.core.pytorch.constants import IN_CHANNELS, OUT_CHANNELS, KERNEL_SIZE, KERNEL, BIAS
23
24
  from model_compression_toolkit.core.common import FrameworkInfo
24
25
 
@@ -37,7 +38,7 @@ class FunctionalConvSubstitution(common.BaseSubstitution):
37
38
 
38
39
  def substitute(self,
39
40
  graph: Graph,
40
- func_node: BaseNode) -> Graph:
41
+ func_node: FunctionalNode) -> Graph:
41
42
  """
42
43
  Substitute functional and conv/linear layer with torch layer
43
44
  Args:
@@ -60,9 +61,15 @@ class FunctionalConvSubstitution(common.BaseSubstitution):
60
61
  # Create new node of layer convolution
61
62
  if 1 not in func_node.weights:
62
63
  Logger.critical(f'Weight input missing for node {func_node.name}.') # pragma: no cover
63
- weight = func_node.weights[1]
64
- bias = func_node.weights.get(2)
65
- framework_attr = func_node.framework_attr
64
+ # Extract index of kernel and bias according to tensor_input_allocs if they were input as kwargs. If
65
+ # they were input as args, use their fixed positions.
66
+ weight_index = func_node.tensor_input_allocs.index(KERNEL) if KERNEL in func_node.tensor_input_allocs else 1
67
+ bias_index = func_node.tensor_input_allocs.index(BIAS) if BIAS in func_node.tensor_input_allocs else 2
68
+ if weight_index not in func_node.weights:
69
+ Logger.critical(f'Mismatch between tensor_input_allocs and weight index in node {func_node.name}.') # pragma: no cover
70
+ weight = func_node.weights[weight_index]
71
+ bias = func_node.weights.get(bias_index)
72
+ framework_attr = func_node.op_call_kwargs
66
73
  framework_attr.update({OUT_CHANNELS: weight.shape[out_channel_index]})
67
74
  framework_attr.update({IN_CHANNELS: weight.shape[in_channel_index]})
68
75
  framework_attr.update({KERNEL_SIZE: weight.shape[2:]})
@@ -20,6 +20,7 @@ import torch.nn.functional as F
20
20
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
21
21
  from model_compression_toolkit.core import common
22
22
  from model_compression_toolkit.core.common import BaseNode, Graph
23
+ from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
23
24
  from model_compression_toolkit.core.pytorch.constants import *
24
25
  from model_compression_toolkit.logger import Logger
25
26
 
@@ -37,9 +38,12 @@ class FunctionalBatchNorm(common.BaseSubstitution):
37
38
  super().__init__(matcher_instance=bn_node)
38
39
 
39
40
  @staticmethod
40
- def get_attributes_from_weights(node: BaseNode) -> Dict:
41
+ def get_attributes_from_weights(node: FunctionalNode) -> Dict:
41
42
  """
42
- convert functional batch_norm positional weights to BatchNorm2d weights
43
+ Convert functional batch_norm positional weights to BatchNorm2d weights. Extract indices of gamma
44
+ and beta according to tensor_input_allocs if they were input as kwargs. If they were input as args,
45
+ use their fixed positions.
46
+
43
47
  Args:
44
48
  node: functional batch_norm node.
45
49
 
@@ -53,23 +57,22 @@ class FunctionalBatchNorm(common.BaseSubstitution):
53
57
  GAMMA: np.ones(node.weights[1].shape),
54
58
  BETA: np.zeros(node.weights[1].shape)}
55
59
 
56
- has_weight = WEIGHT not in node.framework_attr
57
- has_bias = BIAS not in node.framework_attr
60
+ # Check if weight and/or bias were not given.
61
+ if KERNEL in node.tensor_input_allocs:
62
+ weights_dict[GAMMA] = node.weights[node.tensor_input_allocs.index(KERNEL)]
63
+ elif KERNEL not in node.op_call_kwargs:
64
+ weights_dict[GAMMA] = node.weights[3]
58
65
 
59
- if 3 in node.weights:
60
- if has_weight:
61
- weights_dict[GAMMA] = node.weights[3]
62
- else:
63
- weights_dict[BETA] = node.weights[3]
64
- if 4 in node.weights:
65
- assert has_bias
66
+ if BIAS in node.tensor_input_allocs:
67
+ weights_dict[BETA] = node.weights[node.tensor_input_allocs.index(BIAS)]
68
+ elif BIAS not in node.op_call_kwargs:
66
69
  weights_dict[BETA] = node.weights[4]
67
70
 
68
71
  return weights_dict
69
72
 
70
73
  def substitute(self,
71
74
  graph: Graph,
72
- node: BaseNode) -> Graph:
75
+ node: FunctionalNode) -> Graph:
73
76
  """
74
77
  Substitute functional.batch_norm and its inputs with BatchNorm2d.
75
78
  Args:
@@ -87,10 +90,13 @@ class FunctionalBatchNorm(common.BaseSubstitution):
87
90
  bn_node_weights = self.get_attributes_from_weights(node)
88
91
  if not bn_node_weights:
89
92
  return graph
93
+ framework_attr = {NUM_FEATURES: out_channels}
94
+ if EPSILON in node.op_call_kwargs:
95
+ framework_attr.update({EPSILON: node.op_call_kwargs[EPSILON]})
96
+ if MOMENTUM in node.op_call_kwargs:
97
+ framework_attr.update({MOMENTUM: node.op_call_kwargs[MOMENTUM]})
90
98
  new_batchnorm2d = BaseNode(name=node.name + '_into_BatchNorm2d',
91
- framework_attr={NUM_FEATURES: out_channels,
92
- EPSILON: EPSILON_VAL,
93
- MOMENTUM: MOMENTUM_VAL},
99
+ framework_attr=framework_attr,
94
100
  input_shape=node.output_shape,
95
101
  output_shape=node.output_shape,
96
102
  weights=bn_node_weights,
@@ -21,6 +21,7 @@ from typing import Dict, Tuple, List
21
21
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
22
22
  from model_compression_toolkit.core import common
23
23
  from model_compression_toolkit.core.common import BaseNode, Graph
24
+ from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
24
25
  from model_compression_toolkit.core.pytorch.constants import *
25
26
  from model_compression_toolkit.logger import Logger
26
27
 
@@ -38,9 +39,11 @@ class FunctionalLayerNorm(common.BaseSubstitution):
38
39
  super().__init__(matcher_instance=ln_node)
39
40
 
40
41
  @staticmethod
41
- def get_attributes_from_weights(node: BaseNode, normalized_shape: [Tuple, List, int]) -> Dict:
42
+ def get_attributes_from_weights(node: FunctionalNode, normalized_shape: [Tuple, List, int]) -> Dict:
42
43
  """
43
- Parse layer_norm(input, normalized_shape, weight=None, bias=None)
44
+ Convert functional layer_norm positional weights to LayerNorm weights. Extract indices of gamma
45
+ and beta according to tensor_input_allocs if they were input as kwargs. If they were input as args,
46
+ use their fixed positions.
44
47
  Args:
45
48
  node: Node that match the pattern in the substitution init.
46
49
  normalized_shape: nn.LayerNorm "normalized_shape" argument
@@ -50,28 +53,26 @@ class FunctionalLayerNorm(common.BaseSubstitution):
50
53
  """
51
54
 
52
55
  # Define default weight and bias
53
- weights_dict = {GAMMA: np.ones(normalized_shape), # Default value in case weight is not given
54
- BETA: np.zeros(normalized_shape) # Default value in case bias is not given
56
+ weights_dict = {GAMMA: np.ones(normalized_shape), # Default value in case weight is not given
57
+ BETA: np.zeros(normalized_shape) # Default value in case bias is not given
55
58
  }
56
59
 
57
60
  # Check if weight and/or bias were not given.
58
- has_weight = WEIGHT not in node.framework_attr
59
- has_bias = BIAS not in node.framework_attr
61
+ if KERNEL in node.tensor_input_allocs:
62
+ weights_dict[GAMMA] = node.weights[node.tensor_input_allocs.index(KERNEL)]
63
+ elif KERNEL not in node.op_call_kwargs:
64
+ weights_dict[GAMMA] = node.weights[1]
60
65
 
61
- if 1 in node.weights:
62
- if has_weight:
63
- weights_dict[GAMMA] = node.weights[1]
64
- else:
65
- weights_dict[BETA] = node.weights[1]
66
- if 2 in node.weights:
67
- assert has_bias
66
+ if BIAS in node.tensor_input_allocs:
67
+ weights_dict[BETA] = node.weights[node.tensor_input_allocs.index(BIAS)]
68
+ elif BIAS not in node.op_call_kwargs:
68
69
  weights_dict[BETA] = node.weights[2]
69
70
 
70
71
  return weights_dict
71
72
 
72
73
  def substitute(self,
73
74
  graph: Graph,
74
- node: BaseNode) -> Graph:
75
+ node: FunctionalNode) -> Graph:
75
76
  """
76
77
  Substitute functional.layer_norm and its inputs with LayerNorm.
77
78
  Args:
@@ -85,10 +86,11 @@ class FunctionalLayerNorm(common.BaseSubstitution):
85
86
 
86
87
  ln_node_weights = self.get_attributes_from_weights(node, normalized_shape)
87
88
 
89
+ framework_attr = {NORMALIZED_SHAPE: normalized_shape}
90
+ if EPSILON in node.op_call_kwargs:
91
+ framework_attr.update({EPSILON: node.op_call_kwargs[EPSILON]})
88
92
  new_layernorm = BaseNode(name=node.name + '_into_LayerNorm',
89
- framework_attr={NORMALIZED_SHAPE: normalized_shape,
90
- EPSILON: node.framework_attr.get('eps'),
91
- },
93
+ framework_attr=framework_attr,
92
94
  input_shape=node.output_shape,
93
95
  output_shape=node.output_shape,
94
96
  weights=ln_node_weights,
@@ -13,11 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import inspect
16
- from typing import Dict, List, Tuple, Callable
16
+ from operator import getitem
17
+ from typing import Dict, List, Tuple, Callable, Union, Any, Type
18
+
19
+ import numpy as np
17
20
  import torch
18
21
  from torch.fx import GraphModule, Node
19
22
 
20
- from model_compression_toolkit.core import common
21
23
  from model_compression_toolkit.core.common import BaseNode
22
24
  from model_compression_toolkit.core.common.graph.base_graph import OutTensor
23
25
  from model_compression_toolkit.core.common.graph.edge import Edge
@@ -28,29 +30,131 @@ from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlac
28
30
  from model_compression_toolkit.logger import Logger
29
31
 
30
32
 
31
- def extract_holder_weights(constant_name, node_target, model, weights, to_numpy):
33
+ def _extract_parameters_and_buffers(module: Union[torch.nn.Module, GraphModule],
34
+ to_numpy: Callable) -> Dict[str, np.ndarray]:
32
35
  """
33
- Extract layer weights and named buffers to a dictionary.
36
+ Extract parameters & buffers from input module to a dictionary.
34
37
  Args:
35
- constant_name: name to write the parameters under, should be the node name.
36
- node_target: relevant parameter name from Pytorch FX model.
37
- model: Pytorch FX model.
38
- weights: dictionary containing the weights of the node.
38
+ module: FX ot PyTorch module to extract parameters and buffers from.
39
+
40
+ Returns:
41
+ Dictionary containing module parameters and buffers by name.
42
+ """
43
+
44
+ named_parameters = {name: to_numpy(parameter) for name, parameter in module.named_parameters()}
45
+ named_buffers = {name: to_numpy(buffer) for name, buffer in module.named_buffers()}
46
+
47
+ return {**named_parameters, **named_buffers}
48
+
49
+
50
+ def is_instance_first_arg(n: Node, expected_type: Union[Type, Tuple[Type]]) -> bool:
51
+ """
52
+ Check whether first argument of the node is the expected type
53
+ Args:
54
+ n: fx node.
55
+ expected_type: Expected 1st argument type.
56
+
57
+ Returns:
58
+ True is the first argument of node n is of the expected type, else return False.
59
+
60
+ """
61
+ return len(n.args) > 0 and isinstance(n.args[0], expected_type)
62
+
63
+
64
+ def _build_input_alloc_and_call_args(n: Node, input_tensors_in_node_kwargs: Dict,
65
+ inputs_as_list: bool) -> Tuple[List, List]:
66
+ """
67
+ Build the tensor inputs list and op_call_args of the functional node.
68
+
69
+ Args:
70
+ n: fx node.
71
+ input_tensors_in_node_kwargs: A dictionary of node kwarg name and input fx node.
72
+ inputs_as_list: Is node's inputs are a list.
73
+
74
+ Returns:
75
+ A list of updated op_call args.
76
+ A list of tensor allocations in node's inputs.
77
+
78
+ """
79
+
80
+ tensor_input_alloc = []
81
+ op_call_args = list(n.args)
82
+ if inputs_as_list:
83
+ op_call_args.pop(0)
84
+ else:
85
+ for in_node in n.all_input_nodes:
86
+ # The extra for loop is used to tackle the case of the same input tensor for this node (e.g. torch.add(x, x)).
87
+ for i, arg in enumerate(n.args):
88
+ if arg == in_node:
89
+ tensor_input_alloc.append(i)
90
+ for k, arg in input_tensors_in_node_kwargs.items():
91
+ if arg == in_node:
92
+ tensor_input_alloc.append(k)
93
+
94
+ return op_call_args, tensor_input_alloc
95
+
96
+
97
+ def _extract_torch_layer_data(node_module: torch.nn.Module,
98
+ to_numpy: Callable) -> Tuple[Any, Dict[str, np.ndarray], Dict]:
99
+ """
100
+ Extract required data from a non-functional node to rebuild the PyTorch layer.
101
+
102
+ Args:
103
+ node_module: Torch layer, such as nn.Conv2d, nn.Linear, etc.
39
104
  to_numpy: Function to convert framework's tensor to a Numpy array.
40
105
 
41
106
  Returns:
42
- Updated weights dictionary.
107
+ Node layer class.
108
+ A mapping between the layer's named parameters and buffers to their tensor values.
109
+ A framework_attr dictionary required to instantiate the node with the layer class.
110
+ """
111
+ node_type = type(node_module)
112
+ if not isinstance(node_module, torch.nn.Module):
113
+ Logger.error(f"Expected an instance of torch.nn.Module for node {node_module.name}, but got {node_type}")
114
+ # Extract the instance framework_attr (i.e. the arguments the class instance was initialized with). "fullargspec"
115
+ # is a list of the layer's attribute names, that will be used as keys of the framework_attr dictionary. We the
116
+ # values from the layer instance.
117
+ fullargspec = inspect.getfullargspec(node_type.__init__).args
118
+ framework_attr = {k: v for k, v in node_module.__dict__.items() if k in fullargspec}
119
+ # The "bias" argument doesn't appear in the node_module.__dict__, so we add it manually.
120
+ if hasattr(node_module, BIAS) and BIAS in fullargspec:
121
+ framework_attr[BIAS] = False if node_module.bias is None else True
122
+
123
+ # Extract layer weights and named buffers.
124
+ weights = {n: w for n, w in _extract_parameters_and_buffers(node_module, to_numpy).items() if len(w.shape) > 0}
125
+ return node_type, weights, framework_attr
126
+
127
+
128
+ def _extract_input_and_output_shapes(_node: Node) -> Tuple[List, List]:
129
+ """
130
+ Extract input and output shapes of a node.
131
+ Args:
132
+ _node: fx node.
133
+
134
+ Returns:
135
+ Input and output shapes as lists.
43
136
  """
44
- named_parameters_weights = {constant_name: to_numpy(parameter) for name, parameter in
45
- model.named_parameters() if node_target == name}
46
- named_buffer_weights = {constant_name: to_numpy(parameter) for name, parameter in
47
- model.named_buffers() if node_target == name}
48
- if len(named_parameters_weights) + len(named_buffer_weights) > 1:
49
- Logger.critical("A single constant parameter must correspond to exactly one tensor. Found {len(named_parameters_weights) + len(named_buffer_weights)} parameters.")
137
+ input_shape = []
138
+ if _node.op != PLACEHOLDER:
139
+ for i, input_node in enumerate(_node.all_input_nodes):
140
+ tensor_meta = input_node.meta
141
+ if tensor_meta[TYPE] in [torch.Tensor, torch.nn.parameter.Parameter]:
142
+ input_shape += [list(tensor_meta[TENSOR_META].shape)]
143
+ elif tensor_meta[TYPE] == tuple:
144
+ input_shape += [list(n.shape) for n in tensor_meta[TENSOR_META]]
145
+ elif tensor_meta[TYPE] == int:
146
+ input_shape += [[1]]
147
+
148
+ if _node.meta[TYPE] == torch.Tensor:
149
+ output_shape = [list(_node.meta[TENSOR_META].shape)]
150
+ elif _node.meta[TYPE] in (list, tuple):
151
+ output_shape = [list(m.shape) for m in _node.meta[TENSOR_META]]
152
+ elif _node.meta[TYPE] == int:
153
+ output_shape = [[1]]
154
+ else:
155
+ output_shape = []
50
156
 
51
- weights.update(named_parameters_weights)
52
- weights.update(named_buffer_weights)
53
- return weights
157
+ return input_shape, output_shape
54
158
 
55
159
 
56
160
  def nodes_builder(model: GraphModule,
@@ -67,135 +171,104 @@ def nodes_builder(model: GraphModule,
67
171
  Returns:
68
172
  A list of Graph nodes that were built from the fx GraphModule nodes.
69
173
  """
70
- # init function variables:
71
- inputs = []
72
- outputs = []
73
- nodes = []
74
- output_nodes = []
174
+ # Init function variables:
175
+ inputs, outputs = [], []
176
+ nodes, output_nodes = [], []
75
177
  fx_node_2_graph_node = {}
76
178
  consts_dict = {}
77
179
  used_consts = set()
78
180
 
181
+ # Init parameters & buffers dictionary of the entire model. We later extract the constants values from this dictionary.
182
+ model_parameters_and_buffers = _extract_parameters_and_buffers(model, to_numpy)
183
+
79
184
  for node in model.graph.nodes:
80
- # extract node type and framework attributes
81
- framework_attr = dict(node.kwargs)
185
+
186
+ # ##############################################
187
+ # Extract node type and framework attributes #
188
+ # ##############################################
189
+ weights = {}
190
+ framework_attr = {}
82
191
  node_has_activation = True
192
+
83
193
  if node.target in module_dict.keys():
84
- node_module = module_dict[node.target]
85
- node_type = type(node_module)
86
- framework_attr = node_module.__dict__
87
- fullargspec = inspect.getfullargspec(node_type.__init__).args
88
- framework_attr = {k: v for k, v in framework_attr.items() if k in fullargspec}
89
- if hasattr(node_module, BIAS) and BIAS in fullargspec:
90
- framework_attr[BIAS] = False if node_module.bias is None else True
194
+ # PyTorch module node, such as nn.Conv2d or nn.Linear.
195
+ node_type, weights, framework_attr = _extract_torch_layer_data(module_dict[node.target], to_numpy)
196
+
91
197
  elif node.op == CALL_FUNCTION:
198
+ # Node is a function that handle a parameter\buffer in the model.
92
199
  node_type = node.target
93
- if node_type == getattr:
200
+ if node_type in [getattr, getitem]:
94
201
  node_has_activation = False
95
- Logger.warning(
96
- 'Pytorch model has a parameter or constant Tensor value. This can cause unexpected behaviour when '
97
- 'converting the model.')
202
+
98
203
  elif node.op == PLACEHOLDER:
204
+ # Input node to the model.
99
205
  node_type = DummyPlaceHolder
206
+
100
207
  elif node.op == OUTPUT:
208
+ # Output node of the model. Only saved in output_nodes for later handling.
101
209
  output_nodes += node.all_input_nodes
102
210
  continue
211
+
103
212
  elif node.op == CALL_METHOD:
213
+ # Node is a PyTorch function such as torch.add, torch.reshape etc.
104
214
  if hasattr(torch, node.target):
105
215
  node_type = getattr(torch, node.target)
106
216
  elif hasattr(torch.Tensor, node.target):
107
217
  node_type = getattr(torch.Tensor, node.target)
108
218
  else:
109
- Logger.critical(f"The call method '{node.target}' is not supported.")
110
- elif node.op == GET_ATTR:
111
- Logger.warning(
112
- 'Pytorch model has a parameter or constant Tensor value. This can cause unexpected behaviour when '
113
- 'converting the model.')
114
- else:
115
- Logger.critical(f'Encountered an unsupported node type in node: {node.name}.')
219
+ Logger.critical(f"The call method '{node.target}' in {node} is not supported.")
116
220
 
117
- # extract layer weights and named buffers
118
- weights = {}
119
- if node.target in module_dict.keys():
120
- named_parameters_weights = {name: to_numpy(parameter) for name, parameter in
121
- module_dict[node.target].named_parameters()}
122
- named_buffer_weights = {name: to_numpy(parameter) for name, parameter in
123
- module_dict[node.target].named_buffers() if len(parameter.shape) > 0}
124
- weights.update(named_parameters_weights)
125
- weights.update(named_buffer_weights)
126
-
127
- if node.op == GET_ATTR:
128
- new_const = extract_holder_weights(node, node.target, model, weights, to_numpy)
129
- if list(new_const.keys())[0] in consts_dict:
221
+ elif node.op == GET_ATTR:
222
+ # Node holding a constant -> add to consts_dict so can add them later to weights of next node.
223
+ if node.target in consts_dict:
130
224
  Logger.critical('A constant weight appears to have been recorded multiple times.')
131
- consts_dict.update(new_const)
225
+ consts_dict[node] = model_parameters_and_buffers[node.target]
132
226
  continue
227
+ else:
228
+ Logger.critical(f'Encountered an unsupported node type in node: {node.name}.')
133
229
 
134
- # extract input shapes and const weights
135
- input_shape = []
230
+ # Add constants to weights dictionary.
136
231
  if node.op != PLACEHOLDER:
137
232
  for i, input_node in enumerate(node.all_input_nodes):
138
233
  if input_node in consts_dict:
139
234
  used_consts.add(input_node)
140
235
  weights.update({i: consts_dict[input_node]})
141
236
 
142
- tensor_meta = input_node.meta
143
- if tensor_meta[TYPE] in [torch.Tensor, torch.nn.parameter.Parameter]:
144
- input_shape += [list(tensor_meta[TENSOR_META].shape)]
145
- elif tensor_meta[TYPE] == tuple:
146
- input_shape += [list(n.shape) for n in tensor_meta[TENSOR_META]]
147
- elif tensor_meta[TYPE] == int:
148
- input_shape += [[1]]
149
-
150
- # extract output shapes
151
- if node.meta[TYPE] == torch.Tensor:
152
- output_shape = [list(node.meta[TENSOR_META].shape)]
153
- elif node.meta[TYPE] in (list, tuple):
154
- output_shape = [list(m.shape) for m in node.meta[TENSOR_META]]
155
- elif node.meta[TYPE] == int:
156
- output_shape = [[1]]
157
- else:
158
- output_shape = []
159
-
160
- # filter Nodes from framework attributes, we replace these attributes with nx graph nodes
161
- framework_attr_filtered = {}
162
- framework_attr_nodes = {}
163
- for k, v in framework_attr.items():
164
- if isinstance(v, torch.fx.node.Node):
165
- framework_attr_nodes[k] = v
166
- else:
167
- framework_attr_filtered[k] = v
168
- framework_attr = framework_attr_filtered
169
-
170
- # filter Nodes from node kwargs, we replace these attributes with nx graph nodes
171
- node_kwargs = {}
172
- for k, v in node.kwargs.items():
173
- if not isinstance(v, torch.fx.node.Node):
174
- node_kwargs[k] = v
237
+ # Extract input and output shapes of the node.
238
+ input_shape, output_shape = _extract_input_and_output_shapes(node)
175
239
 
176
- # initiate graph nodes
240
+ # Initiate graph nodes.
177
241
  if node.op in [CALL_METHOD, CALL_FUNCTION]:
178
242
  graph_node_type = FunctionalNode
179
- inputs_as_list1 = len(node.args) > 0 and isinstance(node.args[0], (list, tuple)) and all(
180
- [isinstance(n, torch.fx.node.Node) for n in node.args[0]])
181
- inputs_as_list = inputs_as_list1 or (len(node.args) > 0 and isinstance(node.args[0], Node) and
182
- node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple))
183
- tensor_input_alloc = []
184
- op_call_args = list(node.args)
185
- if inputs_as_list:
186
- op_call_args.pop(0)
187
- else:
188
- for in_node in node.all_input_nodes:
189
- for i, arg in enumerate(node.args):
190
- if arg == in_node:
191
- tensor_input_alloc.append(i)
192
- for k, arg in framework_attr_nodes.items():
193
- if arg == in_node:
194
- tensor_input_alloc.append(k)
195
-
196
- # remove torch.fx.node.Node from inputs to graph_node_type
243
+
244
+ # Filter FX nodes from node_kwargs. These FX nodes are tensor inputs to the node that are part of the
245
+ # model's graph. We remove them because the node_kwargs should not include input tensors of the node.
246
+ # These input tensors will be inserted in the kwargs according to the tensor_input_alloc which is used
247
+ # to convert the input_tensors list in the builder to the node's args & kwargs.
248
+ node_kwargs, input_tensors_in_node_kwargs = {}, {}
249
+ for k, v in node.kwargs.items():
250
+ if isinstance(v, Node):
251
+ input_tensors_in_node_kwargs[k] = v
252
+ else:
253
+ node_kwargs[k] = v
254
+
255
+ # Check if node's first input argument is a list of input fx nodes, such as torch.cat:
256
+ is_first_input_list_of_nodes = is_instance_first_arg(node, (list, tuple)) and all(
257
+ [isinstance(n, Node) for n in node.args[0]])
258
+ is_placeholder_a_list = is_instance_first_arg(node, Node) and \
259
+ node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple)
260
+ inputs_as_list = is_first_input_list_of_nodes or is_placeholder_a_list
261
+
262
+ # Build tensor_input_alloc required for the model builder. All input nodes are received as a list in the builder,
263
+ # so tensor_input_alloc is used to allocate each input tensor in the correct place in the node's args & kwargs.
264
+ op_call_args, tensor_input_alloc = _build_input_alloc_and_call_args(node, input_tensors_in_node_kwargs,
265
+ inputs_as_list)
266
+
267
+ # Remove torch.fx.node.Node from inputs to the functional node. FX nodes are input tensors in the builder,
268
+ # so they are remove from the op_call_args (same as op_call_kwargs) and are inserted back according to the
269
+ # tensor_input_alloc list.
197
270
  op_call_args = [arg for arg in op_call_args if not isinstance(arg, Node)]
198
- # convert torch.fx.immutable_collections.immutable_list to tuple
271
+ # Convert torch.fx.immutable_collections.immutable_list to tuple.
199
272
  op_call_args = [tuple(arg) if isinstance(arg, torch.fx.immutable_collections.immutable_list) else arg
200
273
  for arg in op_call_args]
201
274
 
@@ -205,8 +278,12 @@ def nodes_builder(model: GraphModule,
205
278
  INPUTS_AS_LIST: inputs_as_list,
206
279
  TENSOR_INPUT_ALLOCS: tensor_input_alloc}
207
280
  else:
281
+ if not all([not isinstance(v, Node) for v in framework_attr.values()]):
282
+ Logger.critical(f'Found FX nodes in framework attributes of {node.name}. This node type should not contain any.') # pragma: no cover
283
+
208
284
  graph_node_type = BaseNode
209
285
  kwargs = {}
286
+
210
287
  graph_node = graph_node_type(name=node.name,
211
288
  framework_attr=framework_attr,
212
289
  input_shape=input_shape,
@@ -216,7 +293,7 @@ def nodes_builder(model: GraphModule,
216
293
  has_activation=node_has_activation,
217
294
  **kwargs)
218
295
 
219
- # generate graph inputs list
296
+ # Generate graph inputs list.
220
297
  if node.op == PLACEHOLDER:
221
298
  for ii in range(len(output_shape)):
222
299
  inputs.append(graph_node)
@@ -224,12 +301,12 @@ def nodes_builder(model: GraphModule,
224
301
  fx_node_2_graph_node[node] = graph_node
225
302
  nodes.append(graph_node)
226
303
 
227
- # make sure all extracted constants were used in the graph
304
+ # Check whether all extracted constants were used in the graph.
228
305
  not_connected_consts = [c for c in consts_dict if c not in used_consts]
229
306
  if not_connected_consts:
230
- Logger.critical(f'Error reading graph: These constants are not connected in the graph: {not_connected_consts}.')
307
+ Logger.critical(f'Error reading graph: These constants are not connected in the graph: {not_connected_consts}.') # pragma: no cover
231
308
 
232
- # generate graph outputs list
309
+ # Generate graph outputs list.
233
310
  for node in output_nodes:
234
311
  outputs.append(OutTensor(fx_node_2_graph_node[node], output_nodes.index(node)))
235
312
 
@@ -216,7 +216,7 @@ def _set_final_resource_utilization(graph: Graph,
216
216
  # No relevant nodes have been quantized with affect on the given target - since we only consider
217
217
  # in the model's final size the quantized layers size, this means that the final size for this target
218
218
  # is zero.
219
- Logger.warning(f"No relevant quantized layers for the ru target {ru_target} were found, the recorded"
219
+ Logger.warning(f"No relevant quantized layers for the ru target {ru_target} were found, the recorded "
220
220
  f"final ru for this target would be 0.")
221
221
  final_ru_dict[ru_target] = 0
222
222
 
@@ -148,6 +148,6 @@ class OperationsToLayers:
148
148
  qco_by_opset_name = _current_tpc.get().tp_model.get_config_options_by_operators_set(ops2layers.name)
149
149
  if layer in existing_layers:
150
150
  Logger.critical(f'Found layer {layer.__name__} in more than one '
151
- f'OperatorsSet') # pragma: no cover
151
+ f'OperatorsSet') # pragma: no cover
152
152
  else:
153
153
  existing_layers.update({layer: qco_by_opset_name})