mct-nightly 2.3.0.20250402.536__py3-none-any.whl → 2.3.0.20250404.535__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.3.0.20250402.536.dist-info → mct_nightly-2.3.0.20250404.535.dist-info}/METADATA +1 -1
- {mct_nightly-2.3.0.20250402.536.dist-info → mct_nightly-2.3.0.20250404.535.dist-info}/RECORD +18 -18
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/fusion/fusing_info.py +374 -0
- model_compression_toolkit/core/common/fusion/graph_fuser.py +50 -28
- model_compression_toolkit/core/common/graph/base_graph.py +89 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +8 -0
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +8 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +16 -1
- model_compression_toolkit/core/graph_prep_runner.py +5 -2
- model_compression_toolkit/core/runner.py +3 -4
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/v2.py +2 -66
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py +0 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +0 -1
- model_compression_toolkit/core/common/fusion/layer_fusing.py +0 -131
- {mct_nightly-2.3.0.20250402.536.dist-info → mct_nightly-2.3.0.20250404.535.dist-info}/WHEEL +0 -0
- {mct_nightly-2.3.0.20250402.536.dist-info → mct_nightly-2.3.0.20250404.535.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.3.0.20250402.536.dist-info → mct_nightly-2.3.0.20250404.535.dist-info}/top_level.txt +0 -0
{mct_nightly-2.3.0.20250402.536.dist-info → mct_nightly-2.3.0.20250404.535.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mct-nightly
|
3
|
-
Version: 2.3.0.
|
3
|
+
Version: 2.3.0.20250404.535
|
4
4
|
Summary: A Model Compression Toolkit for neural networks
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
6
6
|
Classifier: License :: OSI Approved :: Apache Software License
|
{mct_nightly-2.3.0.20250402.536.dist-info → mct_nightly-2.3.0.20250404.535.dist-info}/RECORD
RENAMED
@@ -1,5 +1,5 @@
|
|
1
|
-
mct_nightly-2.3.0.
|
2
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
mct_nightly-2.3.0.20250404.535.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
2
|
+
model_compression_toolkit/__init__.py,sha256=Xy_GrGTjrv9Us1_tnSwgsiJDh-wjxsYto2Xpa5zo45M,1557
|
3
3
|
model_compression_toolkit/constants.py,sha256=2ltuH-gdaLZoZV4CPUgKjC3S9ojz2z4OTVdenyVEypU,3912
|
4
4
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
5
5
|
model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
|
@@ -7,9 +7,9 @@ model_compression_toolkit/metadata.py,sha256=x_Bk4VpzILdsFax6--CZ3X18qUTP28sbF_A
|
|
7
7
|
model_compression_toolkit/verify_packages.py,sha256=TlS-K1EP-QsghqWUW7SDPkAJiUf7ryw4tvhFDe6rCUk,1405
|
8
8
|
model_compression_toolkit/core/__init__.py,sha256=8a0wUNBKwTdJGDk_Ho6WQAXjGuCqQZG1FUxxJlAV8L8,2096
|
9
9
|
model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
|
10
|
-
model_compression_toolkit/core/graph_prep_runner.py,sha256=
|
10
|
+
model_compression_toolkit/core/graph_prep_runner.py,sha256=C6eUTd-fcgxk0LUbt51gFZwmyDDDEB8-9Q4kr9ujYvI,11555
|
11
11
|
model_compression_toolkit/core/quantization_prep_runner.py,sha256=DPevqQ8brkdut8K5f5v9g5lbT3r1GSmhLAk3NkL40Fg,6593
|
12
|
-
model_compression_toolkit/core/runner.py,sha256=
|
12
|
+
model_compression_toolkit/core/runner.py,sha256=_r6cieb7Ur2BeHQK5XxTZHogjyA0utybvIVbH06CBHY,13056
|
13
13
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
14
14
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
15
15
|
model_compression_toolkit/core/common/framework_implementation.py,sha256=s3yiqnbWkwfnAB1sSal_KAuqVg27rLhAJ2O8LHUbSHE,22494
|
@@ -31,10 +31,10 @@ model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.p
|
|
31
31
|
model_compression_toolkit/core/common/collectors/statistics_collector.py,sha256=psijsQZefwjMDH8SU5E18n65HiGtQilPhKr1hhzZX-I,8268
|
32
32
|
model_compression_toolkit/core/common/collectors/weighted_histogram_collector.py,sha256=zp3dE7YTqWmkD5QWdRhsl9zD8W6Lr96G1Wjw1g2D3T0,4894
|
33
33
|
model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
34
|
-
model_compression_toolkit/core/common/fusion/
|
35
|
-
model_compression_toolkit/core/common/fusion/
|
34
|
+
model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=LfzVS9B6r2KCwf8rcCUdepEQhWkt287SoXfwoudpfFo,15496
|
35
|
+
model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=F0AaAUBpJ9JjHMB5H2LD9pdwTSWJK-Kqm9dQmGHX1Jo,7368
|
36
36
|
model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
|
37
|
-
model_compression_toolkit/core/common/graph/base_graph.py,sha256=
|
37
|
+
model_compression_toolkit/core/common/graph/base_graph.py,sha256=hedhjVula5rPv0vN0CLBDtPYM8SH3cM6FAL62aFfF7U,41767
|
38
38
|
model_compression_toolkit/core/common/graph/base_node.py,sha256=CJu8_r80MGVnYmlAUGOGKGRsD9xShMyaRNb3VMeRC0s,34523
|
39
39
|
model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
|
40
40
|
model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
|
@@ -68,7 +68,7 @@ model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha2
|
|
68
68
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=6pLUEEIqRTVIlCYQC4JIvY55KAvuBHEX8uTOQ-1Ac4Q,3859
|
69
69
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=r1t025_QHshyoop-PZvL7x6UuXaeplCCU3h4VNBhJHo,4309
|
70
70
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=2Pp4hiYvGW2I9YhloDxQNT0sZRg3TDp9CXObloF8IFU,4971
|
71
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=
|
71
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=GGrp7QngrWvWtPN8cQnL4IEbNwcVRc-hAUqfnxjjMmk,5998
|
72
72
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=NBzzhkVI407S9cIiw7t7nsP3MrkOdSnweKQdPBXb8to,38180
|
73
73
|
model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=gsigifJ-ykWNafF4t7UMEC_-nd6YPERAk1_z0kT-Y88,27172
|
74
74
|
model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
|
@@ -102,7 +102,7 @@ model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py,sha256=77
|
|
102
102
|
model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256=_LcDAxLeC5I0KdMHS8jib5XxIKO2ZLavXYuSMIPIQBo,5868
|
103
103
|
model_compression_toolkit/core/common/quantization/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
104
104
|
model_compression_toolkit/core/common/quantization/bit_width_config.py,sha256=0HA3CIZW-ZrA55ra-yJXRvAYnoR8i1SjpbnMDKcWYNQ,12819
|
105
|
-
model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=
|
105
|
+
model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=lyWPvnoX8BmulhLKR20r5gT2_Yan7P40d8EcgDhErPk,4905
|
106
106
|
model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
|
107
107
|
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=zJP2W9apUPX9RstpPWWK71wr9xJsg7j-s7lGV4_bQdc,1510
|
108
108
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=IHVX-Gdekru4xLuDTgcsp_JCnRtuVWnbYsDBQuSXTKc,7079
|
@@ -138,7 +138,7 @@ model_compression_toolkit/core/common/statistics_correction/statistics_correctio
|
|
138
138
|
model_compression_toolkit/core/common/substitutions/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
139
139
|
model_compression_toolkit/core/common/substitutions/apply_substitutions.py,sha256=k-bifmakHIYZeZS-4T1QpZ1Et6AwAijMRgAKs7hmMKc,1390
|
140
140
|
model_compression_toolkit/core/common/substitutions/batchnorm_folding.py,sha256=wLlTT7sqUffKHwOrMG2VV5SktQkkP54l8taW1Fq0mh0,13392
|
141
|
-
model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=
|
141
|
+
model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=1389z4NbTKIHYGr-FB-fV1YP1Gcfta0tOu60DwfNVlI,8452
|
142
142
|
model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py,sha256=dWJpVfomF4Ppeeor3VzS23TXHyBm85QI7snyLOYP_ko,9972
|
143
143
|
model_compression_toolkit/core/common/substitutions/linear_collapsing.py,sha256=iEtzbWCDXP6EDkTZCtREQ0rpMxhQ2kM9zlcP_0KLq9I,12367
|
144
144
|
model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py,sha256=uoauhmncQqUBNvD-qCLIXsIbl_IzrbxSKdxiMig-5W4,2406
|
@@ -435,14 +435,14 @@ model_compression_toolkit/target_platform_capabilities/constants.py,sha256=BFSgD
|
|
435
435
|
model_compression_toolkit/target_platform_capabilities/immutable.py,sha256=YhROBiXEIB3TU-bAFrnL3qbAsb1yuWPBAQ_CLOJbYUU,1827
|
436
436
|
model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=4ydTWWKv_PEOAFok2JtxFNj8rav-0IlqcXKF6lnhHNE,4157
|
437
437
|
model_compression_toolkit/target_platform_capabilities/schema/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
|
438
|
-
model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=
|
438
|
+
model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=PvO8eHxnb3A55gyExT5fZGnOUl3ce7BbbT5SPxCEXNo,541
|
439
439
|
model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py,sha256=vBkXxVJagm9JKB9cdm4Pvi7u_luriXUjvNn0-m8Zr0k,4653
|
440
440
|
model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=4CGpWENuOyjwaIMaGrFI0Act7jsSeT7m94pjrv91dxE,27516
|
441
|
-
model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=
|
441
|
+
model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=yg0ZrsaqaS69lmDvxRrz636CRARzx_eZbokTMVHNEXc,4555
|
442
442
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/__init__.py,sha256=XjNws3zoiJkeH4ixKqrLA5xBvpv5rq31qX7wYQjNpZM,1447
|
443
443
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py,sha256=HJ8uc3PFfyxg-WpVXPBg4mGaox8Z9bRqtQNbRfIyAk4,3745
|
444
|
-
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=
|
445
|
-
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256
|
444
|
+
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=mxc3DBbUi-HDFgSx8Nmnyxr8SIdbx8lmtcRMsQl1BLE,7578
|
445
|
+
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=8spnpqxVUv8WF9-PTukOLvJAFiNi01wNowUVIDqSj5I,6321
|
446
446
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attribute_filter.py,sha256=jfhszvuD2Fyy6W2KjlLzXBQKFzTqGAaDZeFVr4-ONQw,8776
|
447
447
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/current_tpc.py,sha256=_kFG0USYa6yzvLsi82_Vusv_KR8Hi7J1u680pPXECuo,2192
|
448
448
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=UKzckLYLdBcFAptyKnVMwpPpfRkmF0SK1Kl0g0eGjQA,9710
|
@@ -527,7 +527,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
527
527
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
|
528
528
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
529
529
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
530
|
-
mct_nightly-2.3.0.
|
531
|
-
mct_nightly-2.3.0.
|
532
|
-
mct_nightly-2.3.0.
|
533
|
-
mct_nightly-2.3.0.
|
530
|
+
mct_nightly-2.3.0.20250404.535.dist-info/METADATA,sha256=cb-U_2NM6U6KUmtNnw8cDsM_XjdMPgJrdJkZxDQEn9I,27098
|
531
|
+
mct_nightly-2.3.0.20250404.535.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
532
|
+
mct_nightly-2.3.0.20250404.535.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
533
|
+
mct_nightly-2.3.0.20250404.535.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.3.0.
|
30
|
+
__version__ = "2.3.0.20250404.000535"
|
@@ -0,0 +1,374 @@
|
|
1
|
+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from model_compression_toolkit.target_platform_capabilities import LayerFilterParams
|
17
|
+
from dataclasses import dataclass, field
|
18
|
+
|
19
|
+
from typing import Optional, List, Dict, Any, Tuple
|
20
|
+
import copy
|
21
|
+
|
22
|
+
# The prefix of each fused operator (the suffix is a combination of the
|
23
|
+
# nodes names that combine the fused operator).
|
24
|
+
FUSED_OP_ID_PREFIX = "FusedNode_"
|
25
|
+
|
26
|
+
|
27
|
+
@dataclass
|
28
|
+
class FusingInfo:
|
29
|
+
"""
|
30
|
+
This class manages information about fused operations in a graph.
|
31
|
+
|
32
|
+
The key responsibility of this class is maintaining a mapping between original nodes
|
33
|
+
and their corresponding fused operation IDs. This mapping helps track which nodes
|
34
|
+
belong to fused operations and validate this info is correct after changes in the graph.
|
35
|
+
|
36
|
+
The core structures maintained are:
|
37
|
+
- `fusing_data`: A dictionary mapping fused operation IDs to lists of nodes that belong to that operation.
|
38
|
+
- `node_to_fused_node_map`: A dictionary mapping each node name to the ID of the fused operation it belongs to.
|
39
|
+
|
40
|
+
"""
|
41
|
+
fusing_patterns: any = None
|
42
|
+
fusing_data: Dict[str, Tuple['BaseNode']] = field(default_factory=dict)
|
43
|
+
node_to_fused_node_map: Dict[str, str] = field(init=False, default_factory=dict)
|
44
|
+
|
45
|
+
def __post_init__(self):
|
46
|
+
"""Validates and initializes mappings after dataclass instantiation."""
|
47
|
+
for op_id, op_nodes in self.fusing_data.items():
|
48
|
+
assert isinstance(op_id, str) and op_id.startswith(FUSED_OP_ID_PREFIX), f"Found invalid fused op id: {op_id}"
|
49
|
+
assert isinstance(op_nodes, tuple) and len(op_nodes) > 1, f"Found invalid fused op nodes: {op_nodes}"
|
50
|
+
|
51
|
+
self._init_node_mapping()
|
52
|
+
|
53
|
+
def _init_node_mapping(self) -> None:
|
54
|
+
"""
|
55
|
+
Init the node-to-fused-node mapping based on the initial fusing data.
|
56
|
+
"""
|
57
|
+
self.node_to_fused_node_map.clear()
|
58
|
+
for op_id, nodes in self.fusing_data.items():
|
59
|
+
for node in nodes:
|
60
|
+
self.node_to_fused_node_map[node.name] = op_id
|
61
|
+
|
62
|
+
def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
|
63
|
+
"""
|
64
|
+
Add a new fused operation with the given ID and set of nodes.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
op_id (str): The identifier for the fused operation.
|
68
|
+
nodes (Tuple[BaseNode]): The tuple of nodes that form the fused operation.
|
69
|
+
|
70
|
+
Raises:
|
71
|
+
ValueError: If the operation ID already exists.
|
72
|
+
"""
|
73
|
+
if op_id in self.fusing_data:
|
74
|
+
raise ValueError(f"Fused operation {op_id} already exists.")
|
75
|
+
assert isinstance(nodes, tuple), f"Expected nodes to be a tuple but its type is {type(nodes)}"
|
76
|
+
self.fusing_data[op_id] = nodes
|
77
|
+
# Update the mapping for these nodes
|
78
|
+
for node in nodes:
|
79
|
+
self.node_to_fused_node_map[node.name] = op_id
|
80
|
+
|
81
|
+
def remove_fused_operation(self, op_id: str) -> None:
|
82
|
+
"""
|
83
|
+
Remove a fused operation by its ID.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
op_id (str): The identifier for the fused operation to remove.
|
87
|
+
|
88
|
+
Raises:
|
89
|
+
ValueError: If the operation ID does not exist.
|
90
|
+
"""
|
91
|
+
if op_id not in self.fusing_data:
|
92
|
+
raise ValueError(f"Fused operation {op_id} does not exist.")
|
93
|
+
# Remove nodes from the mapping
|
94
|
+
nodes = self.fusing_data[op_id]
|
95
|
+
for node in nodes:
|
96
|
+
self.node_to_fused_node_map.pop(node.name, None)
|
97
|
+
del self.fusing_data[op_id]
|
98
|
+
|
99
|
+
def get_fused_node_name(self, node_name: str) -> Optional[str]:
|
100
|
+
"""
|
101
|
+
Get the name of the fused node containing the given original node name.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
node_name: The name of a node from the original graph.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
The name of the fused node containing this node, or None if not fused.
|
108
|
+
"""
|
109
|
+
return self.node_to_fused_node_map.get(node_name)
|
110
|
+
|
111
|
+
def get_node_to_fused_node_map(self) -> Dict[str, str]:
|
112
|
+
"""
|
113
|
+
Retrieve a copy of the mapping from original node names to fused node names.
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
A dictionary mapping each original node name to its fused node name.
|
117
|
+
"""
|
118
|
+
return self.node_to_fused_node_map.copy()
|
119
|
+
|
120
|
+
def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]:
|
121
|
+
"""
|
122
|
+
Retrieve the list of nodes for a given fused operation ID.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
op_id (str): The identifier for the fused operation.
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
Optional[List[BaseNode]]: The list of nodes for the operation, or None if not found.
|
129
|
+
"""
|
130
|
+
return self.fusing_data.get(op_id)
|
131
|
+
|
132
|
+
def is_node_in_fused_op(self, node: 'BaseNode') -> bool:
|
133
|
+
"""
|
134
|
+
Check if a node is part of any fused operation.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
node (BaseNode): The node to check.
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
bool: True if the node is in any fused operation, False otherwise.
|
141
|
+
"""
|
142
|
+
return any(node in nodes for nodes in self.fusing_data.values())
|
143
|
+
|
144
|
+
def get_all_fused_operations(self) -> Dict[str, Tuple['BaseNode']]:
|
145
|
+
"""
|
146
|
+
Retrieve fused information.
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
Dict[str, List[BaseNode]]: The fusing data.
|
150
|
+
"""
|
151
|
+
return self.fusing_data
|
152
|
+
|
153
|
+
|
154
|
+
@staticmethod
|
155
|
+
def generate_fused_op_id(nodes: List['BaseNode']) -> str:
|
156
|
+
"""
|
157
|
+
Generates an identifier for a fused operation by concatenating
|
158
|
+
the names of the given nodes with a prefix.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
nodes (List[BaseNode]): A list of nodes to be fused.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
str: An identifier string for the fused operation.
|
165
|
+
"""
|
166
|
+
id = FUSED_OP_ID_PREFIX + '_'.join([node.name for node in nodes])
|
167
|
+
return id
|
168
|
+
|
169
|
+
def validate(self, graph) -> None:
|
170
|
+
"""
|
171
|
+
Validate that the fusing information is consistent with the given graph and generation logic.
|
172
|
+
|
173
|
+
This method performs the following checks:
|
174
|
+
1. All nodes in the fusing data exist in the graph.
|
175
|
+
2. Each fused sequence forms a valid linear chain in the graph:
|
176
|
+
- Each node (except the last) has exactly one successor, which is the next node in the sequence.
|
177
|
+
3. No node is part of more than one fused operation.
|
178
|
+
4. Each fused sequence matches a valid fusing pattern from the original set.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
graph: The computational graph to validate against. It is expected to have:
|
182
|
+
- `get_topo_sorted_nodes()`: Returns a list of nodes in topological order.
|
183
|
+
- `get_next_nodes(node)`: Returns a list of direct successor nodes.
|
184
|
+
|
185
|
+
Raises:
|
186
|
+
ValueError: If any validation check fails.
|
187
|
+
"""
|
188
|
+
graph_nodes = set(graph.get_topo_sorted_nodes()) # Retrieve all nodes from the graph
|
189
|
+
all_fused_nodes = set() # Track all nodes used in fusions to ensure no overlap
|
190
|
+
|
191
|
+
for op_id, nodes in self.fusing_data.items():
|
192
|
+
# Check 1: Ensure all fused nodes exist in the graph
|
193
|
+
for node in nodes:
|
194
|
+
if node not in graph_nodes:
|
195
|
+
raise ValueError(f"Fused operation {op_id} contains node {node.name} not present in the graph.")
|
196
|
+
|
197
|
+
# Check 2: Validate the fusion sequence forms a valid linear chain
|
198
|
+
for i in range(len(nodes) - 1): # Up to the second-to-last node
|
199
|
+
current_node = nodes[i]
|
200
|
+
next_node = nodes[i + 1]
|
201
|
+
successors = graph.get_next_nodes(current_node)
|
202
|
+
if len(successors) != 1 or successors[0] != next_node:
|
203
|
+
raise ValueError(
|
204
|
+
f"Fused operation {op_id} is not a valid linear chain: "
|
205
|
+
f"node {current_node.name} does not connect directly to {next_node.name} "
|
206
|
+
f"with exactly one successor (found successors: {[n.name for n in successors]})."
|
207
|
+
)
|
208
|
+
|
209
|
+
# Check 3: Ensure no node is reused across fusions
|
210
|
+
node_set = set(nodes)
|
211
|
+
overlap = node_set & all_fused_nodes
|
212
|
+
if overlap:
|
213
|
+
raise ValueError(
|
214
|
+
f"Fused operation {op_id} contains nodes already used in another fusion: "
|
215
|
+
f"{[node.name for node in overlap]}."
|
216
|
+
)
|
217
|
+
all_fused_nodes.update(node_set)
|
218
|
+
|
219
|
+
# Check 4: Ensure the sequence matches a valid fusing pattern
|
220
|
+
if not is_valid_fusion(self.fusing_patterns, nodes):
|
221
|
+
raise ValueError(
|
222
|
+
f"Fused operation {op_id} does not match any valid fusing pattern "
|
223
|
+
f"from {self.fusing_patterns}."
|
224
|
+
)
|
225
|
+
|
226
|
+
def is_nodes_eligible_to_be_fused(self, nodes: List['BaseNode']) -> bool:
|
227
|
+
"""
|
228
|
+
Check whether the given nodes are eligible to be fused based on predefined fusing patterns.
|
229
|
+
|
230
|
+
This method retrieves the fusing patterns from `self.fqc` and verifies whether the
|
231
|
+
given sequence of nodes matches any of the valid patterns.
|
232
|
+
|
233
|
+
Args:
|
234
|
+
nodes (List[BaseNode]): The list of nodes to check for fusion eligibility.
|
235
|
+
|
236
|
+
Returns:
|
237
|
+
bool: True if the nodes can be fused according to fusing patterns, otherwise False.
|
238
|
+
"""
|
239
|
+
# If no fusing patterns are defined, fusion is not possible
|
240
|
+
if not self.fusing_patterns:
|
241
|
+
return False
|
242
|
+
|
243
|
+
# Check if the provided nodes match a valid fusion pattern
|
244
|
+
return is_valid_fusion(fusing_patterns=self.fusing_patterns, nodes=nodes)
|
245
|
+
|
246
|
+
def __repr__(self) -> str:
|
247
|
+
"""
|
248
|
+
Return a string representation of the fusing information.
|
249
|
+
"""
|
250
|
+
fusing_data_repr = "\n".join(
|
251
|
+
f" {op_id}: [{', '.join(node.name for node in nodes)}]"
|
252
|
+
for op_id, nodes in self.fusing_data.items()
|
253
|
+
)
|
254
|
+
mapping_repr = ", ".join(
|
255
|
+
f"{node} -> {op_id}" for node, op_id in self.node_to_fused_node_map.items()
|
256
|
+
)
|
257
|
+
return (
|
258
|
+
f"FusingInfo(\n"
|
259
|
+
f" Total fused operations: {len(self.fusing_data)}\n"
|
260
|
+
f" Fusing Data:\n{fusing_data_repr}\n"
|
261
|
+
f" Node-to-Fused Mapping:\n {mapping_repr}\n"
|
262
|
+
f")"
|
263
|
+
)
|
264
|
+
|
265
|
+
|
266
|
+
class FusingInfoGenerator:
|
267
|
+
def __init__(self, fusing_patterns):
|
268
|
+
self._fusing_patterns = fusing_patterns
|
269
|
+
|
270
|
+
def generate_fusing_info(self, graph) -> FusingInfo:
|
271
|
+
"""
|
272
|
+
Generate fusing information based on the graph and fusing patterns.
|
273
|
+
|
274
|
+
Args:
|
275
|
+
graph: The input graph to analyze, expected to have methods like
|
276
|
+
get_topo_sorted_nodes() and get_next_nodes(node).
|
277
|
+
|
278
|
+
Returns:
|
279
|
+
A dictionary where keys are unique fusion identifiers (e.g., 'fused_op_0')
|
280
|
+
and values are lists of BaseNode objects representing nodes in that fusion.
|
281
|
+
|
282
|
+
Notes:
|
283
|
+
- Assumes get_valid_fusing_patterns_for_node and is_valid_fusion functions are defined elsewhere.
|
284
|
+
- Nodes are processed in topological order to respect operation sequence.
|
285
|
+
- Fusions are linear sequences (each node has exactly one successor).
|
286
|
+
- Each node belongs to at most one fused operation.
|
287
|
+
"""
|
288
|
+
if not self._fusing_patterns:
|
289
|
+
return FusingInfo(fusing_patterns=self._fusing_patterns)
|
290
|
+
|
291
|
+
# Find max fusion
|
292
|
+
max_layers_fusing = 0 if len(self._fusing_patterns) == 0 else max([len(fusing_pattern) for fusing_pattern in self._fusing_patterns])
|
293
|
+
|
294
|
+
# Travel along the graph to find layers for fusing
|
295
|
+
nodes = graph.get_topo_sorted_nodes()
|
296
|
+
|
297
|
+
fusing_info: Dict[str, Tuple['BaseNode']] = {}
|
298
|
+
fused_nodes = [] # nodes that are participating in fusing
|
299
|
+
|
300
|
+
for node in nodes:
|
301
|
+
# Skip if already in fusing
|
302
|
+
if node in fused_nodes:
|
303
|
+
continue
|
304
|
+
# Start fusing search
|
305
|
+
fusing_nodes = [] # nodes that are candidates for participating in fusing
|
306
|
+
patterns = copy.deepcopy(self._fusing_patterns)
|
307
|
+
next_nodes = [node]
|
308
|
+
for i in range(max_layers_fusing):
|
309
|
+
patterns = get_valid_fusing_patterns_for_node(patterns, next_nodes[0], i)
|
310
|
+
if len(patterns) == 0: # Give up if no more fusion pattern
|
311
|
+
break
|
312
|
+
fusing_nodes.append(next_nodes[0])
|
313
|
+
next_nodes = graph.get_next_nodes(fusing_nodes[-1])
|
314
|
+
if len(next_nodes) != 1: # Give up if node has more than one connection (not supported for fusion)
|
315
|
+
break
|
316
|
+
|
317
|
+
# New fusion
|
318
|
+
if is_valid_fusion(self._fusing_patterns, fusing_nodes):
|
319
|
+
fused_op_id = FusingInfo.generate_fused_op_id(fusing_nodes)
|
320
|
+
assert fused_op_id not in fusing_info, f"{fused_op_id} is already in fusing info: {fusing_info}"
|
321
|
+
fusing_info[fused_op_id] = tuple(fusing_nodes)
|
322
|
+
fused_nodes.extend(fusing_nodes)
|
323
|
+
|
324
|
+
return FusingInfo(fusing_data=fusing_info, fusing_patterns=self._fusing_patterns)
|
325
|
+
|
326
|
+
|
327
|
+
def get_valid_fusing_patterns_for_node(fusing_patterns: List[List[Any]],
|
328
|
+
node: 'BaseNode',
|
329
|
+
idx: int = 0) -> List[List[Any]]:
|
330
|
+
"""
|
331
|
+
Returns only the fusing patterns where a specific layer (at index idx) matches the given node — either by type or filter params.
|
332
|
+
|
333
|
+
Args:
|
334
|
+
fusing_patterns: supported fusings
|
335
|
+
node: node to decide if it can be a part of fusion
|
336
|
+
idx: index of layer in the fusion
|
337
|
+
|
338
|
+
Returns:
|
339
|
+
fusing_patterns after filtering non-relevant fusions
|
340
|
+
"""
|
341
|
+
valid_fusing_patterns = []
|
342
|
+
for i, fusing_pattern in enumerate(fusing_patterns):
|
343
|
+
if idx < len(fusing_pattern):
|
344
|
+
if ((type(fusing_pattern[idx]) == LayerFilterParams and node.is_match_filter_params(
|
345
|
+
fusing_pattern[idx])) or node.is_match_type(fusing_pattern[idx])):
|
346
|
+
valid_fusing_patterns.append(fusing_pattern)
|
347
|
+
|
348
|
+
# Return only valid patterns for this node
|
349
|
+
return valid_fusing_patterns
|
350
|
+
|
351
|
+
|
352
|
+
def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode']) -> bool:
|
353
|
+
"""
|
354
|
+
Check if the fusion is valid: exist in fusing_patterns
|
355
|
+
Args:
|
356
|
+
fusing_patterns: supported fusing patterns
|
357
|
+
nodes: nodes which are participating in fusion
|
358
|
+
Returns:
|
359
|
+
whether the fusion in valid
|
360
|
+
"""
|
361
|
+
fusion_depth = len(nodes)
|
362
|
+
if fusion_depth <= 1:
|
363
|
+
return False
|
364
|
+
for fusing_pattern in fusing_patterns:
|
365
|
+
if fusion_depth != len(fusing_pattern):
|
366
|
+
continue
|
367
|
+
counter = 0
|
368
|
+
for i, layer in enumerate(fusing_pattern):
|
369
|
+
if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or \
|
370
|
+
nodes[i].is_match_type(layer):
|
371
|
+
counter += 1
|
372
|
+
if counter == fusion_depth:
|
373
|
+
return True
|
374
|
+
return False
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -13,10 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
|
16
|
+
import copy
|
17
|
+
from typing import List, Tuple
|
17
18
|
|
18
|
-
from model_compression_toolkit.core.common import
|
19
|
-
from model_compression_toolkit.core.common.graph.base_graph import OutTensor
|
19
|
+
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator
|
20
|
+
from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
|
21
|
+
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig
|
22
|
+
from itertools import product
|
20
23
|
|
21
24
|
|
22
25
|
class FusedLayerType:
|
@@ -27,35 +30,41 @@ class FusedLayerType:
|
|
27
30
|
def __init__(self):
|
28
31
|
self.__name__ = 'FusedLayer'
|
29
32
|
|
30
|
-
|
31
33
|
class GraphFuser:
|
32
|
-
|
33
|
-
def create_fused_graph(self, graph: Graph) -> Dict[str, str]:
|
34
|
+
def apply_node_fusion(self, graph: Graph) -> Graph:
|
34
35
|
"""
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
36
|
+
Applies node fusion to the graph according the fusing_info it has.
|
37
|
+
|
38
|
+
The fusion process includes:
|
39
|
+
1. Generating new fused nodes to replace groups of original nodes.
|
40
|
+
2. Updating the graph structure to replace those nodes with the fused representations.
|
40
41
|
|
41
42
|
Args:
|
42
|
-
graph:
|
43
|
+
graph: The graph and its fusing metadata.
|
43
44
|
|
44
45
|
Returns:
|
45
|
-
|
46
|
+
The updated graph with fused nodes replacing the original node groups.
|
46
47
|
"""
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
48
|
+
graph_copy = copy.deepcopy(graph)
|
49
|
+
expected_fusing_info = FusingInfoGenerator(graph_copy.fusing_info.fusing_patterns).generate_fusing_info(graph_copy)
|
50
|
+
|
51
|
+
if expected_fusing_info != graph_copy.fusing_info:
|
52
|
+
raise ValueError(
|
53
|
+
f"Mismatch between expected and existing fusing information.\n"
|
54
|
+
f"Expected:\n{expected_fusing_info}\nExisting:\n{graph_copy.fusing_info}"
|
55
|
+
)
|
56
|
+
|
57
|
+
fused_operations = list(graph_copy.fusing_info.get_all_fused_operations().items())
|
58
|
+
for fused_node_id, original_nodes in fused_operations:
|
59
|
+
fused_node = self._create_fused_node(fused_node_id, original_nodes)
|
60
|
+
graph_copy.fusing_info.remove_fused_operation(fused_node_id)
|
61
|
+
self._replace_nodes_with_fused_node(graph_copy, original_nodes, fused_node)
|
62
|
+
|
63
|
+
return graph_copy
|
64
|
+
|
56
65
|
|
57
66
|
@staticmethod
|
58
|
-
def _create_fused_node(nodes:
|
67
|
+
def _create_fused_node(fused_node_id: str, nodes: Tuple[BaseNode]) -> BaseNode:
|
59
68
|
"""
|
60
69
|
Create a new node that represents the fusion of the given nodes.
|
61
70
|
|
@@ -67,22 +76,28 @@ class GraphFuser:
|
|
67
76
|
"""
|
68
77
|
# Create a new node with a name that reflects its components
|
69
78
|
# Use the input shape of the first node and output shape of the last node
|
70
|
-
|
79
|
+
# TODO: consider replacing the fused node with a sub-model to allow inference on it, etc.
|
80
|
+
fused_node = BaseNode(name=fused_node_id,
|
71
81
|
framework_attr={},
|
72
82
|
input_shape=nodes[0].input_shape,
|
73
83
|
output_shape=nodes[-1].output_shape,
|
74
84
|
weights={},
|
75
85
|
layer_class=FusedLayerType)
|
76
86
|
|
77
|
-
|
78
|
-
|
87
|
+
activation_cfgs = [c.activation_quantization_cfg for c in nodes[-1].candidates_quantization_cfg]
|
88
|
+
fused_node.candidates_quantization_cfg = [
|
89
|
+
CandidateNodeQuantizationConfig(weights_quantization_cfg=None, activation_quantization_cfg=a) for a in
|
90
|
+
activation_cfgs]
|
91
|
+
|
92
|
+
# Keep the final configurations if they were set already.
|
93
|
+
fused_node.final_weights_quantization_cfg = nodes[0].final_weights_quantization_cfg
|
79
94
|
fused_node.final_activation_quantization_cfg = nodes[-1].final_activation_quantization_cfg
|
80
95
|
|
81
96
|
return fused_node
|
82
97
|
|
83
98
|
@staticmethod
|
84
99
|
def _replace_nodes_with_fused_node(graph: Graph,
|
85
|
-
nodes_to_fuse:
|
100
|
+
nodes_to_fuse: Tuple[BaseNode],
|
86
101
|
fused_node: BaseNode):
|
87
102
|
"""
|
88
103
|
Replace the specified nodes in the graph with a new fused node.
|
@@ -118,6 +133,11 @@ class GraphFuser:
|
|
118
133
|
for next_node in subsequent_nodes:
|
119
134
|
assert next_node in nodes_to_fuse # Ensure we're not removing edges outside the fusion
|
120
135
|
graph.remove_edge(current_node, next_node)
|
136
|
+
# next_node can have more incoming edges from other nodes that are not
|
137
|
+
# in the fusion and we should remove them to:
|
138
|
+
in_edges = graph.incoming_edges(next_node)
|
139
|
+
for ie in in_edges:
|
140
|
+
graph.remove_edge(ie.source_node, next_node)
|
121
141
|
|
122
142
|
# Handle the case where fused nodes are part of the graph's outputs
|
123
143
|
graph_output_tensors = graph.get_outputs()
|
@@ -136,3 +156,5 @@ class GraphFuser:
|
|
136
156
|
|
137
157
|
# Finally, add the new fused node to the graph
|
138
158
|
graph.add_node(fused_node)
|
159
|
+
|
160
|
+
|
@@ -15,7 +15,8 @@
|
|
15
15
|
from collections import namedtuple
|
16
16
|
|
17
17
|
from copy import copy, deepcopy
|
18
|
-
from
|
18
|
+
from functools import wraps
|
19
|
+
from typing import List, Tuple, Any, Callable
|
19
20
|
|
20
21
|
import networkx as nx
|
21
22
|
import numpy as np
|
@@ -23,6 +24,7 @@ import numpy as np
|
|
23
24
|
from networkx.algorithms.dag import topological_sort
|
24
25
|
|
25
26
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
27
|
+
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo
|
26
28
|
from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX, EDGE_SOURCE_INDEX
|
27
29
|
from model_compression_toolkit.core.common.graph.edge import Edge, convert_to_edge
|
28
30
|
from model_compression_toolkit.core.common.graph.graph_searches import GraphSearches
|
@@ -36,6 +38,27 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
|
|
36
38
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
37
39
|
FrameworkQuantizationCapabilities
|
38
40
|
|
41
|
+
|
42
|
+
def validate_graph_after_change(method: Callable) -> Callable:
|
43
|
+
"""
|
44
|
+
Decorator for graph-mutating methods. After the decorated method executes,
|
45
|
+
this decorator calls `self.validate()` to ensure the graph remains in a valid state.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
method: The graph-modifying method to wrap.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
A wrapped method that validates the graph after execution.
|
52
|
+
"""
|
53
|
+
@wraps(method)
|
54
|
+
def wrapper(self, *args, **kwargs):
|
55
|
+
result = method(self, *args, **kwargs)
|
56
|
+
if not self.skip_validation_check:
|
57
|
+
self.validate() # calls Graph.validate(). Ensure graph consistency after changes.
|
58
|
+
return result
|
59
|
+
return wrapper
|
60
|
+
|
61
|
+
|
39
62
|
OutTensor = namedtuple('OutTensor', 'node node_out_index')
|
40
63
|
|
41
64
|
|
@@ -63,6 +86,11 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
63
86
|
"""
|
64
87
|
|
65
88
|
super().__init__(**attr)
|
89
|
+
|
90
|
+
# This must be set first to ensure it's available when validation runs during graph creation.
|
91
|
+
self._skip_validation_check = False
|
92
|
+
self._fusing_info = FusingInfo()
|
93
|
+
|
66
94
|
self.name = name
|
67
95
|
self.input_nodes = input_nodes
|
68
96
|
self.output_nodes = output_nodes
|
@@ -75,7 +103,25 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
75
103
|
**e.get_attributes())
|
76
104
|
self.user_info = UserInformation()
|
77
105
|
self.fw_info = fw_info
|
78
|
-
|
106
|
+
|
107
|
+
@property
|
108
|
+
def skip_validation_check(self) -> bool:
|
109
|
+
return self._skip_validation_check
|
110
|
+
|
111
|
+
@skip_validation_check.setter
|
112
|
+
def skip_validation_check(self, value: bool):
|
113
|
+
if not isinstance(value, bool):
|
114
|
+
raise ValueError("skip_validation_check must be a boolean.")
|
115
|
+
self._skip_validation_check = value
|
116
|
+
|
117
|
+
@property
|
118
|
+
def fusing_info(self) -> FusingInfo:
|
119
|
+
return self._fusing_info
|
120
|
+
|
121
|
+
@fusing_info.setter
|
122
|
+
@validate_graph_after_change
|
123
|
+
def fusing_info(self, fusing_info: FusingInfo):
|
124
|
+
self._fusing_info = fusing_info
|
79
125
|
|
80
126
|
def set_fw_info(self,
|
81
127
|
fw_info: FrameworkInfo):
|
@@ -139,6 +185,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
139
185
|
|
140
186
|
return self.output_nodes
|
141
187
|
|
188
|
+
@validate_graph_after_change
|
142
189
|
def set_inputs(self,
|
143
190
|
input_nodes: List[BaseNode]):
|
144
191
|
"""
|
@@ -149,6 +196,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
149
196
|
|
150
197
|
self.input_nodes = input_nodes
|
151
198
|
|
199
|
+
@validate_graph_after_change
|
152
200
|
def set_outputs(self,
|
153
201
|
output_nodes: List[OutTensor]):
|
154
202
|
"""
|
@@ -321,6 +369,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
321
369
|
sort_attr = None
|
322
370
|
return [edges_list.source_node for edges_list in self.incoming_edges(node_obj, sort_by_attr=sort_attr)]
|
323
371
|
|
372
|
+
@validate_graph_after_change
|
324
373
|
def reconnect_out_edges(self,
|
325
374
|
current_node: BaseNode,
|
326
375
|
new_node: BaseNode):
|
@@ -337,6 +386,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
337
386
|
self.add_edge(new_node, oe.sink_node, **oe.get_attributes())
|
338
387
|
self.remove_edge(current_node, oe.sink_node)
|
339
388
|
|
389
|
+
@validate_graph_after_change
|
340
390
|
def reconnect_in_edges(self,
|
341
391
|
current_node: BaseNode,
|
342
392
|
new_node: BaseNode):
|
@@ -353,6 +403,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
353
403
|
self.add_edge(ie.source_node, new_node, **ie.get_attributes())
|
354
404
|
self.remove_edge(ie.source_node, current_node)
|
355
405
|
|
406
|
+
@validate_graph_after_change
|
356
407
|
def add_node_with_in_edges(self, new_node: BaseNode, input_nodes: List[BaseNode],
|
357
408
|
input_nodes_output_index: List[int] = []):
|
358
409
|
"""
|
@@ -378,6 +429,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
378
429
|
for sink_index, (in_node, source_index) in enumerate(zip(input_nodes, input_nodes_output_index)):
|
379
430
|
self.add_edge(in_node, new_node, source_index=source_index, sink_index=sink_index)
|
380
431
|
|
432
|
+
@validate_graph_after_change
|
381
433
|
def replace_output_node(self,
|
382
434
|
current_node: BaseNode,
|
383
435
|
new_node: BaseNode):
|
@@ -400,6 +452,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
400
452
|
new_graph_outputs[graph_ot_index] = OutTensor(new_node, ot.node_out_index)
|
401
453
|
self.set_outputs(new_graph_outputs)
|
402
454
|
|
455
|
+
@validate_graph_after_change
|
403
456
|
def replace_input_node(self,
|
404
457
|
current_node: BaseNode,
|
405
458
|
new_node: BaseNode):
|
@@ -424,6 +477,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
424
477
|
new_graph_inputs.append(new_node)
|
425
478
|
self.set_inputs(new_graph_inputs)
|
426
479
|
|
480
|
+
@validate_graph_after_change
|
427
481
|
def remove_node(self,
|
428
482
|
node_to_remove: BaseNode,
|
429
483
|
new_graph_inputs: List[BaseNode] = None,
|
@@ -713,16 +767,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
713
767
|
node = prev_nodes[0]
|
714
768
|
return node
|
715
769
|
|
716
|
-
def update_fused_nodes(self, fusion: List[Any]):
|
717
|
-
"""
|
718
|
-
Updates the graphs fusions list with a new list of nodes that have been fused.
|
719
|
-
|
720
|
-
Args:
|
721
|
-
fusion: A list of nodes that have been fused.
|
722
|
-
|
723
|
-
"""
|
724
|
-
self.fused_nodes.append(fusion)
|
725
|
-
|
726
770
|
def has_any_configurable_activation(self) -> bool:
|
727
771
|
"""
|
728
772
|
Checks whether any node in the graph has a configurable activation quantization.
|
@@ -742,6 +786,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
742
786
|
|
743
787
|
return any([n.has_any_configurable_weight() for n in self.nodes])
|
744
788
|
|
789
|
+
@validate_graph_after_change
|
745
790
|
def replace_node(self, node_to_replace: BaseNode, new_node: BaseNode):
|
746
791
|
"""
|
747
792
|
Replaces a node in the graph with a new node.
|
@@ -867,4 +912,36 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
867
912
|
|
868
913
|
return intermediate_nodes, next_node
|
869
914
|
|
915
|
+
def disable_fused_nodes_activation_quantization(self):
|
916
|
+
"""
|
917
|
+
Disable activation quantization for all nodes in fused operations,
|
918
|
+
except for the last node in each fused group.
|
919
|
+
"""
|
920
|
+
nodes_to_disable = [node for nodes in self.fusing_info.get_all_fused_operations().values() for node in nodes[:-1]]
|
921
|
+
for node in nodes_to_disable:
|
922
|
+
for qc in node.candidates_quantization_cfg:
|
923
|
+
qc.activation_quantization_cfg.enable_activation_quantization = False
|
924
|
+
|
925
|
+
def validate(self):
|
926
|
+
"""
|
927
|
+
Validate that the current state of the graph is consistent with
|
928
|
+
the fusing information (e.g., no missing or incorrect fused node mapping).
|
870
929
|
|
930
|
+
Returns:
|
931
|
+
The result of the FusingInfo validation logic (typically None or raises error).
|
932
|
+
"""
|
933
|
+
return self.fusing_info.validate(self)
|
934
|
+
|
935
|
+
@validate_graph_after_change
|
936
|
+
def add_edge(self, *args, **kwargs):
|
937
|
+
"""
|
938
|
+
Wrap networkx functions (that modifies the graph) with our validate decorator.
|
939
|
+
"""
|
940
|
+
return super().add_edge(*args, **kwargs)
|
941
|
+
|
942
|
+
@validate_graph_after_change
|
943
|
+
def remove_edge(self, *args, **kwargs):
|
944
|
+
"""
|
945
|
+
Wrap networkx functions (that modifies the graph) with our validate decorator.
|
946
|
+
"""
|
947
|
+
return super().remove_edge(*args, **kwargs)
|
@@ -65,6 +65,7 @@ def search_bit_width(graph: Graph,
|
|
65
65
|
bit-width index on the node).
|
66
66
|
|
67
67
|
"""
|
68
|
+
|
68
69
|
assert target_resource_utilization.is_any_restricted()
|
69
70
|
|
70
71
|
# If we only run weights compression with MP than no need to consider activation quantization when computing the
|
@@ -88,6 +89,11 @@ def search_bit_width(graph: Graph,
|
|
88
89
|
if search_method != BitWidthSearchMethod.INTEGER_PROGRAMMING:
|
89
90
|
raise NotImplementedError()
|
90
91
|
|
92
|
+
# Validation is skipped during the mixed-precision search configuration because fusing information is not
|
93
|
+
# relevant for the virtual graph. Therefore, validation checks are disabled before the search begins and
|
94
|
+
# re-enabled once it completes.
|
95
|
+
graph.skip_validation_check = True
|
96
|
+
|
91
97
|
# Search manager and LP are highly coupled, so LP search method was moved inside search manager.
|
92
98
|
search_manager = MixedPrecisionSearchManager(graph,
|
93
99
|
fw_info,
|
@@ -96,6 +102,8 @@ def search_bit_width(graph: Graph,
|
|
96
102
|
target_resource_utilization)
|
97
103
|
result_bit_cfg = search_manager.search()
|
98
104
|
|
105
|
+
graph.skip_validation_check = False
|
106
|
+
|
99
107
|
if mp_config.refine_mp_solution:
|
100
108
|
result_bit_cfg = greedy_solution_refinement_procedure(result_bit_cfg, search_manager, target_resource_utilization)
|
101
109
|
|
@@ -71,11 +71,13 @@ class CandidateNodeQuantizationConfig(BaseNodeQuantizationConfig):
|
|
71
71
|
|
72
72
|
if weights_quantization_cfg is not None:
|
73
73
|
self.weights_quantization_cfg = weights_quantization_cfg
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
"Ensure QuantizationConfig, OpQuantizationConfig, weights quantization function, "
|
78
|
-
"parameters function, and weights attribute quantization config are provided.")
|
79
|
-
self.weights_quantization_cfg = NodeWeightsQuantizationConfig(qc=qc, op_cfg=op_cfg,
|
74
|
+
elif all(v is not None for v in (qc, op_cfg, node_attrs_list)):
|
75
|
+
self.weights_quantization_cfg = NodeWeightsQuantizationConfig(qc=qc,
|
76
|
+
op_cfg=op_cfg,
|
80
77
|
weights_channels_axis=weights_channels_axis,
|
81
78
|
node_attrs_list=node_attrs_list)
|
79
|
+
else:
|
80
|
+
self.weights_quantization_cfg = None
|
81
|
+
Logger.debug("Setting weights quantization config as None during CandidateNodeQuantizationConfig creation."
|
82
|
+
"Notice, this should happen only for FLN nodes.")
|
83
|
+
|
@@ -19,11 +19,11 @@ from typing import Callable
|
|
19
19
|
|
20
20
|
import numpy as np
|
21
21
|
|
22
|
+
from model_compression_toolkit.core.common import Graph
|
22
23
|
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
23
24
|
from model_compression_toolkit.core import common
|
24
25
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
|
25
26
|
from model_compression_toolkit.logger import Logger
|
26
|
-
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
27
27
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
28
28
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
29
29
|
from mct_quantizers import QuantizationMethod
|
@@ -143,6 +143,21 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
|
|
143
143
|
AttributeQuantizationConfig(
|
144
144
|
enable_weights_quantization=False)))
|
145
145
|
|
146
|
+
# Check if the source node was part of a fusion. If so, there are two cases:
|
147
|
+
# either this is no longer a fusion, and the fusion info should be updated by removing
|
148
|
+
# the current info, or this creates a new fusion and the old pattern should be
|
149
|
+
# replaced with the new one.
|
150
|
+
fi = graph.fusing_info
|
151
|
+
fused_op = fi.get_fused_node_name(source_node.name)
|
152
|
+
if fused_op:
|
153
|
+
fused_nodes = list(fi.get_fused_nodes(fused_op))
|
154
|
+
assert source_node in fused_nodes
|
155
|
+
fused_nodes.insert(fused_nodes.index(source_node)+1, bn_node)
|
156
|
+
fi.remove_fused_operation(fused_op)
|
157
|
+
if fi.is_nodes_eligible_to_be_fused(fused_nodes):
|
158
|
+
op_id = fi.generate_fused_op_id(fused_nodes)
|
159
|
+
fi.add_fused_operation(op_id, tuple(fused_nodes))
|
160
|
+
|
146
161
|
graph.reconnect_out_edges(current_node=source_node, new_node=bn_node)
|
147
162
|
graph.replace_output_node(current_node=source_node, new_node=bn_node)
|
148
163
|
graph.add_node_with_in_edges(bn_node, [source_node])
|
@@ -18,7 +18,7 @@ from typing import Callable, Any
|
|
18
18
|
|
19
19
|
from model_compression_toolkit.core.common import FrameworkInfo
|
20
20
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
21
|
-
from model_compression_toolkit.core.common.fusion.
|
21
|
+
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator
|
22
22
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
23
23
|
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
24
24
|
from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates
|
@@ -136,6 +136,7 @@ def get_finalized_graph(initial_graph: Graph,
|
|
136
136
|
node.prior_info = fw_impl.get_node_prior_info(node=node,
|
137
137
|
fw_info=fw_info,
|
138
138
|
graph=graph)
|
139
|
+
|
139
140
|
##################################################
|
140
141
|
# Graph substitution (pre statistics collection)
|
141
142
|
##################################################
|
@@ -161,7 +162,9 @@ def get_finalized_graph(initial_graph: Graph,
|
|
161
162
|
######################################
|
162
163
|
# Layer fusing
|
163
164
|
######################################
|
164
|
-
|
165
|
+
fusing_info = FusingInfoGenerator(fqc.get_fusing_patterns()).generate_fusing_info(transformed_graph)
|
166
|
+
transformed_graph.fusing_info = fusing_info
|
167
|
+
transformed_graph.disable_fused_nodes_activation_quantization()
|
165
168
|
|
166
169
|
######################################
|
167
170
|
# Channel equalization
|
@@ -184,15 +184,14 @@ def core_runner(in_model: Any,
|
|
184
184
|
|
185
185
|
scheduler_info = None
|
186
186
|
if core_config.debug_config.simulate_scheduler:
|
187
|
-
|
188
|
-
|
189
|
-
memory_graph = MemoryGraph(graph_to_fuse)
|
187
|
+
fused_graph = GraphFuser().apply_node_fusion(tg)
|
188
|
+
memory_graph = MemoryGraph(fused_graph)
|
190
189
|
schedule, max_cut, cuts = compute_graph_max_cut(memory_graph)
|
191
190
|
scheduler_info = SchedulerInfo(
|
192
191
|
operators_scheduling=schedule,
|
193
192
|
max_cut=float(max_cut),
|
194
193
|
cuts=cuts,
|
195
|
-
fused_nodes_mapping=
|
194
|
+
fused_nodes_mapping=tg.fusing_info.get_node_to_fused_node_map()
|
196
195
|
)
|
197
196
|
|
198
197
|
return tg, bit_widths_config, hessian_info_service, scheduler_info
|
@@ -30,72 +30,8 @@ from model_compression_toolkit.target_platform_capabilities.schema.v1 import (
|
|
30
30
|
OperatorsSetBase,
|
31
31
|
OperatorsSet,
|
32
32
|
OperatorSetGroup,
|
33
|
-
Fusing
|
34
|
-
|
35
|
-
|
36
|
-
class OperatorSetNames(str, Enum):
|
37
|
-
CONV = "Conv"
|
38
|
-
DEPTHWISE_CONV = "DepthwiseConv2D"
|
39
|
-
CONV_TRANSPOSE = "ConvTranspose"
|
40
|
-
FULLY_CONNECTED = "FullyConnected"
|
41
|
-
CONCATENATE = "Concatenate"
|
42
|
-
STACK = "Stack"
|
43
|
-
UNSTACK = "Unstack"
|
44
|
-
GATHER = "Gather"
|
45
|
-
EXPAND = "Expend"
|
46
|
-
BATCH_NORM = "BatchNorm"
|
47
|
-
L2NORM = "L2Norm"
|
48
|
-
RELU = "ReLU"
|
49
|
-
RELU6 = "ReLU6"
|
50
|
-
LEAKY_RELU = "LeakyReLU"
|
51
|
-
ELU = "Elu"
|
52
|
-
HARD_TANH = "HardTanh"
|
53
|
-
ADD = "Add"
|
54
|
-
SUB = "Sub"
|
55
|
-
MUL = "Mul"
|
56
|
-
DIV = "Div"
|
57
|
-
MIN = "Min"
|
58
|
-
MAX = "Max"
|
59
|
-
PRELU = "PReLU"
|
60
|
-
ADD_BIAS = "AddBias"
|
61
|
-
SWISH = "Swish"
|
62
|
-
SIGMOID = "Sigmoid"
|
63
|
-
SOFTMAX = "Softmax"
|
64
|
-
LOG_SOFTMAX = "LogSoftmax"
|
65
|
-
TANH = "Tanh"
|
66
|
-
GELU = "Gelu"
|
67
|
-
HARDSIGMOID = "HardSigmoid"
|
68
|
-
HARDSWISH = "HardSwish"
|
69
|
-
FLATTEN = "Flatten"
|
70
|
-
GET_ITEM = "GetItem"
|
71
|
-
RESHAPE = "Reshape"
|
72
|
-
UNSQUEEZE = "Unsqueeze"
|
73
|
-
SQUEEZE = "Squeeze"
|
74
|
-
PERMUTE = "Permute"
|
75
|
-
TRANSPOSE = "Transpose"
|
76
|
-
DROPOUT = "Dropout"
|
77
|
-
SPLIT_CHUNK = "SplitChunk"
|
78
|
-
MAXPOOL = "MaxPool"
|
79
|
-
AVGPOOL = "AvgPool"
|
80
|
-
SIZE = "Size"
|
81
|
-
SHAPE = "Shape"
|
82
|
-
EQUAL = "Equal"
|
83
|
-
ARGMAX = "ArgMax"
|
84
|
-
TOPK = "TopK"
|
85
|
-
FAKE_QUANT = "FakeQuant"
|
86
|
-
COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression"
|
87
|
-
BOX_DECODE = "BoxDecode"
|
88
|
-
ZERO_PADDING2D = "ZeroPadding2D"
|
89
|
-
CAST = "Cast"
|
90
|
-
RESIZE = "Resize"
|
91
|
-
PAD = "Pad"
|
92
|
-
FOLD = "Fold"
|
93
|
-
STRIDED_SLICE = "StridedSlice"
|
94
|
-
SSD_POST_PROCESS = "SSDPostProcess"
|
95
|
-
|
96
|
-
@classmethod
|
97
|
-
def get_values(cls):
|
98
|
-
return [v.value for v in cls]
|
33
|
+
Fusing,
|
34
|
+
OperatorSetNames)
|
99
35
|
|
100
36
|
|
101
37
|
class TargetPlatformCapabilities(BaseModel):
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py
CHANGED
@@ -93,7 +93,6 @@ class AttachTpcToKeras(AttachTpcToFramework):
|
|
93
93
|
OperatorSetNames.TOPK: [tf.nn.top_k],
|
94
94
|
OperatorSetNames.FAKE_QUANT: [tf.quantization.fake_quant_with_min_max_vars],
|
95
95
|
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [tf.image.combined_non_max_suppression],
|
96
|
-
OperatorSetNames.BOX_DECODE: [], # no such operator in keras
|
97
96
|
OperatorSetNames.ZERO_PADDING2D: [ZeroPadding2D],
|
98
97
|
OperatorSetNames.CAST: [tf.cast],
|
99
98
|
OperatorSetNames.STRIDED_SLICE: [tf.strided_slice],
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py
CHANGED
@@ -98,7 +98,6 @@ class AttachTpcToPytorch(AttachTpcToFramework):
|
|
98
98
|
Eq('p', 2) | Eq('p', None))],
|
99
99
|
OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
|
100
100
|
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [], # no such operator in pytorch
|
101
|
-
OperatorSetNames.BOX_DECODE: [] # no such operator in pytorch
|
102
101
|
}
|
103
102
|
|
104
103
|
pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
|
@@ -1,131 +0,0 @@
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
import copy
|
16
|
-
from typing import Any, List
|
17
|
-
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
18
|
-
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
19
|
-
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
20
|
-
FrameworkQuantizationCapabilities
|
21
|
-
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.layer_filter_params import LayerFilterParams
|
22
|
-
|
23
|
-
|
24
|
-
def filter_fusing_patterns(fusing_patterns: List[List[Any]], node: BaseNode, idx: int = 0) -> List[List[Any]]:
|
25
|
-
"""
|
26
|
-
Update relevant fusing patterns object if layer number 'idx' inside the fusion matches the node
|
27
|
-
Args:
|
28
|
-
fusing_patterns: supported fusings
|
29
|
-
node: node to decide if it can be a part of fusion
|
30
|
-
idx: index of layer in the fusion
|
31
|
-
Returns:
|
32
|
-
fusing_patterns after filtering non-relevant fusions
|
33
|
-
"""
|
34
|
-
valid_fusing_patterns = []
|
35
|
-
for i, fusing_pattern in enumerate(fusing_patterns):
|
36
|
-
if idx < len(fusing_pattern):
|
37
|
-
if (type(fusing_pattern[idx]) == LayerFilterParams and node.is_match_filter_params(fusing_pattern[idx])) or \
|
38
|
-
node.is_match_type(fusing_pattern[idx]):
|
39
|
-
valid_fusing_patterns.append(fusing_pattern)
|
40
|
-
|
41
|
-
# Return only valid patterns for this node
|
42
|
-
return valid_fusing_patterns
|
43
|
-
|
44
|
-
|
45
|
-
def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List[BaseNode]) -> bool:
|
46
|
-
"""
|
47
|
-
Check if the fusion is valid: exist in fusing_patterns
|
48
|
-
Args:
|
49
|
-
fusing_patterns: supported fusing patterns
|
50
|
-
nodes: nodes which are participating in fusion
|
51
|
-
Returns:
|
52
|
-
whether the fusion in valid
|
53
|
-
"""
|
54
|
-
fusion_depth = len(nodes)
|
55
|
-
if fusion_depth <= 1:
|
56
|
-
return False
|
57
|
-
for fusing_pattern in fusing_patterns:
|
58
|
-
if fusion_depth != len(fusing_pattern):
|
59
|
-
continue
|
60
|
-
counter = 0
|
61
|
-
for i, layer in enumerate(fusing_pattern):
|
62
|
-
if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or \
|
63
|
-
nodes[i].is_match_type(layer):
|
64
|
-
counter += 1
|
65
|
-
if counter == fusion_depth:
|
66
|
-
return True
|
67
|
-
return False
|
68
|
-
|
69
|
-
|
70
|
-
def disable_nodes_activation_quantization(nodes: List[BaseNode]):
|
71
|
-
"""
|
72
|
-
Disable activation for non-quantization needed due to fusion
|
73
|
-
Args:
|
74
|
-
nodes: nodes to update their activation quantization
|
75
|
-
"""
|
76
|
-
for node in nodes:
|
77
|
-
for qc in node.candidates_quantization_cfg:
|
78
|
-
qc.activation_quantization_cfg.enable_activation_quantization = False
|
79
|
-
|
80
|
-
|
81
|
-
def fusion(graph: Graph, fqc: FrameworkQuantizationCapabilities) -> Graph:
|
82
|
-
"""
|
83
|
-
Fusing defines a list of operators that should be combined and treated as a single operator,
|
84
|
-
hence no quantization is applied between them when they appear in the graph.
|
85
|
-
This function search and disable quantization for such patterns.
|
86
|
-
Args:
|
87
|
-
graph: Graph we apply the fusion on.
|
88
|
-
fqc: FrameworkQuantizationCapabilities object that describes the desired inference target platform (includes fusing patterns MCT should handle).
|
89
|
-
Returns:
|
90
|
-
Graph after applying fusion activation marking.
|
91
|
-
"""
|
92
|
-
fusing_patterns = fqc.get_fusing_patterns()
|
93
|
-
if len(fusing_patterns) == 0:
|
94
|
-
return graph
|
95
|
-
|
96
|
-
# Find max fusion
|
97
|
-
max_layers_fusing = 0 if len(fusing_patterns) == 0 else max([len(fusing_pattern) for fusing_pattern in fusing_patterns])
|
98
|
-
|
99
|
-
|
100
|
-
# -------------------------------- #
|
101
|
-
# Fusion algorithm
|
102
|
-
# -------------------------------- #
|
103
|
-
fused_graph = copy.deepcopy(graph)
|
104
|
-
|
105
|
-
# Travel along the graph to find layers for fusing
|
106
|
-
nodes = fused_graph.get_topo_sorted_nodes()
|
107
|
-
fused_nodes = [] # nodes that are participating in fusing
|
108
|
-
for node in nodes:
|
109
|
-
# Skip if already in fusing
|
110
|
-
if node in fused_nodes:
|
111
|
-
continue
|
112
|
-
# Start fusing search
|
113
|
-
fusing_nodes = [] # nodes that are candidates for participating in fusing
|
114
|
-
patterns = copy.deepcopy(fusing_patterns)
|
115
|
-
next_nodes = [node]
|
116
|
-
for i in range(max_layers_fusing):
|
117
|
-
patterns = filter_fusing_patterns(patterns, next_nodes[0], i)
|
118
|
-
if len(patterns) == 0: # Give up if no more fusion pattern
|
119
|
-
break
|
120
|
-
fusing_nodes.append(next_nodes[0])
|
121
|
-
next_nodes = fused_graph.get_next_nodes(fusing_nodes[-1])
|
122
|
-
if len(next_nodes) != 1: # Give up if node has more than one connection (not supported for fusion)
|
123
|
-
break
|
124
|
-
|
125
|
-
# New fusion: mark all nodes in the fusion except last one
|
126
|
-
if is_valid_fusion(fusing_patterns, fusing_nodes):
|
127
|
-
fused_nodes.extend(fusing_nodes)
|
128
|
-
disable_nodes_activation_quantization(fusing_nodes[:-1])
|
129
|
-
fused_graph.update_fused_nodes(fusing_nodes)
|
130
|
-
|
131
|
-
return fused_graph
|
File without changes
|
File without changes
|
{mct_nightly-2.3.0.20250402.536.dist-info → mct_nightly-2.3.0.20250404.535.dist-info}/top_level.txt
RENAMED
File without changes
|