mct-nightly 2.1.0.20240717.444__py3-none-any.whl → 2.1.0.20240719.444__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.1.0.20240717.444
3
+ Version: 2.1.0.20240719.444
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=lgLwXN1Xe39gig854ixa-q2WUH5g4yl5eKb4XC_EKsg,1573
1
+ model_compression_toolkit/__init__.py,sha256=JshHkyFrgHVzyqpaLEB0DIouHvcS77kBPMhPQ58BgKQ,1573
2
2
  model_compression_toolkit/constants.py,sha256=9pVleMwnhlM4QwIL2HcEq42I1uF4rlSw63RUjkxOF4w,3923
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -31,9 +31,9 @@ model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5
31
31
  model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=lOubqpc18TslhXZijWUJQAa1c3jIB2S-M-5HK78wJPQ,5548
32
32
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
33
33
  model_compression_toolkit/core/common/graph/base_graph.py,sha256=lmIw0srKiwCvz7KWqfwKTxyQHDy3s6rWMIXzFAa1UMo,38326
34
- model_compression_toolkit/core/common/graph/base_node.py,sha256=X_0zqHrKYAsmnj9tAKjVYasbFcZD8OHpjdiMj9ugQs0,29436
34
+ model_compression_toolkit/core/common/graph/base_node.py,sha256=Hqv5lsEYT2uz5FYAX44Rsps_Ax73_kVoXy_DNaXIddU,29448
35
35
  model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
36
- model_compression_toolkit/core/common/graph/functional_node.py,sha256=BbxQ-WRk4R-5hbpQDBANkhRRTkaG7eogeiJwLfLb_EU,3950
36
+ model_compression_toolkit/core/common/graph/functional_node.py,sha256=XvzydBSRxgpYdKS-aYVaWtH3FDzJPKGad3bai9wF3BI,3956
37
37
  model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
38
38
  model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-oL0lFO9PhQFt9fEFgVJwpc1u4fHExAtE,5128
39
39
  model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=3el-A7j1oyoo1_9zq3faQp7IeRsFXFCvnrb3zZFXpU0,9803
@@ -159,7 +159,7 @@ model_compression_toolkit/core/keras/back2framework/__init__.py,sha256=rhIiXg_nB
159
159
  model_compression_toolkit/core/keras/back2framework/factory_model_builder.py,sha256=urpfyHvIzD08QzPBWusVBT_dKZ8ZUf1I1zIQNb4qe5Y,2233
160
160
  model_compression_toolkit/core/keras/back2framework/float_model_builder.py,sha256=9SFHhX-JnkB8PvYIIHRYlReBDI_RkZY9LditzW_ElLk,2444
161
161
  model_compression_toolkit/core/keras/back2framework/instance_builder.py,sha256=fBj13c6zkVoWX4JJG18_uXPptiEJqXClE_zFbaFB6Q8,4517
162
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py,sha256=KXA5rik1cvCSgIbybPfX3tsMlzoILDboVONGyqzXGh0,16290
162
+ model_compression_toolkit/core/keras/back2framework/keras_model_builder.py,sha256=XFSSaET4oPWB_cx-Q_c9pDJfWyQ1qXT9JXBl5FJCTa4,18137
163
163
  model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py,sha256=ygIS1WIiftF1VC3oGhc8N6j7MryKtWgEg8nr50p7f4U,15587
164
164
  model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py,sha256=5wFb4nx_F0Wu4c8pLf6n6OzxOHtpOJ6_3mQsNSXIudU,2481
165
165
  model_compression_toolkit/core/keras/graph_substitutions/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
@@ -199,7 +199,7 @@ model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py,sha256=Up3-sbuA
199
199
  model_compression_toolkit/core/keras/reader/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
200
200
  model_compression_toolkit/core/keras/reader/common.py,sha256=eZWjBcvTDUX7fCWmy1OAH4lYLFTh59_UQ_nP_Gjp4yw,2594
201
201
  model_compression_toolkit/core/keras/reader/connectivity_handler.py,sha256=AgF6qXZOJMeXvc-pBnGY23BJz7wPBx2aTYxHiO8efec,11303
202
- model_compression_toolkit/core/keras/reader/node_builder.py,sha256=SAPkgL8aqJjnB6eCucU2D4m50WACCzWC8wjCVtFnwp8,10424
202
+ model_compression_toolkit/core/keras/reader/node_builder.py,sha256=H4ZvaRKt7W7zKzmEDlhveHWf_rLlUCxM0G1s7lnLzn0,13953
203
203
  model_compression_toolkit/core/keras/reader/reader.py,sha256=wS9UQ2wJKnkZYe9JHwQp7ygDr6CRlzrxmIyLDv1Qz6U,8109
204
204
  model_compression_toolkit/core/keras/reader/nested_model/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
205
205
  model_compression_toolkit/core/keras/reader/nested_model/edges_merger.py,sha256=K6KAH9o8KSG6baLmhKoCrYK-i-wb6gRKiZmoijFqEYA,7906
@@ -517,8 +517,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
517
517
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=yrZNVRm2IRU7r7R-hjS2lOQ6wvEEvbeunvf2jKoWjXk,3277
518
518
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
519
519
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=eyMoXt5o5EnMr6d-rpCwQdX5mAiYiymvbgKv4tf7-a0,4576
520
- mct_nightly-2.1.0.20240717.444.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
521
- mct_nightly-2.1.0.20240717.444.dist-info/METADATA,sha256=-s4IptD94mb3pm2pxywOC5RpgZ1D0NULKBxTKOz-zNg,19719
522
- mct_nightly-2.1.0.20240717.444.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
523
- mct_nightly-2.1.0.20240717.444.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
524
- mct_nightly-2.1.0.20240717.444.dist-info/RECORD,,
520
+ mct_nightly-2.1.0.20240719.444.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
521
+ mct_nightly-2.1.0.20240719.444.dist-info/METADATA,sha256=fhA5w_OA1BHS2PndN7W4uW0KXhv7MM7IPKkTeJxpzLU,19719
522
+ mct_nightly-2.1.0.20240719.444.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
523
+ mct_nightly-2.1.0.20240719.444.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
524
+ mct_nightly-2.1.0.20240719.444.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.1.0.20240717.000444"
30
+ __version__ = "2.1.0.20240719.000444"
@@ -36,7 +36,7 @@ class BaseNode:
36
36
  framework_attr: Dict[str, Any],
37
37
  input_shape: Tuple[Any],
38
38
  output_shape: Tuple[Any],
39
- weights: Dict[str, np.ndarray],
39
+ weights: Dict[Union[str, int], np.ndarray],
40
40
  layer_class: type,
41
41
  reuse: bool = False,
42
42
  reuse_group: str = None,
@@ -59,7 +59,7 @@ class FunctionalNode(BaseNode):
59
59
  has_activation=has_activation)
60
60
 
61
61
  self.op_call_kwargs = op_call_kwargs
62
- self.op_call_args = op_call_args
62
+ self.op_call_args = list(op_call_args)
63
63
  self.functional_op = functional_op
64
64
  self.inputs_as_list = inputs_as_list
65
65
  self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from copy import copy
15
16
 
16
17
  import tensorflow as tf
17
18
  from keras.models import Model
@@ -19,6 +20,7 @@ from packaging import version
19
20
 
20
21
  from model_compression_toolkit.core.common.back2framework.base_model_builder import BaseModelBuilder
21
22
  from model_compression_toolkit.core.common.user_info import UserInformation
23
+ from model_compression_toolkit.logger import Logger
22
24
 
23
25
  if version.parse(tf.__version__) >= version.parse("2.13"):
24
26
  from keras import Input
@@ -271,15 +273,38 @@ class KerasModelBuilder(BaseModelBuilder):
271
273
  out_tensors_of_n_float)
272
274
  else:
273
275
  input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists
276
+ if isinstance(n, FunctionalNode):
277
+ op_call_kwargs = {} if n.op_call_kwargs is None else copy(n.op_call_kwargs)
274
278
  if not isinstance(op_func, KerasQuantizationWrapper):
275
279
  # The KerasQuantizationWrapper will insert the quantized positional weights internally.
276
- input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
280
+ if isinstance(n, FunctionalNode):
281
+ if n.tensor_input_allocs is not None:
282
+ if n.inputs_as_list:
283
+ input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
284
+ else:
285
+ # If the were any const attributes in the layer's inputs, we retrieve them as kwargs
286
+ # for the operator call.
287
+ for pos, k in enumerate(n.tensor_input_allocs):
288
+ if k not in op_call_kwargs: # op_call_kwargs is initialized because we are under FunctionalNode
289
+ # If the argument is saved in tensor_input_allocs but does not exists in the node kwargs
290
+ # then it is expected to be either an input tensor or a positional weight of the node.
291
+ arg = n.weights.get(pos)
292
+ if arg is None:
293
+ if len(input_tensors) == 0:
294
+ Logger.critical(f"Couldn't find a weight or input tensor matching operator's "
295
+ f"argument name '{k}' in location {pos} for node {n.name}.")
296
+ arg = input_tensors.pop(0)
297
+ op_call_kwargs.update({k: arg})
298
+ else:
299
+ # If the operator is not a functional node then positional weights should be inserted
300
+ # into the inputs list.
301
+ input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
277
302
  # Build a functional node using its args
278
303
  if isinstance(n, FunctionalNode):
279
304
  if n.inputs_as_list: # If the first argument should be a list of tensors:
280
- out_tensors_of_n_float = op_func(input_tensors, *n.op_call_args, **n.op_call_kwargs)
305
+ out_tensors_of_n_float = op_func(input_tensors, *n.op_call_args, **op_call_kwargs)
281
306
  else: # If the input tensors should not be a list but iterated:
282
- out_tensors_of_n_float = op_func(*input_tensors, *n.op_call_args, **n.op_call_kwargs)
307
+ out_tensors_of_n_float = op_func(*input_tensors, *n.op_call_args, **op_call_kwargs)
283
308
  else:
284
309
  # If operator expects a single input tensor, it cannot be a list as it should
285
310
  # have a dtype field.
@@ -12,7 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Any, List, Dict
15
+ from copy import copy
16
+
17
+ from typing import Any, List, Dict, Union, Tuple
16
18
 
17
19
  import tensorflow as tf
18
20
  from tensorflow.python.util import tf_inspect
@@ -41,7 +43,7 @@ layers = keras.layers
41
43
 
42
44
  REUSED_IDENTIFIER = '_reused_'
43
45
 
44
- is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray, float))
46
+ is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray, tuple, list))
45
47
  is_tensor = lambda x: isinstance(x, KerasTensor)
46
48
 
47
49
 
@@ -62,35 +64,139 @@ def get_kwargs2index(tfoplambda_layer: TFOpLambda) -> Dict[str, int]:
62
64
  Positional weights are saved according to their index in the node's call arguments, so
63
65
  need to know the function arguments' names in case the weights are in the kwargs.
64
66
 
65
- Note: the kwargs2index dictionary is initialized manually (and not with tf_inspect) so
66
- it will only include the arguments that may contain constants. For example, we don't
67
- want the transpose_a attribute of tf.matmul to be saved as a constant.
68
-
69
- Every operation we add support to, needs to be added here.
70
-
71
67
  Args:
72
68
  tfoplambda_layer: TFOpLambda layer.
73
69
 
74
70
  Returns:
75
71
  A dictionary with argument number and index: {arg_name: arg_index}.
76
72
  """
77
- kwargs2index = {tf.add: {'x': 0, 'y': 1},
78
- tf.subtract: {'x': 0, 'y': 1},
79
- tf.divide: {'x': 0, 'y': 1},
80
- tf.truediv: {'x': 0, 'y': 1},
81
- tf.multiply: {'x': 0, 'y': 1},
82
- tf.pow: {'x': 0, 'y': 1},
83
- tf.matmul: {'a': 0, 'b': 1}}.get(tfoplambda_layer.function)
84
- if not kwargs2index:
85
- # In TF 2.15 the function attribute is different and doesn't match the original
86
- # operation object we use. Therefore, we extract kwargs2index with the symbol.
87
- kwargs2index = {'__operators__.add': {'x': 0, 'y': 1},
88
- 'math.add': {'x': 0, 'y': 1},
89
- 'math.multiply': {'x': 0, 'y': 1},
90
- 'linalg.matmul': {'a': 0, 'b': 1},
91
- 'concat': {'values': 0}}.get(tfoplambda_layer.symbol, {})
92
-
93
- return kwargs2index
73
+
74
+ full_args = tf_inspect.getfullargspec(tfoplambda_layer.function).args
75
+
76
+ return {arg_name: i for i, arg_name in enumerate(full_args)}
77
+
78
+
79
+ def _extract_const_attrs_from_kwargs(op_call_kwargs: Dict[str, Any],
80
+ kwarg2index: Dict[str, int],
81
+ weights: Dict[Union[str, int], Any]) -> Dict[str, Any]:
82
+ """
83
+ Extract const weights of the layer from the operator's key arguments dictionary.
84
+ This function extracts the attributes, updates the nodes weights dictionary and removes them from the original
85
+ kwargs mapping.
86
+
87
+ Args:
88
+ op_call_kwargs: A mapping of the operator key arguments.
89
+ kwarg2index: A dictionary with argument number and index: {arg_name: arg_index}.
90
+ weights: Node weights mapping. This dictionary is modified by this function.
91
+
92
+ Returns: A modified operator key arguments mapping.
93
+
94
+ """
95
+
96
+ # read weights from call kwargs
97
+ for k, v in op_call_kwargs.items():
98
+ if is_const(v):
99
+ # if k in kwarg2index:
100
+ weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)})
101
+
102
+ # remove weights and KerasTensors from op_call_kwargs
103
+ op_call_kwargs = {k: v for k, v in op_call_kwargs.items()
104
+ if not (kwarg2index.get(k) in weights or is_tensor(v))}
105
+
106
+ return op_call_kwargs
107
+
108
+
109
+ def _build_arguments_alloc(n: KerasNode, inputs_as_list: bool, kwarg2index: Dict[str, int]) -> List:
110
+ """
111
+ Builds arguments allocation list.
112
+ In Keras, if there is any argument that is a constant, we convert all arguments and inputs to be
113
+ considered as op kwargs for simpler reconstruction of the model from the graph later.
114
+ Therefore, we build a location list that includes the argument names (keys).
115
+ If the input is a list, then we don't need to save the keys, since we can assume that all possible constant
116
+ arguments are within the first argument (the list) and are stored by their position in the list.
117
+
118
+ Args:
119
+ n: fx node.
120
+ inputs_as_list: Is node's inputs are a list.
121
+
122
+ Returns:
123
+ A list of argument allocations in the node's inputs.
124
+
125
+ """
126
+
127
+ tensor_input_alloc = []
128
+ op_call_args = list(n.call_args)
129
+ if not inputs_as_list:
130
+ sorted_kwargs_pos = sorted(kwarg2index.items(), key=lambda x: x[1])
131
+ tensor_input_alloc = [k for k, _ in sorted_kwargs_pos[:len(op_call_args)]]
132
+ for k, idx in sorted_kwargs_pos[len(op_call_args):]:
133
+ if k in n.call_kwargs:
134
+ tensor_input_alloc.append(k)
135
+
136
+ return tensor_input_alloc
137
+
138
+ def _extract_const_attrs_from_args(op_call_args: List[Any],
139
+ op_call_kwargs: Dict[str, Any],
140
+ inputs_as_list: bool,
141
+ tensor_inputs_alloc: List,
142
+ weights: Dict[Union[str, int], Any]) -> Tuple:
143
+ """
144
+ Extract const weights of the layer from the operator's arguments list.
145
+ This function extracts the attributes, updates the nodes weights dictionary and removes them from the original
146
+ arguments list.
147
+
148
+ Args:
149
+ op_call_args: A list of the operator arguments.
150
+ op_call_kwargs: A mapping of key-arguments of the operator.
151
+ inputs_as_list: Whether the input of the layer is a list.
152
+ tensor_inputs_alloc: Allocation of argument inputs to the operator (if there are const inputs, otherwise None).
153
+ weights: Node weights mapping. This dictionary is modified by this function.
154
+
155
+ Returns: A modified operator arguments list.
156
+
157
+ """
158
+
159
+ move_args_to_kwargs = tensor_inputs_alloc is not None and len(tensor_inputs_alloc) > 0
160
+
161
+ # read weights from call args
162
+ for i, arg in enumerate(op_call_args[0] if inputs_as_list else op_call_args):
163
+ if is_const(arg):
164
+ weights.update({i: to_numpy(arg, is_single_tensor=True)})
165
+ else:
166
+ if not inputs_as_list:
167
+ if move_args_to_kwargs:
168
+ # In this case we move all arguments and inputs to the kwargs
169
+ op_call_kwargs.update({tensor_inputs_alloc[i]: arg})
170
+
171
+ # remove weights and KerasTensors from op_call_args
172
+ if inputs_as_list:
173
+ op_call_args = tuple(op_call_args[1:])
174
+ else:
175
+ op_call_args = tuple([a for i, a in enumerate(op_call_args)
176
+ if not (i in weights or is_tensor(a) or (move_args_to_kwargs and tensor_inputs_alloc[i]
177
+ in op_call_kwargs))])
178
+
179
+ return op_call_args
180
+
181
+
182
+ def _has_const_attributes(op_call_args: List, op_call_kwargs: Dict, input_as_list: bool) -> bool:
183
+ """
184
+ Returns whether the layer's input include a constant tensor (that we might want to quantize).
185
+
186
+ Args:
187
+ op_call_args: A list of arguments to the layer.
188
+ op_call_kwargs: A dictionary of key-arguments to the layer.
189
+ input_as_list: Whether the input to the layer is a list of tensors.
190
+
191
+ Returns: True if the input arguments include a constant tensor, False otherwise.
192
+
193
+ """
194
+ if input_as_list:
195
+ return any([is_const(a) for a in op_call_args[0]])
196
+ const_args = [a for a in op_call_args if is_const(a)]
197
+ const_kwargs = [k for k, v in op_call_kwargs.items() if is_const(v)]
198
+
199
+ return len(const_args) > 0 or len(const_kwargs) > 0
94
200
 
95
201
 
96
202
  def build_node(node: KerasNode,
@@ -110,8 +216,8 @@ def build_node(node: KerasNode,
110
216
  """
111
217
  keras_layer = node.layer # get the layer the node represents.
112
218
  layer_config = keras_layer.get_config() # layer configuration to reconstruct it.
113
- op_call_args = node.call_args
114
- op_call_kwargs = node.call_kwargs
219
+ op_call_args = copy(node.call_args)
220
+ op_call_kwargs = copy(node.call_kwargs)
115
221
  layer_class = type(keras_layer) # class path to instantiating it in back2framework.
116
222
  weights = {v.name: v.numpy() for v in keras_layer.weights} # layer's weights
117
223
 
@@ -152,32 +258,14 @@ def build_node(node: KerasNode,
152
258
  if len(weights) > 0:
153
259
  Logger.critical('Functional nodes are not expected to have weights in this framework.')
154
260
 
155
- # read weights from call args
156
- tf_function_symbols = get_tf_function_symbols()
157
- for i, arg in enumerate(op_call_args[0] if inputs_as_list else op_call_args):
158
- if is_const(arg) or (
159
- keras_layer.symbol in tf_function_symbols and
160
- isinstance(arg, (tuple, list))):
161
- if inputs_as_list or i in kwarg2index.values():
162
- weights.update({i: to_numpy(arg, is_single_tensor=True)})
163
- # remove weights and KerasTensors and weights from op_call_args
164
- if inputs_as_list:
165
- op_call_args = tuple(op_call_args[1:])
166
- else:
167
- op_call_args = tuple([a for i, a in enumerate(op_call_args)
168
- if not (i in weights or is_tensor(a))])
169
-
170
- # read weights from call kwargs
171
- weight_keys = []
172
- for k, v in op_call_kwargs.items():
173
- if is_const(v) or (keras_layer.symbol in tf_function_symbols and
174
- isinstance(v, (tuple, list))):
175
- if k in kwarg2index:
176
- weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)})
177
- weight_keys.append(k)
178
- # remove weights and KerasTensors and weights from op_call_kwargs
179
- op_call_kwargs = {k: v for k, v in op_call_kwargs.items()
180
- if not (kwarg2index.get(k) in weights or is_tensor(v))}
261
+ # Build tensor_input_alloc required for the model builder. All inputs are received as a list in the builder,
262
+ # so tensor_input_alloc is used to allocate each input in the correct place in the node's args & kwargs.
263
+ tensor_input_alloc = None if not _has_const_attributes(op_call_args, op_call_kwargs, inputs_as_list) \
264
+ else _build_arguments_alloc(node, inputs_as_list, kwarg2index)
265
+
266
+ op_call_args = _extract_const_attrs_from_args(op_call_args, op_call_kwargs, inputs_as_list,
267
+ tensor_input_alloc, weights)
268
+ op_call_kwargs = _extract_const_attrs_from_kwargs(op_call_kwargs, kwarg2index, weights)
181
269
 
182
270
  node = FunctionalNode(node_name,
183
271
  layer_config,
@@ -190,7 +278,8 @@ def build_node(node: KerasNode,
190
278
  is_reused,
191
279
  reuse_group,
192
280
  functional_op=keras_layer.function,
193
- inputs_as_list=inputs_as_list)
281
+ inputs_as_list=inputs_as_list,
282
+ tensor_input_allocs=tensor_input_alloc)
194
283
  else:
195
284
  # Read constant weights from layers such as layers.Add
196
285
  if len(op_call_args) > 0 and isinstance(op_call_args[0], (list, tuple)):