mct-nightly 2.1.0.20240616.434__py3-none-any.whl → 2.1.0.20240617.451__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.1.0.20240616.434
3
+ Version: 2.1.0.20240617.451
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=mregiWggl4blOSBp61pEozF4QXuJnyAsxzpvrXIwh2k,1573
1
+ model_compression_toolkit/__init__.py,sha256=1OlgfgkWMboHXhaDyyG9E0dc_vedCsy6r_gAtSq-lfY,1573
2
2
  model_compression_toolkit/constants.py,sha256=9pVleMwnhlM4QwIL2HcEq42I1uF4rlSw63RUjkxOF4w,3923
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -33,7 +33,7 @@ model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaU
33
33
  model_compression_toolkit/core/common/graph/base_graph.py,sha256=lmIw0srKiwCvz7KWqfwKTxyQHDy3s6rWMIXzFAa1UMo,38326
34
34
  model_compression_toolkit/core/common/graph/base_node.py,sha256=X_0zqHrKYAsmnj9tAKjVYasbFcZD8OHpjdiMj9ugQs0,29436
35
35
  model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
36
- model_compression_toolkit/core/common/graph/functional_node.py,sha256=71_4TrCdqR_r0mtgxmAyqI05iP5YoQQGeSmDgynuzTw,3902
36
+ model_compression_toolkit/core/common/graph/functional_node.py,sha256=_6HsBeLlrpLvXhLPRJswcyDa4z16-O3xzHzGuv46zBc,3897
37
37
  model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
38
38
  model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-oL0lFO9PhQFt9fEFgVJwpc1u4fHExAtE,5128
39
39
  model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=3el-A7j1oyoo1_9zq3faQp7IeRsFXFCvnrb3zZFXpU0,9803
@@ -186,7 +186,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/weights_a
186
186
  model_compression_toolkit/core/keras/hessian/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
187
187
  model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py,sha256=4eJKq_Fx4mm_VuBDeeti0fTcUk1lL2yjebxCugJhvrA,8871
188
188
  model_compression_toolkit/core/keras/hessian/trace_hessian_calculator_keras.py,sha256=hRfAjgZakDaIMuERmTVjJSa_Ww6FmEudYPO9R7SuYuQ,3914
189
- model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py,sha256=KBjGr9FzyZIPD4MFtsV3LDBdJtLa0VFdIXyx_KAnjTQ,12215
189
+ model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py,sha256=Xogd90kPZPvKbplZQv5B77Dq_m4aW5-bL6Jxh33VZWs,12213
190
190
  model_compression_toolkit/core/keras/mixed_precision/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
191
191
  model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=aW8wR13fK6P6xzbU9XGU60IO1yYzXSo_Hk4qeq486kg,5137
192
192
  model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py,sha256=Ziydik2j-LvNBXP3TSfUD6rEezPAikzQGib0_IXkmGM,6729
@@ -210,7 +210,7 @@ model_compression_toolkit/core/keras/statistics_correction/__init__.py,sha256=9H
210
210
  model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py,sha256=XNCtT9klMcsO1v5KA3MmCq_WgXOIT5QSzbfTOa9T-04,3060
211
211
  model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
212
212
  model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
213
- model_compression_toolkit/core/pytorch/constants.py,sha256=NI-J7REuxn06oEIHsmJ4GqtNC3TbV8xlkJjt5Ar-c4U,2626
213
+ model_compression_toolkit/core/pytorch/constants.py,sha256=AguUnAsNlj41gwuKIP_7nos3FcJHsIAjewLXSQdrDQM,2624
214
214
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
215
215
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
216
216
  model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=7CFt1Y3fiDaKkEVvlDd76ZmucCuVp6OZNQwwqJezKbU,27547
@@ -222,7 +222,7 @@ model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,s
222
222
  model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
223
223
  model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
224
224
  model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=D7lU1r9Uq_7fdNuKk2BMF8ho5GrsY-8gyGN6yYoHaVg,15060
225
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=Zw4gi-wjJNV8-qGv79YBWVAHmy27f7iW0c2JGNWAKD0,18199
225
+ model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=Pvpkirt3ziWEXDEspgOhR8ALf-XAZUh-78IkXg9YMWs,18830
226
226
  model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
227
227
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
228
228
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
@@ -240,7 +240,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_
240
240
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=VNg-VgzCxSyqy2J3neEPl6U0SPO8UIVU_T47bGhz4FE,38459
241
241
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py,sha256=q1a3HieQtaOmWG2WGXp6GHYAvxa3CZ9dJUx9dqMAsS8,5695
242
242
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_identity.py,sha256=joHjwiUxccypMHkTy46rI91VyapLn9yJ2YRo5ISnOH4,1987
243
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py,sha256=jOqlelGhADEZiYUEyYj9oJZ5YLXx8jWNUlVTG6Td79Y,4919
243
+ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py,sha256=hAZXzrEinHa-dJHLj39Hy_9Q-13QyO95rtYVSLrhvT8,4915
244
244
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py,sha256=DcJEIkGvBdIMOelNIwaJUZ5UsAHiGnDJPR20I464vWo,2929
245
245
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py,sha256=XFtU9yuBmoZlX0f0mS6otMPWMk-RcWs94XdvvTNhW8Y,3303
246
246
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py,sha256=lOPl5zDU3FoR9WmlxO04Pfi65MimK0gmnuHzQJodQdY,10668
@@ -261,7 +261,7 @@ model_compression_toolkit/core/pytorch/quantizer/__init__.py,sha256=Rf1RcYmelmdZ
261
261
  model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py,sha256=D8_CEuFqKAhbUgKaRw7Jlxo0zlqgPTMu6CIIIM4LfS0,7045
262
262
  model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py,sha256=uyeBtNokyDUikk-YkDP_mN_2DX0J5oPm3kSfdSUT2Ck,4420
263
263
  model_compression_toolkit/core/pytorch/reader/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
264
- model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=x5n8KHBqvutqS5l5AillA_FQfhf-2ibP813ixK3Gvy8,12627
264
+ model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=LiGV-ZqlhxN1evpM-ur2dDVPowhrLwO7JZa7AGPftSk,12913
265
265
  model_compression_toolkit/core/pytorch/reader/node_holders.py,sha256=TaolORuwBZEddWe-q0Mg79Nmswz-Sq3-9-4o8UxFQ50,1028
266
266
  model_compression_toolkit/core/pytorch/reader/reader.py,sha256=GEJE0QX8XJFWbYCkbRBtzttZtmmuoACLx8gw9KyAQCE,6015
267
267
  model_compression_toolkit/core/pytorch/statistics_correction/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
@@ -491,8 +491,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
491
491
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
492
492
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
493
493
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=MxylaVFPgN7zBiRBy6WV610EA4scLgRJFbMucKvvNDU,2896
494
- mct_nightly-2.1.0.20240616.434.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
495
- mct_nightly-2.1.0.20240616.434.dist-info/METADATA,sha256=qRPmiufe7bR0fBvN5cB16-cYIewTIbYTbbpfBQvjA_8,19721
496
- mct_nightly-2.1.0.20240616.434.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
497
- mct_nightly-2.1.0.20240616.434.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
498
- mct_nightly-2.1.0.20240616.434.dist-info/RECORD,,
494
+ mct_nightly-2.1.0.20240617.451.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
495
+ mct_nightly-2.1.0.20240617.451.dist-info/METADATA,sha256=-ZUI2y7SZOGKyLl6qpBE9onj-lZwfGE0wLLJI5WeqIE,19721
496
+ mct_nightly-2.1.0.20240617.451.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
497
+ mct_nightly-2.1.0.20240617.451.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
498
+ mct_nightly-2.1.0.20240617.451.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.1.0.20240616.000434"
30
+ __version__ = "2.1.0.20240617.000451"
@@ -25,7 +25,7 @@ class FunctionalNode(BaseNode):
25
25
  functional_op: Any = None,
26
26
  inputs_as_list: bool = False,
27
27
  has_activation: bool = True,
28
- tensor_input_indices = None):
28
+ tensor_input_allocs = None):
29
29
  """
30
30
  Init a FunctionalNode object.
31
31
 
@@ -44,7 +44,7 @@ class FunctionalNode(BaseNode):
44
44
  functional_op: The op the node implements.
45
45
  inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
46
46
  has_activation: Whether the node has activations that we might want to quantize.
47
- tensor_input_indices: A list of indices for activation tensors in the node's input tensor list
47
+ tensor_input_allocs: A list of indices for activation tensors in the node's input tensor list
48
48
 
49
49
  """
50
50
 
@@ -63,7 +63,7 @@ class FunctionalNode(BaseNode):
63
63
  self.op_call_args = op_call_args
64
64
  self.functional_op = functional_op
65
65
  self.inputs_as_list = inputs_as_list
66
- self.tensor_input_indices = [] if tensor_input_indices is None else tensor_input_indices
66
+ self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs
67
67
 
68
68
  @property
69
69
  def type(self):
@@ -74,7 +74,7 @@ class WeightsTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
74
74
  model, _ = FloatKerasModelBuilder(graph=self.graph).build_model()
75
75
 
76
76
  # Initiate a gradient tape for automatic differentiation
77
- with (tf.GradientTape(persistent=True) as tape):
77
+ with tf.GradientTape(persistent=True) as tape:
78
78
  # Perform a forward pass (inference) to get the output, while watching
79
79
  # the input tensor for gradient computation
80
80
  tape.watch(self.input_images)
@@ -22,6 +22,7 @@ from networkx import topological_sort
22
22
 
23
23
  from model_compression_toolkit.core import FrameworkInfo
24
24
  from model_compression_toolkit.core import common
25
+ from model_compression_toolkit.logger import Logger
25
26
  from model_compression_toolkit.core.common import BaseNode, Graph
26
27
  from model_compression_toolkit.core.common.back2framework.base_model_builder import BaseModelBuilder
27
28
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
@@ -66,8 +67,8 @@ def _build_input_tensors_list(node: BaseNode,
66
67
  return input_tensors
67
68
 
68
69
 
69
- def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List,
70
- tensor_input_indices: List = None) -> List:
70
+ def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List, op_call_kwargs: Dict,
71
+ tensor_input_allocs: List = None) -> Tuple[List, Dict]:
71
72
  """
72
73
  Merge input tensors list with positional weights and op_call_args, according to correct order.
73
74
 
@@ -75,22 +76,30 @@ def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List,
75
76
  _node: The node the inputs are for.
76
77
  input_tensors: activation input tensors to node.
77
78
  op_call_args: framework node call args.
79
+ op_call_kwargs: framework node call kwargs.
80
+ tensor_input_allocs: List of input allocations to node.
78
81
 
79
82
  Returns:
80
83
  Combined list of input_tensors and op_call_args.
81
84
  """
82
- if isinstance(_node, FunctionalNode) and _node.tensor_input_indices:
85
+ if isinstance(_node, FunctionalNode) and _node.tensor_input_allocs:
83
86
  _input_list = op_call_args.copy()
84
- if tensor_input_indices is None:
85
- tensor_input_indices = _node.tensor_input_indices
86
- assert len(tensor_input_indices) == len(input_tensors), \
87
- f'Mismatch between input tensors ({len(tensor_input_indices)}) and indices {len(input_tensors)}'
88
- for i, t in zip(tensor_input_indices, input_tensors):
89
- _input_list.insert(i, t)
87
+ if tensor_input_allocs is None:
88
+ tensor_input_allocs = _node.tensor_input_allocs
89
+ if len(tensor_input_allocs) != len(input_tensors):
90
+ Logger.error(f'Mismatch between input tensors ({len(tensor_input_allocs)}) '
91
+ f'and indices {len(input_tensors)} in node {_node.name}.') # pragma: no cover
92
+ for i, t in zip(tensor_input_allocs, input_tensors):
93
+ # insert input tensors in either args or kwargs, according to tensor_input_allocs
94
+ if isinstance(i, str):
95
+ assert i not in op_call_kwargs
96
+ op_call_kwargs.update({i: t})
97
+ else:
98
+ _input_list.insert(i, t)
90
99
  else:
91
100
  _input_list = input_tensors + op_call_args
92
101
 
93
- return _input_list
102
+ return _input_list, op_call_kwargs
94
103
 
95
104
 
96
105
  def _run_operation(n: BaseNode,
@@ -125,14 +134,15 @@ def _run_operation(n: BaseNode,
125
134
  # list separately, because in FX the tensors are FX objects and fail to_torch_tensor
126
135
  input_tensors = [to_torch_tensor(t, numpy_type=t.dtype) if isinstance(t, np.ndarray) else t
127
136
  for t in input_tensors]
128
- _tensor_input_indices = None
137
+ _tensor_input_allocs = None
129
138
  else:
130
- _tensor_input_indices = [i for i in n.tensor_input_indices if i not in n.weights]
139
+ _tensor_input_allocs = [i for i in n.tensor_input_allocs if i not in n.weights]
131
140
 
132
141
  if isinstance(n, FunctionalNode) and n.inputs_as_list:
133
142
  out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
134
143
  else:
135
- merged_inputs = _merge_inputs(n, input_tensors, op_call_args, tensor_input_indices=_tensor_input_indices)
144
+ merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(),
145
+ tensor_input_allocs=_tensor_input_allocs)
136
146
  out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs)
137
147
 
138
148
  # Add a fake quant node if the node has an activation threshold.
@@ -40,7 +40,7 @@ FUNCTIONAL_OP = 'functional_op'
40
40
  OP_CALL_ARGS = 'op_call_args'
41
41
  OP_CALL_KWARGS = 'op_call_kwargs'
42
42
  INPUTS_AS_LIST = 'inputs_as_list'
43
- TENSOR_INPUT_INDICES = 'tensor_input_indices'
43
+ TENSOR_INPUT_ALLOCS = 'tensor_input_allocs'
44
44
  INPLACE = 'inplace'
45
45
  HARDTANH_MIN_VAL = 'min_val'
46
46
  HARDTANH_MAX_VAL = 'max_val'
@@ -65,10 +65,10 @@ class ReshapeWithStaticShapes(common.BaseSubstitution):
65
65
 
66
66
  # When a "reshape" is called with multiple arguments (e.g. x.reshape(-1, channels, height, width)
67
67
  # this substitution converts it x.reshape((-1, channels, height, width)), so need to update the
68
- # tensor_input_indices attribute.
69
- # scalar argument's shape is [1] so remove those indices from tensor_input_indices
68
+ # tensor_input_allocs attribute.
69
+ # scalar argument's shape is [1] so remove those indices from tensor_input_allocs
70
70
  # node.input_shape example: [[1, 32, 4, 32], [1], [1], [1]]
71
- node.tensor_input_indices = node.tensor_input_indices[:sum([i != [1] for i in node.input_shape])]
71
+ node.tensor_input_allocs = node.tensor_input_allocs[:sum([i != [1] for i in node.input_shape])]
72
72
 
73
73
  # modify the node input info
74
74
  node.input_shape = [node.input_shape[0]]
@@ -23,7 +23,7 @@ from model_compression_toolkit.core.common.graph.base_graph import OutTensor
23
23
  from model_compression_toolkit.core.common.graph.edge import Edge
24
24
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
25
25
  from model_compression_toolkit.core.pytorch.constants import OUTPUT, PLACEHOLDER, TENSOR_META, CALL_FUNCTION, TYPE, \
26
- CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, TENSOR_INPUT_INDICES, GET_ATTR
26
+ CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, TENSOR_INPUT_ALLOCS, GET_ATTR
27
27
  from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
28
28
  from model_compression_toolkit.logger import Logger
29
29
 
@@ -140,7 +140,7 @@ def nodes_builder(model: GraphModule,
140
140
  weights.update({i: consts_dict[input_node]})
141
141
 
142
142
  tensor_meta = input_node.meta
143
- if tensor_meta[TYPE] == torch.Tensor:
143
+ if tensor_meta[TYPE] in [torch.Tensor, torch.nn.parameter.Parameter]:
144
144
  input_shape += [list(tensor_meta[TENSOR_META].shape)]
145
145
  elif tensor_meta[TYPE] == tuple:
146
146
  input_shape += [list(n.shape) for n in tensor_meta[TENSOR_META]]
@@ -159,8 +159,11 @@ def nodes_builder(model: GraphModule,
159
159
 
160
160
  # filter Nodes from framework attributes, we replace these attributes with nx graph nodes
161
161
  framework_attr_filtered = {}
162
+ framework_attr_nodes = {}
162
163
  for k, v in framework_attr.items():
163
- if not isinstance(v, torch.fx.node.Node):
164
+ if isinstance(v, torch.fx.node.Node):
165
+ framework_attr_nodes[k] = v
166
+ else:
164
167
  framework_attr_filtered[k] = v
165
168
  framework_attr = framework_attr_filtered
166
169
 
@@ -177,7 +180,7 @@ def nodes_builder(model: GraphModule,
177
180
  [isinstance(n, torch.fx.node.Node) for n in node.args[0]])
178
181
  inputs_as_list = inputs_as_list1 or (len(node.args) > 0 and isinstance(node.args[0], Node) and
179
182
  node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple))
180
- tensor_input_index = []
183
+ tensor_input_alloc = []
181
184
  op_call_args = list(node.args)
182
185
  if inputs_as_list:
183
186
  op_call_args.pop(0)
@@ -185,7 +188,10 @@ def nodes_builder(model: GraphModule,
185
188
  for in_node in node.all_input_nodes:
186
189
  for i, arg in enumerate(node.args):
187
190
  if arg == in_node:
188
- tensor_input_index.append(i)
191
+ tensor_input_alloc.append(i)
192
+ for k, arg in framework_attr_nodes.items():
193
+ if arg == in_node:
194
+ tensor_input_alloc.append(k)
189
195
 
190
196
  # remove torch.fx.node.Node from inputs to graph_node_type
191
197
  op_call_args = [arg for arg in op_call_args if not isinstance(arg, Node)]
@@ -197,7 +203,7 @@ def nodes_builder(model: GraphModule,
197
203
  OP_CALL_ARGS: op_call_args,
198
204
  OP_CALL_KWARGS: node_kwargs,
199
205
  INPUTS_AS_LIST: inputs_as_list,
200
- TENSOR_INPUT_INDICES: tensor_input_index}
206
+ TENSOR_INPUT_ALLOCS: tensor_input_alloc}
201
207
  else:
202
208
  graph_node_type = BaseNode
203
209
  kwargs = {}