mct-nightly 2.1.0.20240717.444__py3-none-any.whl → 2.1.0.20240719.444__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.1.0.20240717.444.dist-info → mct_nightly-2.1.0.20240719.444.dist-info}/METADATA +1 -1
- {mct_nightly-2.1.0.20240717.444.dist-info → mct_nightly-2.1.0.20240719.444.dist-info}/RECORD +10 -10
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/base_node.py +1 -1
- model_compression_toolkit/core/common/graph/functional_node.py +1 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +28 -3
- model_compression_toolkit/core/keras/reader/node_builder.py +143 -54
- {mct_nightly-2.1.0.20240717.444.dist-info → mct_nightly-2.1.0.20240719.444.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.1.0.20240717.444.dist-info → mct_nightly-2.1.0.20240719.444.dist-info}/WHEEL +0 -0
- {mct_nightly-2.1.0.20240717.444.dist-info → mct_nightly-2.1.0.20240719.444.dist-info}/top_level.txt +0 -0
{mct_nightly-2.1.0.20240717.444.dist-info → mct_nightly-2.1.0.20240719.444.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=JshHkyFrgHVzyqpaLEB0DIouHvcS77kBPMhPQ58BgKQ,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=9pVleMwnhlM4QwIL2HcEq42I1uF4rlSw63RUjkxOF4w,3923
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -31,9 +31,9 @@ model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5
|
|
31
31
|
model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=lOubqpc18TslhXZijWUJQAa1c3jIB2S-M-5HK78wJPQ,5548
|
32
32
|
model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
|
33
33
|
model_compression_toolkit/core/common/graph/base_graph.py,sha256=lmIw0srKiwCvz7KWqfwKTxyQHDy3s6rWMIXzFAa1UMo,38326
|
34
|
-
model_compression_toolkit/core/common/graph/base_node.py,sha256=
|
34
|
+
model_compression_toolkit/core/common/graph/base_node.py,sha256=Hqv5lsEYT2uz5FYAX44Rsps_Ax73_kVoXy_DNaXIddU,29448
|
35
35
|
model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
|
36
|
-
model_compression_toolkit/core/common/graph/functional_node.py,sha256=
|
36
|
+
model_compression_toolkit/core/common/graph/functional_node.py,sha256=XvzydBSRxgpYdKS-aYVaWtH3FDzJPKGad3bai9wF3BI,3956
|
37
37
|
model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
|
38
38
|
model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-oL0lFO9PhQFt9fEFgVJwpc1u4fHExAtE,5128
|
39
39
|
model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=3el-A7j1oyoo1_9zq3faQp7IeRsFXFCvnrb3zZFXpU0,9803
|
@@ -159,7 +159,7 @@ model_compression_toolkit/core/keras/back2framework/__init__.py,sha256=rhIiXg_nB
|
|
159
159
|
model_compression_toolkit/core/keras/back2framework/factory_model_builder.py,sha256=urpfyHvIzD08QzPBWusVBT_dKZ8ZUf1I1zIQNb4qe5Y,2233
|
160
160
|
model_compression_toolkit/core/keras/back2framework/float_model_builder.py,sha256=9SFHhX-JnkB8PvYIIHRYlReBDI_RkZY9LditzW_ElLk,2444
|
161
161
|
model_compression_toolkit/core/keras/back2framework/instance_builder.py,sha256=fBj13c6zkVoWX4JJG18_uXPptiEJqXClE_zFbaFB6Q8,4517
|
162
|
-
model_compression_toolkit/core/keras/back2framework/keras_model_builder.py,sha256=
|
162
|
+
model_compression_toolkit/core/keras/back2framework/keras_model_builder.py,sha256=XFSSaET4oPWB_cx-Q_c9pDJfWyQ1qXT9JXBl5FJCTa4,18137
|
163
163
|
model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py,sha256=ygIS1WIiftF1VC3oGhc8N6j7MryKtWgEg8nr50p7f4U,15587
|
164
164
|
model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py,sha256=5wFb4nx_F0Wu4c8pLf6n6OzxOHtpOJ6_3mQsNSXIudU,2481
|
165
165
|
model_compression_toolkit/core/keras/graph_substitutions/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
@@ -199,7 +199,7 @@ model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py,sha256=Up3-sbuA
|
|
199
199
|
model_compression_toolkit/core/keras/reader/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
200
200
|
model_compression_toolkit/core/keras/reader/common.py,sha256=eZWjBcvTDUX7fCWmy1OAH4lYLFTh59_UQ_nP_Gjp4yw,2594
|
201
201
|
model_compression_toolkit/core/keras/reader/connectivity_handler.py,sha256=AgF6qXZOJMeXvc-pBnGY23BJz7wPBx2aTYxHiO8efec,11303
|
202
|
-
model_compression_toolkit/core/keras/reader/node_builder.py,sha256=
|
202
|
+
model_compression_toolkit/core/keras/reader/node_builder.py,sha256=H4ZvaRKt7W7zKzmEDlhveHWf_rLlUCxM0G1s7lnLzn0,13953
|
203
203
|
model_compression_toolkit/core/keras/reader/reader.py,sha256=wS9UQ2wJKnkZYe9JHwQp7ygDr6CRlzrxmIyLDv1Qz6U,8109
|
204
204
|
model_compression_toolkit/core/keras/reader/nested_model/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
205
205
|
model_compression_toolkit/core/keras/reader/nested_model/edges_merger.py,sha256=K6KAH9o8KSG6baLmhKoCrYK-i-wb6gRKiZmoijFqEYA,7906
|
@@ -517,8 +517,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
517
517
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=yrZNVRm2IRU7r7R-hjS2lOQ6wvEEvbeunvf2jKoWjXk,3277
|
518
518
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
519
519
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=eyMoXt5o5EnMr6d-rpCwQdX5mAiYiymvbgKv4tf7-a0,4576
|
520
|
-
mct_nightly-2.1.0.
|
521
|
-
mct_nightly-2.1.0.
|
522
|
-
mct_nightly-2.1.0.
|
523
|
-
mct_nightly-2.1.0.
|
524
|
-
mct_nightly-2.1.0.
|
520
|
+
mct_nightly-2.1.0.20240719.444.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
521
|
+
mct_nightly-2.1.0.20240719.444.dist-info/METADATA,sha256=fhA5w_OA1BHS2PndN7W4uW0KXhv7MM7IPKkTeJxpzLU,19719
|
522
|
+
mct_nightly-2.1.0.20240719.444.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
523
|
+
mct_nightly-2.1.0.20240719.444.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
524
|
+
mct_nightly-2.1.0.20240719.444.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.1.0.
|
30
|
+
__version__ = "2.1.0.20240719.000444"
|
@@ -36,7 +36,7 @@ class BaseNode:
|
|
36
36
|
framework_attr: Dict[str, Any],
|
37
37
|
input_shape: Tuple[Any],
|
38
38
|
output_shape: Tuple[Any],
|
39
|
-
weights: Dict[str, np.ndarray],
|
39
|
+
weights: Dict[Union[str, int], np.ndarray],
|
40
40
|
layer_class: type,
|
41
41
|
reuse: bool = False,
|
42
42
|
reuse_group: str = None,
|
@@ -59,7 +59,7 @@ class FunctionalNode(BaseNode):
|
|
59
59
|
has_activation=has_activation)
|
60
60
|
|
61
61
|
self.op_call_kwargs = op_call_kwargs
|
62
|
-
self.op_call_args = op_call_args
|
62
|
+
self.op_call_args = list(op_call_args)
|
63
63
|
self.functional_op = functional_op
|
64
64
|
self.inputs_as_list = inputs_as_list
|
65
65
|
self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs
|
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
from copy import copy
|
15
16
|
|
16
17
|
import tensorflow as tf
|
17
18
|
from keras.models import Model
|
@@ -19,6 +20,7 @@ from packaging import version
|
|
19
20
|
|
20
21
|
from model_compression_toolkit.core.common.back2framework.base_model_builder import BaseModelBuilder
|
21
22
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
23
|
+
from model_compression_toolkit.logger import Logger
|
22
24
|
|
23
25
|
if version.parse(tf.__version__) >= version.parse("2.13"):
|
24
26
|
from keras import Input
|
@@ -271,15 +273,38 @@ class KerasModelBuilder(BaseModelBuilder):
|
|
271
273
|
out_tensors_of_n_float)
|
272
274
|
else:
|
273
275
|
input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists
|
276
|
+
if isinstance(n, FunctionalNode):
|
277
|
+
op_call_kwargs = {} if n.op_call_kwargs is None else copy(n.op_call_kwargs)
|
274
278
|
if not isinstance(op_func, KerasQuantizationWrapper):
|
275
279
|
# The KerasQuantizationWrapper will insert the quantized positional weights internally.
|
276
|
-
|
280
|
+
if isinstance(n, FunctionalNode):
|
281
|
+
if n.tensor_input_allocs is not None:
|
282
|
+
if n.inputs_as_list:
|
283
|
+
input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
|
284
|
+
else:
|
285
|
+
# If the were any const attributes in the layer's inputs, we retrieve them as kwargs
|
286
|
+
# for the operator call.
|
287
|
+
for pos, k in enumerate(n.tensor_input_allocs):
|
288
|
+
if k not in op_call_kwargs: # op_call_kwargs is initialized because we are under FunctionalNode
|
289
|
+
# If the argument is saved in tensor_input_allocs but does not exists in the node kwargs
|
290
|
+
# then it is expected to be either an input tensor or a positional weight of the node.
|
291
|
+
arg = n.weights.get(pos)
|
292
|
+
if arg is None:
|
293
|
+
if len(input_tensors) == 0:
|
294
|
+
Logger.critical(f"Couldn't find a weight or input tensor matching operator's "
|
295
|
+
f"argument name '{k}' in location {pos} for node {n.name}.")
|
296
|
+
arg = input_tensors.pop(0)
|
297
|
+
op_call_kwargs.update({k: arg})
|
298
|
+
else:
|
299
|
+
# If the operator is not a functional node then positional weights should be inserted
|
300
|
+
# into the inputs list.
|
301
|
+
input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
|
277
302
|
# Build a functional node using its args
|
278
303
|
if isinstance(n, FunctionalNode):
|
279
304
|
if n.inputs_as_list: # If the first argument should be a list of tensors:
|
280
|
-
out_tensors_of_n_float = op_func(input_tensors, *n.op_call_args, **
|
305
|
+
out_tensors_of_n_float = op_func(input_tensors, *n.op_call_args, **op_call_kwargs)
|
281
306
|
else: # If the input tensors should not be a list but iterated:
|
282
|
-
out_tensors_of_n_float = op_func(*input_tensors, *n.op_call_args, **
|
307
|
+
out_tensors_of_n_float = op_func(*input_tensors, *n.op_call_args, **op_call_kwargs)
|
283
308
|
else:
|
284
309
|
# If operator expects a single input tensor, it cannot be a list as it should
|
285
310
|
# have a dtype field.
|
@@ -12,7 +12,9 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from
|
15
|
+
from copy import copy
|
16
|
+
|
17
|
+
from typing import Any, List, Dict, Union, Tuple
|
16
18
|
|
17
19
|
import tensorflow as tf
|
18
20
|
from tensorflow.python.util import tf_inspect
|
@@ -41,7 +43,7 @@ layers = keras.layers
|
|
41
43
|
|
42
44
|
REUSED_IDENTIFIER = '_reused_'
|
43
45
|
|
44
|
-
is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray,
|
46
|
+
is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray, tuple, list))
|
45
47
|
is_tensor = lambda x: isinstance(x, KerasTensor)
|
46
48
|
|
47
49
|
|
@@ -62,35 +64,139 @@ def get_kwargs2index(tfoplambda_layer: TFOpLambda) -> Dict[str, int]:
|
|
62
64
|
Positional weights are saved according to their index in the node's call arguments, so
|
63
65
|
need to know the function arguments' names in case the weights are in the kwargs.
|
64
66
|
|
65
|
-
Note: the kwargs2index dictionary is initialized manually (and not with tf_inspect) so
|
66
|
-
it will only include the arguments that may contain constants. For example, we don't
|
67
|
-
want the transpose_a attribute of tf.matmul to be saved as a constant.
|
68
|
-
|
69
|
-
Every operation we add support to, needs to be added here.
|
70
|
-
|
71
67
|
Args:
|
72
68
|
tfoplambda_layer: TFOpLambda layer.
|
73
69
|
|
74
70
|
Returns:
|
75
71
|
A dictionary with argument number and index: {arg_name: arg_index}.
|
76
72
|
"""
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
73
|
+
|
74
|
+
full_args = tf_inspect.getfullargspec(tfoplambda_layer.function).args
|
75
|
+
|
76
|
+
return {arg_name: i for i, arg_name in enumerate(full_args)}
|
77
|
+
|
78
|
+
|
79
|
+
def _extract_const_attrs_from_kwargs(op_call_kwargs: Dict[str, Any],
|
80
|
+
kwarg2index: Dict[str, int],
|
81
|
+
weights: Dict[Union[str, int], Any]) -> Dict[str, Any]:
|
82
|
+
"""
|
83
|
+
Extract const weights of the layer from the operator's key arguments dictionary.
|
84
|
+
This function extracts the attributes, updates the nodes weights dictionary and removes them from the original
|
85
|
+
kwargs mapping.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
op_call_kwargs: A mapping of the operator key arguments.
|
89
|
+
kwarg2index: A dictionary with argument number and index: {arg_name: arg_index}.
|
90
|
+
weights: Node weights mapping. This dictionary is modified by this function.
|
91
|
+
|
92
|
+
Returns: A modified operator key arguments mapping.
|
93
|
+
|
94
|
+
"""
|
95
|
+
|
96
|
+
# read weights from call kwargs
|
97
|
+
for k, v in op_call_kwargs.items():
|
98
|
+
if is_const(v):
|
99
|
+
# if k in kwarg2index:
|
100
|
+
weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)})
|
101
|
+
|
102
|
+
# remove weights and KerasTensors from op_call_kwargs
|
103
|
+
op_call_kwargs = {k: v for k, v in op_call_kwargs.items()
|
104
|
+
if not (kwarg2index.get(k) in weights or is_tensor(v))}
|
105
|
+
|
106
|
+
return op_call_kwargs
|
107
|
+
|
108
|
+
|
109
|
+
def _build_arguments_alloc(n: KerasNode, inputs_as_list: bool, kwarg2index: Dict[str, int]) -> List:
|
110
|
+
"""
|
111
|
+
Builds arguments allocation list.
|
112
|
+
In Keras, if there is any argument that is a constant, we convert all arguments and inputs to be
|
113
|
+
considered as op kwargs for simpler reconstruction of the model from the graph later.
|
114
|
+
Therefore, we build a location list that includes the argument names (keys).
|
115
|
+
If the input is a list, then we don't need to save the keys, since we can assume that all possible constant
|
116
|
+
arguments are within the first argument (the list) and are stored by their position in the list.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
n: fx node.
|
120
|
+
inputs_as_list: Is node's inputs are a list.
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
A list of argument allocations in the node's inputs.
|
124
|
+
|
125
|
+
"""
|
126
|
+
|
127
|
+
tensor_input_alloc = []
|
128
|
+
op_call_args = list(n.call_args)
|
129
|
+
if not inputs_as_list:
|
130
|
+
sorted_kwargs_pos = sorted(kwarg2index.items(), key=lambda x: x[1])
|
131
|
+
tensor_input_alloc = [k for k, _ in sorted_kwargs_pos[:len(op_call_args)]]
|
132
|
+
for k, idx in sorted_kwargs_pos[len(op_call_args):]:
|
133
|
+
if k in n.call_kwargs:
|
134
|
+
tensor_input_alloc.append(k)
|
135
|
+
|
136
|
+
return tensor_input_alloc
|
137
|
+
|
138
|
+
def _extract_const_attrs_from_args(op_call_args: List[Any],
|
139
|
+
op_call_kwargs: Dict[str, Any],
|
140
|
+
inputs_as_list: bool,
|
141
|
+
tensor_inputs_alloc: List,
|
142
|
+
weights: Dict[Union[str, int], Any]) -> Tuple:
|
143
|
+
"""
|
144
|
+
Extract const weights of the layer from the operator's arguments list.
|
145
|
+
This function extracts the attributes, updates the nodes weights dictionary and removes them from the original
|
146
|
+
arguments list.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
op_call_args: A list of the operator arguments.
|
150
|
+
op_call_kwargs: A mapping of key-arguments of the operator.
|
151
|
+
inputs_as_list: Whether the input of the layer is a list.
|
152
|
+
tensor_inputs_alloc: Allocation of argument inputs to the operator (if there are const inputs, otherwise None).
|
153
|
+
weights: Node weights mapping. This dictionary is modified by this function.
|
154
|
+
|
155
|
+
Returns: A modified operator arguments list.
|
156
|
+
|
157
|
+
"""
|
158
|
+
|
159
|
+
move_args_to_kwargs = tensor_inputs_alloc is not None and len(tensor_inputs_alloc) > 0
|
160
|
+
|
161
|
+
# read weights from call args
|
162
|
+
for i, arg in enumerate(op_call_args[0] if inputs_as_list else op_call_args):
|
163
|
+
if is_const(arg):
|
164
|
+
weights.update({i: to_numpy(arg, is_single_tensor=True)})
|
165
|
+
else:
|
166
|
+
if not inputs_as_list:
|
167
|
+
if move_args_to_kwargs:
|
168
|
+
# In this case we move all arguments and inputs to the kwargs
|
169
|
+
op_call_kwargs.update({tensor_inputs_alloc[i]: arg})
|
170
|
+
|
171
|
+
# remove weights and KerasTensors from op_call_args
|
172
|
+
if inputs_as_list:
|
173
|
+
op_call_args = tuple(op_call_args[1:])
|
174
|
+
else:
|
175
|
+
op_call_args = tuple([a for i, a in enumerate(op_call_args)
|
176
|
+
if not (i in weights or is_tensor(a) or (move_args_to_kwargs and tensor_inputs_alloc[i]
|
177
|
+
in op_call_kwargs))])
|
178
|
+
|
179
|
+
return op_call_args
|
180
|
+
|
181
|
+
|
182
|
+
def _has_const_attributes(op_call_args: List, op_call_kwargs: Dict, input_as_list: bool) -> bool:
|
183
|
+
"""
|
184
|
+
Returns whether the layer's input include a constant tensor (that we might want to quantize).
|
185
|
+
|
186
|
+
Args:
|
187
|
+
op_call_args: A list of arguments to the layer.
|
188
|
+
op_call_kwargs: A dictionary of key-arguments to the layer.
|
189
|
+
input_as_list: Whether the input to the layer is a list of tensors.
|
190
|
+
|
191
|
+
Returns: True if the input arguments include a constant tensor, False otherwise.
|
192
|
+
|
193
|
+
"""
|
194
|
+
if input_as_list:
|
195
|
+
return any([is_const(a) for a in op_call_args[0]])
|
196
|
+
const_args = [a for a in op_call_args if is_const(a)]
|
197
|
+
const_kwargs = [k for k, v in op_call_kwargs.items() if is_const(v)]
|
198
|
+
|
199
|
+
return len(const_args) > 0 or len(const_kwargs) > 0
|
94
200
|
|
95
201
|
|
96
202
|
def build_node(node: KerasNode,
|
@@ -110,8 +216,8 @@ def build_node(node: KerasNode,
|
|
110
216
|
"""
|
111
217
|
keras_layer = node.layer # get the layer the node represents.
|
112
218
|
layer_config = keras_layer.get_config() # layer configuration to reconstruct it.
|
113
|
-
op_call_args = node.call_args
|
114
|
-
op_call_kwargs = node.call_kwargs
|
219
|
+
op_call_args = copy(node.call_args)
|
220
|
+
op_call_kwargs = copy(node.call_kwargs)
|
115
221
|
layer_class = type(keras_layer) # class path to instantiating it in back2framework.
|
116
222
|
weights = {v.name: v.numpy() for v in keras_layer.weights} # layer's weights
|
117
223
|
|
@@ -152,32 +258,14 @@ def build_node(node: KerasNode,
|
|
152
258
|
if len(weights) > 0:
|
153
259
|
Logger.critical('Functional nodes are not expected to have weights in this framework.')
|
154
260
|
|
155
|
-
#
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
# remove weights and KerasTensors and weights from op_call_args
|
164
|
-
if inputs_as_list:
|
165
|
-
op_call_args = tuple(op_call_args[1:])
|
166
|
-
else:
|
167
|
-
op_call_args = tuple([a for i, a in enumerate(op_call_args)
|
168
|
-
if not (i in weights or is_tensor(a))])
|
169
|
-
|
170
|
-
# read weights from call kwargs
|
171
|
-
weight_keys = []
|
172
|
-
for k, v in op_call_kwargs.items():
|
173
|
-
if is_const(v) or (keras_layer.symbol in tf_function_symbols and
|
174
|
-
isinstance(v, (tuple, list))):
|
175
|
-
if k in kwarg2index:
|
176
|
-
weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)})
|
177
|
-
weight_keys.append(k)
|
178
|
-
# remove weights and KerasTensors and weights from op_call_kwargs
|
179
|
-
op_call_kwargs = {k: v for k, v in op_call_kwargs.items()
|
180
|
-
if not (kwarg2index.get(k) in weights or is_tensor(v))}
|
261
|
+
# Build tensor_input_alloc required for the model builder. All inputs are received as a list in the builder,
|
262
|
+
# so tensor_input_alloc is used to allocate each input in the correct place in the node's args & kwargs.
|
263
|
+
tensor_input_alloc = None if not _has_const_attributes(op_call_args, op_call_kwargs, inputs_as_list) \
|
264
|
+
else _build_arguments_alloc(node, inputs_as_list, kwarg2index)
|
265
|
+
|
266
|
+
op_call_args = _extract_const_attrs_from_args(op_call_args, op_call_kwargs, inputs_as_list,
|
267
|
+
tensor_input_alloc, weights)
|
268
|
+
op_call_kwargs = _extract_const_attrs_from_kwargs(op_call_kwargs, kwarg2index, weights)
|
181
269
|
|
182
270
|
node = FunctionalNode(node_name,
|
183
271
|
layer_config,
|
@@ -190,7 +278,8 @@ def build_node(node: KerasNode,
|
|
190
278
|
is_reused,
|
191
279
|
reuse_group,
|
192
280
|
functional_op=keras_layer.function,
|
193
|
-
inputs_as_list=inputs_as_list
|
281
|
+
inputs_as_list=inputs_as_list,
|
282
|
+
tensor_input_allocs=tensor_input_alloc)
|
194
283
|
else:
|
195
284
|
# Read constant weights from layers such as layers.Add
|
196
285
|
if len(op_call_args) > 0 and isinstance(op_call_args[0], (list, tuple)):
|
{mct_nightly-2.1.0.20240717.444.dist-info → mct_nightly-2.1.0.20240719.444.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.1.0.20240717.444.dist-info → mct_nightly-2.1.0.20240719.444.dist-info}/top_level.txt
RENAMED
File without changes
|