mct-nightly 2.3.0.20250518.615__py3-none-any.whl → 2.3.0.20250520.607__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.3.0.20250518.615.dist-info → mct_nightly-2.3.0.20250520.607.dist-info}/METADATA +1 -1
- {mct_nightly-2.3.0.20250518.615.dist-info → mct_nightly-2.3.0.20250520.607.dist-info}/RECORD +17 -17
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/fusion/fusing_info.py +99 -32
- model_compression_toolkit/core/common/fusion/graph_fuser.py +5 -3
- model_compression_toolkit/core/common/graph/base_graph.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -2
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +84 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/runner.py +1 -1
- {mct_nightly-2.3.0.20250518.615.dist-info → mct_nightly-2.3.0.20250520.607.dist-info}/WHEEL +0 -0
- {mct_nightly-2.3.0.20250518.615.dist-info → mct_nightly-2.3.0.20250520.607.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.3.0.20250518.615.dist-info → mct_nightly-2.3.0.20250520.607.dist-info}/top_level.txt +0 -0
{mct_nightly-2.3.0.20250518.615.dist-info → mct_nightly-2.3.0.20250520.607.dist-info}/RECORD
RENAMED
@@ -1,5 +1,5 @@
|
|
1
|
-
mct_nightly-2.3.0.
|
2
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
mct_nightly-2.3.0.20250520.607.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
2
|
+
model_compression_toolkit/__init__.py,sha256=RpJZTYpgYNwzucm8C8weG2IEEP-2HAewn-7SA_sMQh0,1557
|
3
3
|
model_compression_toolkit/constants.py,sha256=KNgiNLpsMgSYyXMNEbHXd4bFNerQc1D6HH3vpbUq_Gs,4086
|
4
4
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
5
5
|
model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
|
@@ -9,7 +9,7 @@ model_compression_toolkit/core/__init__.py,sha256=8a0wUNBKwTdJGDk_Ho6WQAXjGuCqQZ
|
|
9
9
|
model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
|
10
10
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=C6eUTd-fcgxk0LUbt51gFZwmyDDDEB8-9Q4kr9ujYvI,11555
|
11
11
|
model_compression_toolkit/core/quantization_prep_runner.py,sha256=DPevqQ8brkdut8K5f5v9g5lbT3r1GSmhLAk3NkL40Fg,6593
|
12
|
-
model_compression_toolkit/core/runner.py,sha256=
|
12
|
+
model_compression_toolkit/core/runner.py,sha256=EM3B_t_TDUr_ttrQvZFhf6qxO9aIAYOwdl5FU8Y32Ow,13064
|
13
13
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
14
14
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
15
15
|
model_compression_toolkit/core/common/framework_implementation.py,sha256=JQI_eoZZoNk5Y_jAxLfYt9-wzfs7zGpTldz9UblxmMc,21182
|
@@ -31,10 +31,10 @@ model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.p
|
|
31
31
|
model_compression_toolkit/core/common/collectors/statistics_collector.py,sha256=psijsQZefwjMDH8SU5E18n65HiGtQilPhKr1hhzZX-I,8268
|
32
32
|
model_compression_toolkit/core/common/collectors/weighted_histogram_collector.py,sha256=zp3dE7YTqWmkD5QWdRhsl9zD8W6Lr96G1Wjw1g2D3T0,4894
|
33
33
|
model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
34
|
-
model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=
|
35
|
-
model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=
|
34
|
+
model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=uDxF0awrjn3SbcpXBpoQ4OGcKO6Z7HBk8ierZPCGbGo,21970
|
35
|
+
model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=yxxxuwrmQ4wLW-PlTu0MEW59LmNJEh1OWy9Li15YH-8,7520
|
36
36
|
model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
|
37
|
-
model_compression_toolkit/core/common/graph/base_graph.py,sha256=
|
37
|
+
model_compression_toolkit/core/common/graph/base_graph.py,sha256=YA0c8ucaaZu9eRO-xruLqDT3QFOpxq24ViG6ILS2jqA,41403
|
38
38
|
model_compression_toolkit/core/common/graph/base_node.py,sha256=AbUadAT581zelVcGcK9_--6CAGiht9qwkeWahwT3RzE,33389
|
39
39
|
model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
|
40
40
|
model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
|
@@ -67,16 +67,16 @@ model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_uti
|
|
67
67
|
model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=-x8edUyudu1EAEM66AuXPtgayLpzbxoLNubfEbFM5kU,2867
|
68
68
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=6pLUEEIqRTVIlCYQC4JIvY55KAvuBHEX8uTOQ-1Ac4Q,3859
|
69
69
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=BO4ouM_UVS9Fg0z95gLJSMz1ep6YQC5za_iXI_qW2yQ,5399
|
70
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256
|
70
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=axgAypzsiCOw04ZOtOEjK4riuNsaEU2qU6KkWnEXtMo,4951
|
71
71
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=KhiHGpmN5QbpyJQnTZmXigdXFlSlRNqpOOyKGj1Fwek,6412
|
72
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=
|
72
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=MXOK9WPy3fSt5uxsWYMF4szwwqWWgrlzNJdE9VIb-AQ,28145
|
73
73
|
model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=4uhUXKgwyMrJqEVK5uJzVr67GI5YzDTHLveV4maB7z0,28079
|
74
74
|
model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=Zn6SgzGLWWKmuYGHd1YtKxZdYnQWRDeXEkKlBiTbHcs,2929
|
75
75
|
model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=MY8df-c_kITEr_7hOctaxhdiq29hSTA0La9Qo0oTJJY,9678
|
76
76
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
77
77
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
|
78
|
-
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256
|
79
|
-
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=
|
78
|
+
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=8f6KDTKD8SzVXDl9jmYJ-p19cQB0Nr_UTdCPuhELTdg,40329
|
79
|
+
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=ZY5yFIDzbaqIk0UzakDBObfsVevn4fydqAfAm4RCikY,4058
|
80
80
|
model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
81
81
|
model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=6Z6nQL9UH7B8dbcUR0cuCTEYFOKZAlvOb-SCk_cAZFA,6670
|
82
82
|
model_compression_toolkit/core/common/network_editors/__init__.py,sha256=vZmu55bYqiaOQs3AjfwWDXHmuKZcLHt-wm7uR5fPEqg,1307
|
@@ -138,14 +138,14 @@ model_compression_toolkit/core/common/statistics_correction/statistics_correctio
|
|
138
138
|
model_compression_toolkit/core/common/substitutions/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
139
139
|
model_compression_toolkit/core/common/substitutions/apply_substitutions.py,sha256=k-bifmakHIYZeZS-4T1QpZ1Et6AwAijMRgAKs7hmMKc,1390
|
140
140
|
model_compression_toolkit/core/common/substitutions/batchnorm_folding.py,sha256=wLlTT7sqUffKHwOrMG2VV5SktQkkP54l8taW1Fq0mh0,13392
|
141
|
-
model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=
|
141
|
+
model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=ndAKcbnNtDQ0DfL9WOYMYPlxU71t7xo9uxvaFZQsfjI,8501
|
142
142
|
model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py,sha256=dWJpVfomF4Ppeeor3VzS23TXHyBm85QI7snyLOYP_ko,9972
|
143
143
|
model_compression_toolkit/core/common/substitutions/linear_collapsing.py,sha256=iEtzbWCDXP6EDkTZCtREQ0rpMxhQ2kM9zlcP_0KLq9I,12367
|
144
144
|
model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py,sha256=uoauhmncQqUBNvD-qCLIXsIbl_IzrbxSKdxiMig-5W4,2406
|
145
145
|
model_compression_toolkit/core/common/substitutions/remove_identity.py,sha256=TKU1TIU52UIkVnl0EZvWnDhLV9nIVZ4hqi-w1i4NXMk,2637
|
146
146
|
model_compression_toolkit/core/common/substitutions/residual_collapsing.py,sha256=N82mso5j3EJQlKt9EMHjjEJ67FmdGQeCfN8U5grOFXo,4830
|
147
147
|
model_compression_toolkit/core/common/substitutions/scale_equalization.py,sha256=p57u25qdW2pimxzGwgMXEBV4S-LzXuTVAlIM7830WfU,10966
|
148
|
-
model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=
|
148
|
+
model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=VcwCVWEooYwg6NGcnSP8OaSzgtzSd4k1r-5a68rpqZc,33713
|
149
149
|
model_compression_toolkit/core/common/substitutions/softmax_shift.py,sha256=R-0ZqhYAuZLEFWHvB2UTPm52L6gWHGdRdEnwGxKSeGI,2625
|
150
150
|
model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py,sha256=w43dRmaG96a8SNECgghxoFCTSoZ-vUb33dXGm2PbomE,4251
|
151
151
|
model_compression_toolkit/core/common/substitutions/weights_activation_split.py,sha256=gt07lXRUvYunJKiwv_w20zfXhcplSW4oT2C1dqiNNXc,4719
|
@@ -258,7 +258,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape
|
|
258
258
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py,sha256=DcJEIkGvBdIMOelNIwaJUZ5UsAHiGnDJPR20I464vWo,2929
|
259
259
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py,sha256=XFtU9yuBmoZlX0f0mS6otMPWMk-RcWs94XdvvTNhW8Y,3303
|
260
260
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py,sha256=D1hxN3pZ5-_FLJSS30ZJUo-v8TqUWFcMjhMijFa9aSo,12407
|
261
|
-
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py,sha256=
|
261
|
+
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py,sha256=n90Fu2ZkuWPoqy1_GchrQSk6O-HlaeuBeVfaCR_O8xI,10755
|
262
262
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/softmax_shift.py,sha256=05lV4pIL3hJkZl4JQPV4wk_EFD0eYLG5b8cdzvZk4P8,1588
|
263
263
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/transform_function_call_method.py,sha256=EC9Dvp-_UlpDWnipnf8ds65wh_Y-T8pXAFIwRScWpiY,2044
|
264
264
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=WmEa8Xjji-_tIbthDxlLAGSr69nWk-YKcHNaVqLa7sg,1375
|
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
528
528
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
|
529
529
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
530
530
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
531
|
-
mct_nightly-2.3.0.
|
532
|
-
mct_nightly-2.3.0.
|
533
|
-
mct_nightly-2.3.0.
|
534
|
-
mct_nightly-2.3.0.
|
531
|
+
mct_nightly-2.3.0.20250520.607.dist-info/METADATA,sha256=s41u_n703mnXVAN6OOTRrZKu_w7EGrsBDIv06qil5fo,25135
|
532
|
+
mct_nightly-2.3.0.20250520.607.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
533
|
+
mct_nightly-2.3.0.20250520.607.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
534
|
+
mct_nightly-2.3.0.20250520.607.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.3.0.
|
30
|
+
__version__ = "2.3.0.20250520.000607"
|
@@ -36,22 +36,28 @@ class FusingInfo:
|
|
36
36
|
belong to fused operations and validate this info is correct after changes in the graph.
|
37
37
|
|
38
38
|
The core structures maintained are:
|
39
|
+
- 'fusing_patterns': The patterns to generate the fused operators from.
|
40
|
+
- 'manual_fused_ops': List of sequence of node names to handle as fused ops (even if they are not part of the fusing patterns).
|
39
41
|
- `fusing_data`: A dictionary mapping fused operation IDs to lists of nodes that belong to that operation.
|
40
42
|
- `node_to_fused_node_map`: A dictionary mapping each node name to the ID of the fused operation it belongs to.
|
41
43
|
|
42
44
|
"""
|
43
|
-
fusing_patterns: any = None
|
45
|
+
fusing_patterns: List[list[any]] = None
|
46
|
+
manual_fused_ops: List[List[str]] = None
|
44
47
|
fusing_data: Dict[str, Tuple['BaseNode']] = field(default_factory=dict)
|
45
48
|
node_to_fused_node_map: Dict[str, str] = field(init=False, default_factory=dict)
|
46
49
|
fused_op_id_to_quant_config: Dict[str, OpQuantizationConfig] = field(default_factory=dict)
|
47
50
|
|
48
51
|
def __post_init__(self):
|
49
52
|
"""Validates and initializes mappings after dataclass instantiation."""
|
53
|
+
self.fusing_patterns = self.fusing_patterns or []
|
50
54
|
for op_id, op_nodes in self.fusing_data.items():
|
51
55
|
assert isinstance(op_id, str) and op_id.startswith(FUSED_OP_ID_PREFIX), f"Found invalid fused op id: {op_id}"
|
52
56
|
assert isinstance(op_nodes, tuple) and len(op_nodes) > 1, f"Found invalid fused op nodes: {op_nodes}"
|
53
57
|
|
54
58
|
self._init_node_mapping()
|
59
|
+
self._manual_fused_ops = self.manual_fused_ops or []
|
60
|
+
del self.manual_fused_ops
|
55
61
|
self._init_quantization_config_map()
|
56
62
|
|
57
63
|
def _init_node_mapping(self) -> None:
|
@@ -63,6 +69,26 @@ class FusingInfo:
|
|
63
69
|
for node in nodes:
|
64
70
|
self.node_to_fused_node_map[node.name] = op_id
|
65
71
|
|
72
|
+
def get_manual_nodes_to_fuse(self) -> List[List[str]]:
|
73
|
+
"""
|
74
|
+
Get the list of node names to be fused manually.
|
75
|
+
"""
|
76
|
+
return self._manual_fused_ops
|
77
|
+
|
78
|
+
|
79
|
+
def add_manual_nodes_to_fuse(self, node_names: List[str]):
|
80
|
+
"""
|
81
|
+
Add a list of node names to be fused manually.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
node_names: List of nodes to be fused.
|
85
|
+
|
86
|
+
"""
|
87
|
+
assert isinstance(node_names, list)
|
88
|
+
assert all([isinstance(n, str) for n in node_names])
|
89
|
+
assert node_names not in self._manual_fused_ops, f"{node_names} is already in manual fused ops: {self._manual_fused_ops}"
|
90
|
+
self._manual_fused_ops.append(node_names)
|
91
|
+
|
66
92
|
def _init_quantization_config_map(self) -> None:
|
67
93
|
"""
|
68
94
|
Init the mapping between fused operation IDs and their quantization configurations.
|
@@ -121,12 +147,16 @@ class FusingInfo:
|
|
121
147
|
raise ValueError(f"Fused operation {op_id} does not exist.")
|
122
148
|
# Remove nodes from the mapping
|
123
149
|
nodes = self.fusing_data[op_id]
|
150
|
+
node_names = [n.name for n in nodes]
|
151
|
+
if node_names in self._manual_fused_ops:
|
152
|
+
self._manual_fused_ops.remove(node_names)
|
153
|
+
|
124
154
|
for node in nodes:
|
125
155
|
self.node_to_fused_node_map.pop(node.name, None)
|
126
156
|
del self.fusing_data[op_id]
|
127
157
|
self.fused_op_id_to_quant_config.pop(op_id, None)
|
128
158
|
|
129
|
-
def
|
159
|
+
def get_fused_op_id_for_node(self, node_name: str) -> Optional[str]:
|
130
160
|
"""
|
131
161
|
Get the name of the fused node containing the given original node name.
|
132
162
|
|
@@ -168,6 +198,12 @@ class FusingInfo:
|
|
168
198
|
"""
|
169
199
|
return self.fusing_data.get(op_id)
|
170
200
|
|
201
|
+
def get_nodes_to_disable_activation_quantization(self) -> List['BaseNode']:
|
202
|
+
"""
|
203
|
+
Returns a list of the nodes that their activation quantization is disabled due to fusing.
|
204
|
+
"""
|
205
|
+
return [node for nodes in self.get_all_fused_operations().values() for node in nodes[:-1]]
|
206
|
+
|
171
207
|
def get_fused_op_quantization_config(self, op_id: str) -> OpQuantizationConfig:
|
172
208
|
"""
|
173
209
|
Retrieve the quantization configuration for a given fused operation ID.
|
@@ -268,7 +304,7 @@ class FusingInfo:
|
|
268
304
|
|
269
305
|
# Check 4: Ensure the sequence matches a valid fusing pattern
|
270
306
|
valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns)
|
271
|
-
if not is_valid_fusion(valid_fusing_patterns, nodes):
|
307
|
+
if not is_valid_fusion(valid_fusing_patterns, nodes, self._manual_fused_ops):
|
272
308
|
raise ValueError(
|
273
309
|
f"Fused operation {op_id} does not match any valid fusing pattern "
|
274
310
|
f"from {valid_fusing_patterns}."
|
@@ -311,13 +347,17 @@ class FusingInfo:
|
|
311
347
|
f" Total fused operations: {len(self.fusing_data)}\n"
|
312
348
|
f" Fusing Data:\n{fusing_data_repr}\n"
|
313
349
|
f" Node-to-Fused Mapping:\n {mapping_repr}\n"
|
350
|
+
f" Manual fused ops:\n {self._manual_fused_ops}\n"
|
314
351
|
f")"
|
315
352
|
)
|
316
353
|
|
317
354
|
|
318
355
|
class FusingInfoGenerator:
|
319
|
-
def __init__(self, fusing_patterns):
|
320
|
-
self._fusing_patterns = fusing_patterns
|
356
|
+
def __init__(self, fusing_patterns: List[list] = None, manual_fused_ops: List[List[str]] = None):
|
357
|
+
self._fusing_patterns = fusing_patterns or []
|
358
|
+
assert isinstance(self._fusing_patterns, list)
|
359
|
+
self._manual_fused_ops = manual_fused_ops or []
|
360
|
+
assert isinstance(self._manual_fused_ops, list)
|
321
361
|
|
322
362
|
def generate_fusing_info(self, graph: 'Graph') -> FusingInfo:
|
323
363
|
"""
|
@@ -338,7 +378,7 @@ class FusingInfoGenerator:
|
|
338
378
|
- Each node belongs to at most one fused operation.
|
339
379
|
"""
|
340
380
|
if not self._fusing_patterns:
|
341
|
-
return FusingInfo(fusing_patterns=self._fusing_patterns)
|
381
|
+
return FusingInfo(fusing_patterns=self._fusing_patterns, manual_fused_ops=self._manual_fused_ops)
|
342
382
|
|
343
383
|
# Extract fusing layer patterns
|
344
384
|
fusing_layer_patterns = _get_fusing_layer_patterns(self._fusing_patterns)
|
@@ -352,31 +392,53 @@ class FusingInfoGenerator:
|
|
352
392
|
fusing_info: Dict[str, Tuple['BaseNode']] = {}
|
353
393
|
fused_nodes = [] # nodes that are participating in fusing
|
354
394
|
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
395
|
+
if len(self._fusing_patterns)>0:
|
396
|
+
for node in nodes:
|
397
|
+
# Skip if already in fusing
|
398
|
+
if node in fused_nodes:
|
399
|
+
continue
|
400
|
+
# Start fusing search
|
401
|
+
fusing_nodes = [] # nodes that are candidates for participating in fusing
|
402
|
+
patterns = copy.deepcopy(fusing_layer_patterns)
|
403
|
+
next_nodes = [node]
|
404
|
+
for i in range(max_layer_patterns):
|
405
|
+
patterns = get_valid_fusing_patterns_for_node(patterns, next_nodes[0], i)
|
406
|
+
if len(patterns) == 0: # Give up if no more fusion pattern
|
407
|
+
break
|
408
|
+
fusing_nodes.append(next_nodes[0])
|
409
|
+
next_nodes = graph.get_next_nodes(fusing_nodes[-1])
|
410
|
+
if len(next_nodes) != 1: # Give up if node has more than one connection (not supported for fusion)
|
411
|
+
break
|
412
|
+
|
413
|
+
# New fusion
|
414
|
+
if is_valid_fusion(fusing_layer_patterns, fusing_nodes):
|
415
|
+
fused_op_id = FusingInfo.generate_fused_op_id(fusing_nodes)
|
416
|
+
assert fused_op_id not in fusing_info, f"{fused_op_id} is already in fusing info: {fusing_info}"
|
417
|
+
fusing_info[fused_op_id] = tuple(fusing_nodes)
|
418
|
+
fused_nodes.extend(fusing_nodes)
|
419
|
+
|
420
|
+
for manual_names in self._manual_fused_ops:
|
421
|
+
manual_nodes = [graph.find_node_by_name(n) for n in manual_names]
|
422
|
+
for n in manual_nodes:
|
423
|
+
if len(n) != 1:
|
424
|
+
raise ValueError(f"Expected exactly one node, but got {len(n)}")
|
425
|
+
manual_nodes = [n[0] for n in manual_nodes]
|
426
|
+
|
427
|
+
# Remove any existing fused ops containing any of the manual nodes
|
428
|
+
fused_ids_to_remove = {
|
429
|
+
op_id for op_id, nodes in fusing_info.items()
|
430
|
+
if any(node in nodes for node in manual_nodes)
|
431
|
+
}
|
432
|
+
for op_id in fused_ids_to_remove:
|
433
|
+
del fusing_info[op_id]
|
434
|
+
|
435
|
+
fused_op_id = FusingInfo.generate_fused_op_id(manual_nodes)
|
436
|
+
assert fused_op_id not in fusing_info, f"{fused_op_id} is already in fusing info: {fusing_info}"
|
437
|
+
fusing_info[fused_op_id] = tuple(manual_nodes)
|
438
|
+
|
439
|
+
return FusingInfo(fusing_data=fusing_info,
|
440
|
+
fusing_patterns=self._fusing_patterns,
|
441
|
+
manual_fused_ops=self._manual_fused_ops)
|
380
442
|
|
381
443
|
|
382
444
|
def get_valid_fusing_patterns_for_node(fusing_patterns: List[List[Any]],
|
@@ -404,15 +466,20 @@ def get_valid_fusing_patterns_for_node(fusing_patterns: List[List[Any]],
|
|
404
466
|
return valid_fusing_patterns
|
405
467
|
|
406
468
|
|
407
|
-
def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode']) -> bool:
|
469
|
+
def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode'], manual_fused_names: List[List[str]]=None) -> bool:
|
408
470
|
"""
|
409
471
|
Check if the fusion is valid: exist in fusing_patterns
|
410
472
|
Args:
|
411
473
|
fusing_patterns: supported fusing patterns
|
412
474
|
nodes: nodes which are participating in fusion
|
475
|
+
manual_fused_names: list of nodes names to handle as a valid fusing op.
|
413
476
|
Returns:
|
414
477
|
whether the fusion in valid
|
415
478
|
"""
|
479
|
+
node_names = [n.name for n in nodes]
|
480
|
+
if any(manual == node_names for manual in (manual_fused_names or [])):
|
481
|
+
return True
|
482
|
+
|
416
483
|
fusion_depth = len(nodes)
|
417
484
|
if fusion_depth <= 1:
|
418
485
|
return False
|
@@ -46,12 +46,14 @@ class GraphFuser:
|
|
46
46
|
The updated graph with fused nodes replacing the original node groups.
|
47
47
|
"""
|
48
48
|
graph_copy = copy.deepcopy(graph)
|
49
|
-
expected_fusing_info = FusingInfoGenerator(graph_copy.fusing_info.fusing_patterns
|
49
|
+
expected_fusing_info = FusingInfoGenerator(graph_copy.fusing_info.fusing_patterns,
|
50
|
+
graph_copy.fusing_info.get_manual_nodes_to_fuse()).generate_fusing_info(graph_copy)
|
50
51
|
|
51
|
-
|
52
|
+
existing_fusing_info = graph_copy.fusing_info
|
53
|
+
if expected_fusing_info != existing_fusing_info:
|
52
54
|
raise ValueError(
|
53
55
|
f"Mismatch between expected and existing fusing information.\n"
|
54
|
-
f"Expected:\n{expected_fusing_info}\nExisting:\n{
|
56
|
+
f"Expected:\n{expected_fusing_info}\nExisting:\n{existing_fusing_info}"
|
55
57
|
)
|
56
58
|
|
57
59
|
fused_operations = list(graph_copy.fusing_info.get_all_fused_operations().items())
|
@@ -908,7 +908,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
908
908
|
Disable activation quantization for all nodes in fused operations,
|
909
909
|
except for the last node in each fused group.
|
910
910
|
"""
|
911
|
-
nodes_to_disable =
|
911
|
+
nodes_to_disable = self.fusing_info.get_nodes_to_disable_activation_quantization()
|
912
912
|
for node in nodes_to_disable:
|
913
913
|
for qc in node.candidates_quantization_cfg:
|
914
914
|
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
|
@@ -51,7 +51,7 @@ class MixedPrecisionRUHelper:
|
|
51
51
|
"""
|
52
52
|
act_qcs, w_qcs = self.get_quantization_candidates(mp_cfg)
|
53
53
|
|
54
|
-
ru, detailed_ru = self.ru_calculator.compute_resource_utilization(TargetInclusionCriterion.
|
54
|
+
ru, detailed_ru = self.ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused,
|
55
55
|
BitwidthMode.QCustom,
|
56
56
|
act_qcs=act_qcs,
|
57
57
|
w_qcs=w_qcs,
|
@@ -294,8 +294,12 @@ class MixedPrecisionSearchManager:
|
|
294
294
|
"""
|
295
295
|
act_qcs, w_qcs = self.orig_graph_ru_helper.get_quantization_candidates(config)
|
296
296
|
ru = self.orig_graph_ru_helper.ru_calculator.compute_resource_utilization(
|
297
|
-
target_criterion=TargetInclusionCriterion.
|
298
|
-
|
297
|
+
target_criterion=TargetInclusionCriterion.AnyQuantizedNonFused,
|
298
|
+
bitwidth_mode=BitwidthMode.QCustom,
|
299
|
+
act_qcs=act_qcs,
|
300
|
+
w_qcs=w_qcs,
|
301
|
+
ru_targets=self.ru_targets,
|
302
|
+
allow_unused_qcs=True)
|
299
303
|
return ru
|
300
304
|
|
301
305
|
def _finalize_distance_metric(self, layer_to_metrics_mapping: Dict[BaseNode, List[float]]):
|
@@ -67,11 +67,13 @@ class TargetInclusionCriterion(Enum):
|
|
67
67
|
QNonConfigurable: non-configurable targets (single quantization candidate).
|
68
68
|
AnyQuantized: any quantized targets (configurable and non-configurable).
|
69
69
|
Any: all targets (quantized + float).
|
70
|
+
QuantizedNonFused: any quantized targets that are not inside fused operations.
|
70
71
|
"""
|
71
72
|
QConfigurable = auto()
|
72
73
|
QNonConfigurable = auto()
|
73
74
|
AnyQuantized = auto()
|
74
75
|
Any = auto()
|
76
|
+
AnyQuantizedNonFused = auto()
|
75
77
|
|
76
78
|
|
77
79
|
class Utilization(NamedTuple):
|
@@ -534,8 +536,9 @@ class ResourceUtilizationCalculator:
|
|
534
536
|
assert not isinstance(n, VirtualNode), 'Use original graph to compute BOPS.'
|
535
537
|
if target_criterion is None:
|
536
538
|
target_criterion = TargetInclusionCriterion.Any
|
537
|
-
if target_criterion not in [TargetInclusionCriterion.AnyQuantized, TargetInclusionCriterion.Any]:
|
538
|
-
raise ValueError(
|
539
|
+
if target_criterion not in [TargetInclusionCriterion.AnyQuantized, TargetInclusionCriterion.AnyQuantizedNonFused, TargetInclusionCriterion.Any]:
|
540
|
+
raise ValueError(
|
541
|
+
'BOPS computation is supported only for Any, AnyQuantized and AnyQuantizedNonFused targets.')
|
539
542
|
|
540
543
|
self._validate_custom_qcs(act_qcs, bitwidth_mode)
|
541
544
|
self._validate_custom_qcs(w_qc, bitwidth_mode)
|
@@ -621,7 +624,7 @@ class ResourceUtilizationCalculator:
|
|
621
624
|
weight_attrs = n.get_node_weights_attributes()
|
622
625
|
if target_criterion == TargetInclusionCriterion.QConfigurable:
|
623
626
|
weight_attrs = [attr for attr in weight_attrs if n.is_configurable_weight(attr)]
|
624
|
-
elif target_criterion
|
627
|
+
elif target_criterion in [TargetInclusionCriterion.AnyQuantized, TargetInclusionCriterion.AnyQuantizedNonFused]:
|
625
628
|
weight_attrs = [attr for attr in weight_attrs if n.is_weights_quantization_enabled(attr)]
|
626
629
|
elif target_criterion == TargetInclusionCriterion.QNonConfigurable:
|
627
630
|
quantized = [attr for attr in weight_attrs if n.is_weights_quantization_enabled(attr)]
|
@@ -671,6 +674,10 @@ class ResourceUtilizationCalculator:
|
|
671
674
|
nodes = [n for n in nodes if n.has_configurable_activation()]
|
672
675
|
elif target_criterion == TargetInclusionCriterion.AnyQuantized:
|
673
676
|
nodes = [n for n in nodes if n.is_activation_quantization_enabled() or n.is_quantization_preserving()]
|
677
|
+
elif target_criterion == TargetInclusionCriterion.AnyQuantizedNonFused:
|
678
|
+
nodes = [n for n in nodes if n.is_activation_quantization_enabled() or n.is_quantization_preserving()]
|
679
|
+
# remove fused nodes (due to SNC, where the non-linear is quantized, even though it should not be quantized)
|
680
|
+
nodes = [n for n in nodes if n not in self.graph.fusing_info.get_nodes_to_disable_activation_quantization()]
|
674
681
|
elif target_criterion == TargetInclusionCriterion.QNonConfigurable:
|
675
682
|
nodes = [n for n in nodes if n.is_activation_quantization_enabled() and not n.has_configurable_activation()]
|
676
683
|
elif target_criterion != TargetInclusionCriterion.Any: # pragma: no cover
|
@@ -63,4 +63,4 @@ def compute_resource_utilization_data(in_model: Any,
|
|
63
63
|
running_gptq=False)
|
64
64
|
|
65
65
|
ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl, fw_info)
|
66
|
-
return ru_calculator.compute_resource_utilization(TargetInclusionCriterion.
|
66
|
+
return ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused, BitwidthMode.QDefaultSP)
|
@@ -149,7 +149,7 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
|
|
149
149
|
# the current info, or this creates a new fusion and the old pattern should be
|
150
150
|
# replaced with the new one.
|
151
151
|
fi = graph.fusing_info
|
152
|
-
fused_op = fi.
|
152
|
+
fused_op = fi.get_fused_op_id_for_node(source_node.name)
|
153
153
|
if fused_op:
|
154
154
|
fused_nodes = list(fi.get_fused_nodes(fused_op))
|
155
155
|
assert source_node in fused_nodes
|
@@ -23,8 +23,6 @@ from model_compression_toolkit.logger import Logger
|
|
23
23
|
from model_compression_toolkit.core.common import FrameworkInfo, Graph, BaseNode
|
24
24
|
from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
|
25
25
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
26
|
-
from mct_quantizers import QuantizationMethod
|
27
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig
|
28
26
|
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import create_node_activation_qc, \
|
29
27
|
set_quantization_configs_to_node
|
30
28
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
@@ -33,6 +31,7 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
|
|
33
31
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
|
34
32
|
_mse_error_histogram
|
35
33
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation import z_score_filter
|
34
|
+
from model_compression_toolkit.target_platform_capabilities import QuantizationMethod, AttributeQuantizationConfig
|
36
35
|
|
37
36
|
"""
|
38
37
|
This substitution aims to solve an issue of activation with negative outputs where
|
@@ -188,6 +187,65 @@ def remove_node_between_two_nodes(graph: Graph,
|
|
188
187
|
graph.remove_node(node_to_remove)
|
189
188
|
|
190
189
|
|
190
|
+
def fuse_padding_with_op2d(graph: 'BaseGraph', pad_node: 'BaseNode', op2d_node: 'BaseNode') -> None:
|
191
|
+
"""
|
192
|
+
Add a padding node to the fused operation containing op2d_node.
|
193
|
+
If op2d_node is not already in a fused op, create a new fused group with both nodes.
|
194
|
+
|
195
|
+
Args:
|
196
|
+
graph: The computational graph.
|
197
|
+
pad_node: The padding node to be added.
|
198
|
+
op2d_node: The Conv2D or similar op node following the pad.
|
199
|
+
"""
|
200
|
+
fusing_info = graph.fusing_info
|
201
|
+
|
202
|
+
if fusing_info.is_node_in_fused_op(op2d_node):
|
203
|
+
fused_id = fusing_info.get_fused_op_id_for_node(op2d_node.name)
|
204
|
+
fused_nodes = fusing_info.get_fused_nodes(fused_id)
|
205
|
+
fusing_info.remove_fused_operation(fused_id)
|
206
|
+
else:
|
207
|
+
fused_nodes = [op2d_node]
|
208
|
+
|
209
|
+
new_fused_nodes = [pad_node] + list(fused_nodes)
|
210
|
+
fused_op_id = fusing_info.generate_fused_op_id(new_fused_nodes)
|
211
|
+
|
212
|
+
fusing_info.add_fused_operation(fused_op_id, tuple(new_fused_nodes))
|
213
|
+
fusing_info.add_manual_nodes_to_fuse([n.name for n in new_fused_nodes])
|
214
|
+
|
215
|
+
def update_fused_op_with_add(graph: 'BaseGraph', non_linear_node: 'BaseNode', add_node: 'BaseNode') -> None:
|
216
|
+
"""
|
217
|
+
Update the fused operation to include an Add node that follows a non-linear activation node.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
graph: The computational graph.
|
221
|
+
non_linear_node: The non-linear activation node (e.g., ReLU).
|
222
|
+
add_node: The Add node inserted after the non-linear node.
|
223
|
+
"""
|
224
|
+
fusing_info = graph.fusing_info
|
225
|
+
prev_node = graph.get_prev_nodes(non_linear_node)[0]
|
226
|
+
|
227
|
+
# Gather existing fused nodes (if any)
|
228
|
+
fused_candidates = []
|
229
|
+
for node in (prev_node, non_linear_node):
|
230
|
+
if fusing_info.is_node_in_fused_op(node):
|
231
|
+
fused_id = fusing_info.get_fused_op_id_for_node(node.name)
|
232
|
+
fused_candidates.extend(fusing_info.get_fused_nodes(fused_id))
|
233
|
+
fused_candidates.append(add_node)
|
234
|
+
|
235
|
+
# Remove duplicates while preserving order
|
236
|
+
fused_candidates = list(dict.fromkeys(fused_candidates))
|
237
|
+
|
238
|
+
# Remove existing fused ops involving prev_node or non_linear_node
|
239
|
+
for node in (prev_node, non_linear_node):
|
240
|
+
if fusing_info.is_node_in_fused_op(node):
|
241
|
+
fusing_info.remove_fused_operation(fusing_info.get_fused_op_id_for_node(node.name))
|
242
|
+
|
243
|
+
# Register new fused operation
|
244
|
+
fused_op_id = fusing_info.generate_fused_op_id(fused_candidates)
|
245
|
+
fusing_info.add_manual_nodes_to_fuse([n.name for n in fused_candidates])
|
246
|
+
fusing_info.add_fused_operation(fused_op_id, tuple(fused_candidates))
|
247
|
+
|
248
|
+
|
191
249
|
def shift_negative_function(graph: Graph,
|
192
250
|
core_config: CoreConfig,
|
193
251
|
non_linear_node: BaseNode,
|
@@ -232,7 +290,6 @@ def shift_negative_function(graph: Graph,
|
|
232
290
|
Returns:
|
233
291
|
Graph after applying the shifting and correction.
|
234
292
|
"""
|
235
|
-
|
236
293
|
min_to_correct, max_value2compare = graph.get_out_stats_collector(non_linear_node).get_min_max_values()
|
237
294
|
|
238
295
|
if not non_linear_node.is_all_activation_candidates_equal():
|
@@ -242,6 +299,7 @@ def shift_negative_function(graph: Graph,
|
|
242
299
|
# all candidates have same activation config, so taking the first candidate for calculations
|
243
300
|
non_linear_node_cfg_candidate = non_linear_node.candidates_quantization_cfg[0].activation_quantization_cfg
|
244
301
|
|
302
|
+
|
245
303
|
# get the non-linear activation threshold
|
246
304
|
activation_threshold = non_linear_node_cfg_candidate.activation_quantization_params.get(THRESHOLD)
|
247
305
|
|
@@ -350,7 +408,14 @@ def shift_negative_function(graph: Graph,
|
|
350
408
|
fqc=graph.fqc,
|
351
409
|
mixed_precision_enable=core_config.is_mixed_precision_enabled)
|
352
410
|
|
353
|
-
|
411
|
+
update_fused_op_with_add(graph=graph,
|
412
|
+
non_linear_node=non_linear_node,
|
413
|
+
add_node=add_node)
|
414
|
+
|
415
|
+
# If sum([pad_top, pad_btm, pad_left, pad_right])==0 it means we do not pad in any side, thus
|
416
|
+
# we do not add a padding node as this is meaningless
|
417
|
+
pad_node = None
|
418
|
+
if padding is not None and sum([pad_top, pad_btm, pad_left, pad_right])>0:
|
354
419
|
pad_node = create_pad_node(op2d_node.name,
|
355
420
|
add_node.name,
|
356
421
|
shift_value,
|
@@ -394,8 +459,16 @@ def shift_negative_function(graph: Graph,
|
|
394
459
|
graph.shift_stats_collector(bypass_node, np.array(shift_value))
|
395
460
|
|
396
461
|
add_node_qco = add_node.get_qco(graph.fqc).quantization_configurations
|
462
|
+
add_supported_bitwidths = [c.activation_n_bits for c in add_node_qco]
|
463
|
+
if original_non_linear_activation_nbits not in add_supported_bitwidths:
|
464
|
+
raise ValueError(
|
465
|
+
f"Add supported activation bit-widths according to the TPC are: {add_supported_bitwidths}, but non-linear "
|
466
|
+
f"bitwidth is {original_non_linear_activation_nbits}. Consider adapting the TPC so 'Add' will support the "
|
467
|
+
f"same bitwidth as {non_linear_node.type} or disable shift negative correction.")
|
468
|
+
|
397
469
|
for op_qc_idx, candidate_qc in enumerate(add_node.candidates_quantization_cfg):
|
398
470
|
for attr in add_node.get_node_weights_attributes():
|
471
|
+
# TODO: do we not quantize the weights of this 'add' on purpose?
|
399
472
|
candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False
|
400
473
|
|
401
474
|
candidate_qc.activation_quantization_cfg = create_node_activation_qc(core_config.quantization_config,
|
@@ -404,8 +477,15 @@ def shift_negative_function(graph: Graph,
|
|
404
477
|
|
405
478
|
candidate_qc.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold,
|
406
479
|
SIGNED: False})
|
480
|
+
|
407
481
|
candidate_qc.activation_quantization_cfg.activation_n_bits = original_non_linear_activation_nbits
|
408
482
|
|
483
|
+
# Add the new padding node to a fused op with the op2d.
|
484
|
+
if pad_node:
|
485
|
+
fuse_padding_with_op2d(graph=graph,
|
486
|
+
pad_node=pad_node,
|
487
|
+
op2d_node=op2d_node)
|
488
|
+
|
409
489
|
if non_linear_node_cfg_candidate.shift_negative_threshold_recalculation:
|
410
490
|
activation_param = get_activations_qparams(activation_quant_cfg=non_linear_node_cfg_candidate,
|
411
491
|
nodes_prior_info=non_linear_node.prior_info,
|
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
import copy
|
15
16
|
import operator
|
16
17
|
from typing import Tuple, Any, Callable
|
17
18
|
|
@@ -149,9 +150,9 @@ def create_pad_node(next_node_name: str,
|
|
149
150
|
op_call_kwargs = {PAD: [pad_left, pad_right, pad_top, pad_btm],
|
150
151
|
VALUE: float(value_to_pad)}
|
151
152
|
|
152
|
-
padded_shape = input_shape[0]
|
153
|
-
padded_shape[
|
154
|
-
padded_shape[
|
153
|
+
padded_shape = copy.deepcopy(input_shape[0])
|
154
|
+
padded_shape[2] += pad_top + pad_btm
|
155
|
+
padded_shape[3] += pad_left + pad_right
|
155
156
|
pad_node = common.graph.functional_node.FunctionalNode(pad_node_name,
|
156
157
|
{},
|
157
158
|
input_shape,
|
@@ -241,5 +242,4 @@ def pytorch_apply_shift_negative_correction(graph: Graph,
|
|
241
242
|
PADDING,
|
242
243
|
BIAS,
|
243
244
|
USE_BIAS,
|
244
|
-
params_search_quantization_fn=params_search_quantization_fn
|
245
|
-
)
|
245
|
+
params_search_quantization_fn=params_search_quantization_fn)
|
@@ -220,7 +220,7 @@ def _set_final_resource_utilization(graph: Graph,
|
|
220
220
|
ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
|
221
221
|
w_qcs = {n.name: n.final_weights_quantization_cfg for n in graph.nodes}
|
222
222
|
a_qcs = {n.name: n.final_activation_quantization_cfg for n in graph.nodes}
|
223
|
-
final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.
|
223
|
+
final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused,
|
224
224
|
BitwidthMode.QCustom, act_qcs=a_qcs, w_qcs=w_qcs,
|
225
225
|
ru_targets=ru_targets, allow_unused_qcs=True)
|
226
226
|
summary = final_ru.get_summary_str(restricted=True)
|
File without changes
|
File without changes
|
{mct_nightly-2.3.0.20250518.615.dist-info → mct_nightly-2.3.0.20250520.607.dist-info}/top_level.txt
RENAMED
File without changes
|