mct-nightly 2.1.0.20240616.434__py3-none-any.whl → 2.1.0.20240617.451__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.1.0.20240616.434.dist-info → mct_nightly-2.1.0.20240617.451.dist-info}/METADATA +1 -1
- {mct_nightly-2.1.0.20240616.434.dist-info → mct_nightly-2.1.0.20240617.451.dist-info}/RECORD +12 -12
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/functional_node.py +3 -3
- model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +23 -13
- model_compression_toolkit/core/pytorch/constants.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +3 -3
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +12 -6
- {mct_nightly-2.1.0.20240616.434.dist-info → mct_nightly-2.1.0.20240617.451.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.1.0.20240616.434.dist-info → mct_nightly-2.1.0.20240617.451.dist-info}/WHEEL +0 -0
- {mct_nightly-2.1.0.20240616.434.dist-info → mct_nightly-2.1.0.20240617.451.dist-info}/top_level.txt +0 -0
{mct_nightly-2.1.0.20240616.434.dist-info → mct_nightly-2.1.0.20240617.451.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=1OlgfgkWMboHXhaDyyG9E0dc_vedCsy6r_gAtSq-lfY,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=9pVleMwnhlM4QwIL2HcEq42I1uF4rlSw63RUjkxOF4w,3923
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -33,7 +33,7 @@ model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaU
|
|
33
33
|
model_compression_toolkit/core/common/graph/base_graph.py,sha256=lmIw0srKiwCvz7KWqfwKTxyQHDy3s6rWMIXzFAa1UMo,38326
|
34
34
|
model_compression_toolkit/core/common/graph/base_node.py,sha256=X_0zqHrKYAsmnj9tAKjVYasbFcZD8OHpjdiMj9ugQs0,29436
|
35
35
|
model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
|
36
|
-
model_compression_toolkit/core/common/graph/functional_node.py,sha256=
|
36
|
+
model_compression_toolkit/core/common/graph/functional_node.py,sha256=_6HsBeLlrpLvXhLPRJswcyDa4z16-O3xzHzGuv46zBc,3897
|
37
37
|
model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
|
38
38
|
model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-oL0lFO9PhQFt9fEFgVJwpc1u4fHExAtE,5128
|
39
39
|
model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=3el-A7j1oyoo1_9zq3faQp7IeRsFXFCvnrb3zZFXpU0,9803
|
@@ -186,7 +186,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/weights_a
|
|
186
186
|
model_compression_toolkit/core/keras/hessian/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
187
187
|
model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py,sha256=4eJKq_Fx4mm_VuBDeeti0fTcUk1lL2yjebxCugJhvrA,8871
|
188
188
|
model_compression_toolkit/core/keras/hessian/trace_hessian_calculator_keras.py,sha256=hRfAjgZakDaIMuERmTVjJSa_Ww6FmEudYPO9R7SuYuQ,3914
|
189
|
-
model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py,sha256=
|
189
|
+
model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py,sha256=Xogd90kPZPvKbplZQv5B77Dq_m4aW5-bL6Jxh33VZWs,12213
|
190
190
|
model_compression_toolkit/core/keras/mixed_precision/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
191
191
|
model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=aW8wR13fK6P6xzbU9XGU60IO1yYzXSo_Hk4qeq486kg,5137
|
192
192
|
model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py,sha256=Ziydik2j-LvNBXP3TSfUD6rEezPAikzQGib0_IXkmGM,6729
|
@@ -210,7 +210,7 @@ model_compression_toolkit/core/keras/statistics_correction/__init__.py,sha256=9H
|
|
210
210
|
model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py,sha256=XNCtT9klMcsO1v5KA3MmCq_WgXOIT5QSzbfTOa9T-04,3060
|
211
211
|
model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
212
212
|
model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
213
|
-
model_compression_toolkit/core/pytorch/constants.py,sha256=
|
213
|
+
model_compression_toolkit/core/pytorch/constants.py,sha256=AguUnAsNlj41gwuKIP_7nos3FcJHsIAjewLXSQdrDQM,2624
|
214
214
|
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
|
215
215
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
|
216
216
|
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=7CFt1Y3fiDaKkEVvlDd76ZmucCuVp6OZNQwwqJezKbU,27547
|
@@ -222,7 +222,7 @@ model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,s
|
|
222
222
|
model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
|
223
223
|
model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
|
224
224
|
model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=D7lU1r9Uq_7fdNuKk2BMF8ho5GrsY-8gyGN6yYoHaVg,15060
|
225
|
-
model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=
|
225
|
+
model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=Pvpkirt3ziWEXDEspgOhR8ALf-XAZUh-78IkXg9YMWs,18830
|
226
226
|
model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
|
227
227
|
model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
228
228
|
model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
|
@@ -240,7 +240,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_
|
|
240
240
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=VNg-VgzCxSyqy2J3neEPl6U0SPO8UIVU_T47bGhz4FE,38459
|
241
241
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py,sha256=q1a3HieQtaOmWG2WGXp6GHYAvxa3CZ9dJUx9dqMAsS8,5695
|
242
242
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_identity.py,sha256=joHjwiUxccypMHkTy46rI91VyapLn9yJ2YRo5ISnOH4,1987
|
243
|
-
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py,sha256=
|
243
|
+
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py,sha256=hAZXzrEinHa-dJHLj39Hy_9Q-13QyO95rtYVSLrhvT8,4915
|
244
244
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py,sha256=DcJEIkGvBdIMOelNIwaJUZ5UsAHiGnDJPR20I464vWo,2929
|
245
245
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py,sha256=XFtU9yuBmoZlX0f0mS6otMPWMk-RcWs94XdvvTNhW8Y,3303
|
246
246
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py,sha256=lOPl5zDU3FoR9WmlxO04Pfi65MimK0gmnuHzQJodQdY,10668
|
@@ -261,7 +261,7 @@ model_compression_toolkit/core/pytorch/quantizer/__init__.py,sha256=Rf1RcYmelmdZ
|
|
261
261
|
model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py,sha256=D8_CEuFqKAhbUgKaRw7Jlxo0zlqgPTMu6CIIIM4LfS0,7045
|
262
262
|
model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py,sha256=uyeBtNokyDUikk-YkDP_mN_2DX0J5oPm3kSfdSUT2Ck,4420
|
263
263
|
model_compression_toolkit/core/pytorch/reader/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
264
|
-
model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=
|
264
|
+
model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=LiGV-ZqlhxN1evpM-ur2dDVPowhrLwO7JZa7AGPftSk,12913
|
265
265
|
model_compression_toolkit/core/pytorch/reader/node_holders.py,sha256=TaolORuwBZEddWe-q0Mg79Nmswz-Sq3-9-4o8UxFQ50,1028
|
266
266
|
model_compression_toolkit/core/pytorch/reader/reader.py,sha256=GEJE0QX8XJFWbYCkbRBtzttZtmmuoACLx8gw9KyAQCE,6015
|
267
267
|
model_compression_toolkit/core/pytorch/statistics_correction/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
@@ -491,8 +491,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
|
|
491
491
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
492
492
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
493
493
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=MxylaVFPgN7zBiRBy6WV610EA4scLgRJFbMucKvvNDU,2896
|
494
|
-
mct_nightly-2.1.0.
|
495
|
-
mct_nightly-2.1.0.
|
496
|
-
mct_nightly-2.1.0.
|
497
|
-
mct_nightly-2.1.0.
|
498
|
-
mct_nightly-2.1.0.
|
494
|
+
mct_nightly-2.1.0.20240617.451.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
495
|
+
mct_nightly-2.1.0.20240617.451.dist-info/METADATA,sha256=-ZUI2y7SZOGKyLl6qpBE9onj-lZwfGE0wLLJI5WeqIE,19721
|
496
|
+
mct_nightly-2.1.0.20240617.451.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
497
|
+
mct_nightly-2.1.0.20240617.451.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
498
|
+
mct_nightly-2.1.0.20240617.451.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.1.0.
|
30
|
+
__version__ = "2.1.0.20240617.000451"
|
@@ -25,7 +25,7 @@ class FunctionalNode(BaseNode):
|
|
25
25
|
functional_op: Any = None,
|
26
26
|
inputs_as_list: bool = False,
|
27
27
|
has_activation: bool = True,
|
28
|
-
|
28
|
+
tensor_input_allocs = None):
|
29
29
|
"""
|
30
30
|
Init a FunctionalNode object.
|
31
31
|
|
@@ -44,7 +44,7 @@ class FunctionalNode(BaseNode):
|
|
44
44
|
functional_op: The op the node implements.
|
45
45
|
inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
|
46
46
|
has_activation: Whether the node has activations that we might want to quantize.
|
47
|
-
|
47
|
+
tensor_input_allocs: A list of indices for activation tensors in the node's input tensor list
|
48
48
|
|
49
49
|
"""
|
50
50
|
|
@@ -63,7 +63,7 @@ class FunctionalNode(BaseNode):
|
|
63
63
|
self.op_call_args = op_call_args
|
64
64
|
self.functional_op = functional_op
|
65
65
|
self.inputs_as_list = inputs_as_list
|
66
|
-
self.
|
66
|
+
self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs
|
67
67
|
|
68
68
|
@property
|
69
69
|
def type(self):
|
@@ -74,7 +74,7 @@ class WeightsTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
|
|
74
74
|
model, _ = FloatKerasModelBuilder(graph=self.graph).build_model()
|
75
75
|
|
76
76
|
# Initiate a gradient tape for automatic differentiation
|
77
|
-
with
|
77
|
+
with tf.GradientTape(persistent=True) as tape:
|
78
78
|
# Perform a forward pass (inference) to get the output, while watching
|
79
79
|
# the input tensor for gradient computation
|
80
80
|
tape.watch(self.input_images)
|
@@ -22,6 +22,7 @@ from networkx import topological_sort
|
|
22
22
|
|
23
23
|
from model_compression_toolkit.core import FrameworkInfo
|
24
24
|
from model_compression_toolkit.core import common
|
25
|
+
from model_compression_toolkit.logger import Logger
|
25
26
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
26
27
|
from model_compression_toolkit.core.common.back2framework.base_model_builder import BaseModelBuilder
|
27
28
|
from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
|
@@ -66,8 +67,8 @@ def _build_input_tensors_list(node: BaseNode,
|
|
66
67
|
return input_tensors
|
67
68
|
|
68
69
|
|
69
|
-
def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List,
|
70
|
-
|
70
|
+
def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List, op_call_kwargs: Dict,
|
71
|
+
tensor_input_allocs: List = None) -> Tuple[List, Dict]:
|
71
72
|
"""
|
72
73
|
Merge input tensors list with positional weights and op_call_args, according to correct order.
|
73
74
|
|
@@ -75,22 +76,30 @@ def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List,
|
|
75
76
|
_node: The node the inputs are for.
|
76
77
|
input_tensors: activation input tensors to node.
|
77
78
|
op_call_args: framework node call args.
|
79
|
+
op_call_kwargs: framework node call kwargs.
|
80
|
+
tensor_input_allocs: List of input allocations to node.
|
78
81
|
|
79
82
|
Returns:
|
80
83
|
Combined list of input_tensors and op_call_args.
|
81
84
|
"""
|
82
|
-
if isinstance(_node, FunctionalNode) and _node.
|
85
|
+
if isinstance(_node, FunctionalNode) and _node.tensor_input_allocs:
|
83
86
|
_input_list = op_call_args.copy()
|
84
|
-
if
|
85
|
-
|
86
|
-
|
87
|
-
f'Mismatch between input tensors ({len(
|
88
|
-
|
89
|
-
|
87
|
+
if tensor_input_allocs is None:
|
88
|
+
tensor_input_allocs = _node.tensor_input_allocs
|
89
|
+
if len(tensor_input_allocs) != len(input_tensors):
|
90
|
+
Logger.error(f'Mismatch between input tensors ({len(tensor_input_allocs)}) '
|
91
|
+
f'and indices {len(input_tensors)} in node {_node.name}.') # pragma: no cover
|
92
|
+
for i, t in zip(tensor_input_allocs, input_tensors):
|
93
|
+
# insert input tensors in either args or kwargs, according to tensor_input_allocs
|
94
|
+
if isinstance(i, str):
|
95
|
+
assert i not in op_call_kwargs
|
96
|
+
op_call_kwargs.update({i: t})
|
97
|
+
else:
|
98
|
+
_input_list.insert(i, t)
|
90
99
|
else:
|
91
100
|
_input_list = input_tensors + op_call_args
|
92
101
|
|
93
|
-
return _input_list
|
102
|
+
return _input_list, op_call_kwargs
|
94
103
|
|
95
104
|
|
96
105
|
def _run_operation(n: BaseNode,
|
@@ -125,14 +134,15 @@ def _run_operation(n: BaseNode,
|
|
125
134
|
# list separately, because in FX the tensors are FX objects and fail to_torch_tensor
|
126
135
|
input_tensors = [to_torch_tensor(t, numpy_type=t.dtype) if isinstance(t, np.ndarray) else t
|
127
136
|
for t in input_tensors]
|
128
|
-
|
137
|
+
_tensor_input_allocs = None
|
129
138
|
else:
|
130
|
-
|
139
|
+
_tensor_input_allocs = [i for i in n.tensor_input_allocs if i not in n.weights]
|
131
140
|
|
132
141
|
if isinstance(n, FunctionalNode) and n.inputs_as_list:
|
133
142
|
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
|
134
143
|
else:
|
135
|
-
merged_inputs = _merge_inputs(n, input_tensors, op_call_args,
|
144
|
+
merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(),
|
145
|
+
tensor_input_allocs=_tensor_input_allocs)
|
136
146
|
out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs)
|
137
147
|
|
138
148
|
# Add a fake quant node if the node has an activation threshold.
|
@@ -40,7 +40,7 @@ FUNCTIONAL_OP = 'functional_op'
|
|
40
40
|
OP_CALL_ARGS = 'op_call_args'
|
41
41
|
OP_CALL_KWARGS = 'op_call_kwargs'
|
42
42
|
INPUTS_AS_LIST = 'inputs_as_list'
|
43
|
-
|
43
|
+
TENSOR_INPUT_ALLOCS = 'tensor_input_allocs'
|
44
44
|
INPLACE = 'inplace'
|
45
45
|
HARDTANH_MIN_VAL = 'min_val'
|
46
46
|
HARDTANH_MAX_VAL = 'max_val'
|
@@ -65,10 +65,10 @@ class ReshapeWithStaticShapes(common.BaseSubstitution):
|
|
65
65
|
|
66
66
|
# When a "reshape" is called with multiple arguments (e.g. x.reshape(-1, channels, height, width)
|
67
67
|
# this substitution converts it x.reshape((-1, channels, height, width)), so need to update the
|
68
|
-
#
|
69
|
-
# scalar argument's shape is [1] so remove those indices from
|
68
|
+
# tensor_input_allocs attribute.
|
69
|
+
# scalar argument's shape is [1] so remove those indices from tensor_input_allocs
|
70
70
|
# node.input_shape example: [[1, 32, 4, 32], [1], [1], [1]]
|
71
|
-
node.
|
71
|
+
node.tensor_input_allocs = node.tensor_input_allocs[:sum([i != [1] for i in node.input_shape])]
|
72
72
|
|
73
73
|
# modify the node input info
|
74
74
|
node.input_shape = [node.input_shape[0]]
|
@@ -23,7 +23,7 @@ from model_compression_toolkit.core.common.graph.base_graph import OutTensor
|
|
23
23
|
from model_compression_toolkit.core.common.graph.edge import Edge
|
24
24
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
25
25
|
from model_compression_toolkit.core.pytorch.constants import OUTPUT, PLACEHOLDER, TENSOR_META, CALL_FUNCTION, TYPE, \
|
26
|
-
CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST,
|
26
|
+
CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, TENSOR_INPUT_ALLOCS, GET_ATTR
|
27
27
|
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
|
28
28
|
from model_compression_toolkit.logger import Logger
|
29
29
|
|
@@ -140,7 +140,7 @@ def nodes_builder(model: GraphModule,
|
|
140
140
|
weights.update({i: consts_dict[input_node]})
|
141
141
|
|
142
142
|
tensor_meta = input_node.meta
|
143
|
-
if tensor_meta[TYPE]
|
143
|
+
if tensor_meta[TYPE] in [torch.Tensor, torch.nn.parameter.Parameter]:
|
144
144
|
input_shape += [list(tensor_meta[TENSOR_META].shape)]
|
145
145
|
elif tensor_meta[TYPE] == tuple:
|
146
146
|
input_shape += [list(n.shape) for n in tensor_meta[TENSOR_META]]
|
@@ -159,8 +159,11 @@ def nodes_builder(model: GraphModule,
|
|
159
159
|
|
160
160
|
# filter Nodes from framework attributes, we replace these attributes with nx graph nodes
|
161
161
|
framework_attr_filtered = {}
|
162
|
+
framework_attr_nodes = {}
|
162
163
|
for k, v in framework_attr.items():
|
163
|
-
if
|
164
|
+
if isinstance(v, torch.fx.node.Node):
|
165
|
+
framework_attr_nodes[k] = v
|
166
|
+
else:
|
164
167
|
framework_attr_filtered[k] = v
|
165
168
|
framework_attr = framework_attr_filtered
|
166
169
|
|
@@ -177,7 +180,7 @@ def nodes_builder(model: GraphModule,
|
|
177
180
|
[isinstance(n, torch.fx.node.Node) for n in node.args[0]])
|
178
181
|
inputs_as_list = inputs_as_list1 or (len(node.args) > 0 and isinstance(node.args[0], Node) and
|
179
182
|
node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple))
|
180
|
-
|
183
|
+
tensor_input_alloc = []
|
181
184
|
op_call_args = list(node.args)
|
182
185
|
if inputs_as_list:
|
183
186
|
op_call_args.pop(0)
|
@@ -185,7 +188,10 @@ def nodes_builder(model: GraphModule,
|
|
185
188
|
for in_node in node.all_input_nodes:
|
186
189
|
for i, arg in enumerate(node.args):
|
187
190
|
if arg == in_node:
|
188
|
-
|
191
|
+
tensor_input_alloc.append(i)
|
192
|
+
for k, arg in framework_attr_nodes.items():
|
193
|
+
if arg == in_node:
|
194
|
+
tensor_input_alloc.append(k)
|
189
195
|
|
190
196
|
# remove torch.fx.node.Node from inputs to graph_node_type
|
191
197
|
op_call_args = [arg for arg in op_call_args if not isinstance(arg, Node)]
|
@@ -197,7 +203,7 @@ def nodes_builder(model: GraphModule,
|
|
197
203
|
OP_CALL_ARGS: op_call_args,
|
198
204
|
OP_CALL_KWARGS: node_kwargs,
|
199
205
|
INPUTS_AS_LIST: inputs_as_list,
|
200
|
-
|
206
|
+
TENSOR_INPUT_ALLOCS: tensor_input_alloc}
|
201
207
|
else:
|
202
208
|
graph_node_type = BaseNode
|
203
209
|
kwargs = {}
|
{mct_nightly-2.1.0.20240616.434.dist-info → mct_nightly-2.1.0.20240617.451.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.1.0.20240616.434.dist-info → mct_nightly-2.1.0.20240617.451.dist-info}/top_level.txt
RENAMED
File without changes
|