mct-nightly 2.3.0.20250401.618__py3-none-any.whl → 2.3.0.20250403.518__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {mct_nightly-2.3.0.20250401.618.dist-info → mct_nightly-2.3.0.20250403.518.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.3.0.20250401.618.dist-info → mct_nightly-2.3.0.20250403.518.dist-info}/RECORD +18 -17
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/fusion/fusing_info.py +374 -0
  5. model_compression_toolkit/core/common/fusion/graph_fuser.py +50 -28
  6. model_compression_toolkit/core/common/graph/base_graph.py +89 -12
  7. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +8 -0
  8. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +8 -6
  9. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +16 -1
  10. model_compression_toolkit/core/graph_prep_runner.py +5 -2
  11. model_compression_toolkit/core/runner.py +3 -4
  12. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +1 -1
  13. model_compression_toolkit/target_platform_capabilities/schema/v2.py +177 -0
  14. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py +1 -0
  15. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +2 -1
  16. model_compression_toolkit/core/common/fusion/layer_fusing.py +0 -131
  17. {mct_nightly-2.3.0.20250401.618.dist-info → mct_nightly-2.3.0.20250403.518.dist-info}/WHEEL +0 -0
  18. {mct_nightly-2.3.0.20250401.618.dist-info → mct_nightly-2.3.0.20250403.518.dist-info}/licenses/LICENSE.md +0 -0
  19. {mct_nightly-2.3.0.20250401.618.dist-info → mct_nightly-2.3.0.20250403.518.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250401.618
3
+ Version: 2.3.0.20250403.518
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: Apache Software License
@@ -1,5 +1,5 @@
1
- mct_nightly-2.3.0.20250401.618.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=sgIxUWX9jDSvUZnwMqs3nHNjXfhgFSfniDDr2vvRTuQ,1557
1
+ mct_nightly-2.3.0.20250403.518.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=z7w2tRBoJC1dhCtnjEKyK834X-V0TBq_pKLiNWkHc5s,1557
3
3
  model_compression_toolkit/constants.py,sha256=2ltuH-gdaLZoZV4CPUgKjC3S9ojz2z4OTVdenyVEypU,3912
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
@@ -7,9 +7,9 @@ model_compression_toolkit/metadata.py,sha256=x_Bk4VpzILdsFax6--CZ3X18qUTP28sbF_A
7
7
  model_compression_toolkit/verify_packages.py,sha256=TlS-K1EP-QsghqWUW7SDPkAJiUf7ryw4tvhFDe6rCUk,1405
8
8
  model_compression_toolkit/core/__init__.py,sha256=8a0wUNBKwTdJGDk_Ho6WQAXjGuCqQZG1FUxxJlAV8L8,2096
9
9
  model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
10
- model_compression_toolkit/core/graph_prep_runner.py,sha256=CVTjBaci8F6EP3IKDnRMfxkP-Sv8qY8GpkGt6FyII2U,11376
10
+ model_compression_toolkit/core/graph_prep_runner.py,sha256=C6eUTd-fcgxk0LUbt51gFZwmyDDDEB8-9Q4kr9ujYvI,11555
11
11
  model_compression_toolkit/core/quantization_prep_runner.py,sha256=DPevqQ8brkdut8K5f5v9g5lbT3r1GSmhLAk3NkL40Fg,6593
12
- model_compression_toolkit/core/runner.py,sha256=WjZMVXc-OGBTnkiH0PRjNdJEM5pKQRPvLHXor5tjwjk,13096
12
+ model_compression_toolkit/core/runner.py,sha256=_r6cieb7Ur2BeHQK5XxTZHogjyA0utybvIVbH06CBHY,13056
13
13
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
14
14
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
15
15
  model_compression_toolkit/core/common/framework_implementation.py,sha256=s3yiqnbWkwfnAB1sSal_KAuqVg27rLhAJ2O8LHUbSHE,22494
@@ -31,10 +31,10 @@ model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.p
31
31
  model_compression_toolkit/core/common/collectors/statistics_collector.py,sha256=psijsQZefwjMDH8SU5E18n65HiGtQilPhKr1hhzZX-I,8268
32
32
  model_compression_toolkit/core/common/collectors/weighted_histogram_collector.py,sha256=zp3dE7YTqWmkD5QWdRhsl9zD8W6Lr96G1Wjw1g2D3T0,4894
33
33
  model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
34
- model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=b41_4rL_Adiza4vpWlmmqgvkpUmWVdfdx0nEIB0p2n8,6195
35
- model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=-2fnjyC9q2RPw9st6RxROW-gdtT2mSRz0QZ_Gz1KDz4,5579
34
+ model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=LfzVS9B6r2KCwf8rcCUdepEQhWkt287SoXfwoudpfFo,15496
35
+ model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=F0AaAUBpJ9JjHMB5H2LD9pdwTSWJK-Kqm9dQmGHX1Jo,7368
36
36
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
37
- model_compression_toolkit/core/common/graph/base_graph.py,sha256=cSwHUqwZEiR1t2DaBfc7_qSJbtX8crpqerN4ol9v3H8,38859
37
+ model_compression_toolkit/core/common/graph/base_graph.py,sha256=hedhjVula5rPv0vN0CLBDtPYM8SH3cM6FAL62aFfF7U,41767
38
38
  model_compression_toolkit/core/common/graph/base_node.py,sha256=CJu8_r80MGVnYmlAUGOGKGRsD9xShMyaRNb3VMeRC0s,34523
39
39
  model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
40
40
  model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
@@ -68,7 +68,7 @@ model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha2
68
68
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=6pLUEEIqRTVIlCYQC4JIvY55KAvuBHEX8uTOQ-1Ac4Q,3859
69
69
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=r1t025_QHshyoop-PZvL7x6UuXaeplCCU3h4VNBhJHo,4309
70
70
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=2Pp4hiYvGW2I9YhloDxQNT0sZRg3TDp9CXObloF8IFU,4971
71
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=LddWtLileazCOvVSz-7j-GA4yskcGD3UHQGo7XUzSTE,5661
71
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=GGrp7QngrWvWtPN8cQnL4IEbNwcVRc-hAUqfnxjjMmk,5998
72
72
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=NBzzhkVI407S9cIiw7t7nsP3MrkOdSnweKQdPBXb8to,38180
73
73
  model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=gsigifJ-ykWNafF4t7UMEC_-nd6YPERAk1_z0kT-Y88,27172
74
74
  model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
@@ -102,7 +102,7 @@ model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py,sha256=77
102
102
  model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256=_LcDAxLeC5I0KdMHS8jib5XxIKO2ZLavXYuSMIPIQBo,5868
103
103
  model_compression_toolkit/core/common/quantization/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
104
104
  model_compression_toolkit/core/common/quantization/bit_width_config.py,sha256=0HA3CIZW-ZrA55ra-yJXRvAYnoR8i1SjpbnMDKcWYNQ,12819
105
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=u7uueixA5wi3eYPrZKtLVxogkmgcgFL1w2pzMfd_ToU,4950
105
+ model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=lyWPvnoX8BmulhLKR20r5gT2_Yan7P40d8EcgDhErPk,4905
106
106
  model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
107
107
  model_compression_toolkit/core/common/quantization/debug_config.py,sha256=zJP2W9apUPX9RstpPWWK71wr9xJsg7j-s7lGV4_bQdc,1510
108
108
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=IHVX-Gdekru4xLuDTgcsp_JCnRtuVWnbYsDBQuSXTKc,7079
@@ -138,7 +138,7 @@ model_compression_toolkit/core/common/statistics_correction/statistics_correctio
138
138
  model_compression_toolkit/core/common/substitutions/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
139
139
  model_compression_toolkit/core/common/substitutions/apply_substitutions.py,sha256=k-bifmakHIYZeZS-4T1QpZ1Et6AwAijMRgAKs7hmMKc,1390
140
140
  model_compression_toolkit/core/common/substitutions/batchnorm_folding.py,sha256=wLlTT7sqUffKHwOrMG2VV5SktQkkP54l8taW1Fq0mh0,13392
141
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=10fITLy6in5eLfDe415eTdJnTkdTDialfUhBffFYYw0,7634
141
+ model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=1389z4NbTKIHYGr-FB-fV1YP1Gcfta0tOu60DwfNVlI,8452
142
142
  model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py,sha256=dWJpVfomF4Ppeeor3VzS23TXHyBm85QI7snyLOYP_ko,9972
143
143
  model_compression_toolkit/core/common/substitutions/linear_collapsing.py,sha256=iEtzbWCDXP6EDkTZCtREQ0rpMxhQ2kM9zlcP_0KLq9I,12367
144
144
  model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py,sha256=uoauhmncQqUBNvD-qCLIXsIbl_IzrbxSKdxiMig-5W4,2406
@@ -435,13 +435,14 @@ model_compression_toolkit/target_platform_capabilities/constants.py,sha256=BFSgD
435
435
  model_compression_toolkit/target_platform_capabilities/immutable.py,sha256=YhROBiXEIB3TU-bAFrnL3qbAsb1yuWPBAQ_CLOJbYUU,1827
436
436
  model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=4ydTWWKv_PEOAFok2JtxFNj8rav-0IlqcXKF6lnhHNE,4157
437
437
  model_compression_toolkit/target_platform_capabilities/schema/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
438
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=PvO8eHxnb3A55gyExT5fZGnOUl3ce7BbbT5SPxCEXNo,541
438
+ model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=hf539WJ3nBGn0RnALXrKmAPnbhJ-VmWmLIa207x8b4M,541
439
439
  model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py,sha256=vBkXxVJagm9JKB9cdm4Pvi7u_luriXUjvNn0-m8Zr0k,4653
440
440
  model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=4CGpWENuOyjwaIMaGrFI0Act7jsSeT7m94pjrv91dxE,27516
441
+ model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=vUhCocA0EcjdR741Yv48W4Kr5Pq22Miebhm7F9GKb3Y,6086
441
442
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/__init__.py,sha256=XjNws3zoiJkeH4ixKqrLA5xBvpv5rq31qX7wYQjNpZM,1447
442
443
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py,sha256=HJ8uc3PFfyxg-WpVXPBg4mGaox8Z9bRqtQNbRfIyAk4,3745
443
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=mxc3DBbUi-HDFgSx8Nmnyxr8SIdbx8lmtcRMsQl1BLE,7578
444
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=WPCqs_aFGE28XJf7KKB-SlrYoUNOcD9epgoaqQMCJMw,6320
444
+ model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=AE09QLE_QKwNqUTZbkZP9XLJStG1ECiTWmEGuXZTEsQ,7652
445
+ model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=-zbPmzQJal-1vZiQ6vIBBBnlEOB2DTb09koA0Aj4I_I,6396
445
446
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attribute_filter.py,sha256=jfhszvuD2Fyy6W2KjlLzXBQKFzTqGAaDZeFVr4-ONQw,8776
446
447
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/current_tpc.py,sha256=_kFG0USYa6yzvLsi82_Vusv_KR8Hi7J1u680pPXECuo,2192
447
448
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=UKzckLYLdBcFAptyKnVMwpPpfRkmF0SK1Kl0g0eGjQA,9710
@@ -526,7 +527,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
526
527
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
527
528
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
528
529
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
529
- mct_nightly-2.3.0.20250401.618.dist-info/METADATA,sha256=53LoSDV2ox7X64SeEb9OwP4UsuLi75QeSwyFLponCrQ,27098
530
- mct_nightly-2.3.0.20250401.618.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
531
- mct_nightly-2.3.0.20250401.618.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
532
- mct_nightly-2.3.0.20250401.618.dist-info/RECORD,,
530
+ mct_nightly-2.3.0.20250403.518.dist-info/METADATA,sha256=D6WPQRCnXD6lqzCblmPu_dLfulyf5bSMcbH-9mm_nNI,27098
531
+ mct_nightly-2.3.0.20250403.518.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
532
+ mct_nightly-2.3.0.20250403.518.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
533
+ mct_nightly-2.3.0.20250403.518.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250401.000618"
30
+ __version__ = "2.3.0.20250403.000518"
@@ -0,0 +1,374 @@
1
+ # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from model_compression_toolkit.target_platform_capabilities import LayerFilterParams
17
+ from dataclasses import dataclass, field
18
+
19
+ from typing import Optional, List, Dict, Any, Tuple
20
+ import copy
21
+
22
+ # The prefix of each fused operator (the suffix is a combination of the
23
+ # nodes names that combine the fused operator).
24
+ FUSED_OP_ID_PREFIX = "FusedNode_"
25
+
26
+
27
+ @dataclass
28
+ class FusingInfo:
29
+ """
30
+ This class manages information about fused operations in a graph.
31
+
32
+ The key responsibility of this class is maintaining a mapping between original nodes
33
+ and their corresponding fused operation IDs. This mapping helps track which nodes
34
+ belong to fused operations and validate this info is correct after changes in the graph.
35
+
36
+ The core structures maintained are:
37
+ - `fusing_data`: A dictionary mapping fused operation IDs to lists of nodes that belong to that operation.
38
+ - `node_to_fused_node_map`: A dictionary mapping each node name to the ID of the fused operation it belongs to.
39
+
40
+ """
41
+ fusing_patterns: any = None
42
+ fusing_data: Dict[str, Tuple['BaseNode']] = field(default_factory=dict)
43
+ node_to_fused_node_map: Dict[str, str] = field(init=False, default_factory=dict)
44
+
45
+ def __post_init__(self):
46
+ """Validates and initializes mappings after dataclass instantiation."""
47
+ for op_id, op_nodes in self.fusing_data.items():
48
+ assert isinstance(op_id, str) and op_id.startswith(FUSED_OP_ID_PREFIX), f"Found invalid fused op id: {op_id}"
49
+ assert isinstance(op_nodes, tuple) and len(op_nodes) > 1, f"Found invalid fused op nodes: {op_nodes}"
50
+
51
+ self._init_node_mapping()
52
+
53
+ def _init_node_mapping(self) -> None:
54
+ """
55
+ Init the node-to-fused-node mapping based on the initial fusing data.
56
+ """
57
+ self.node_to_fused_node_map.clear()
58
+ for op_id, nodes in self.fusing_data.items():
59
+ for node in nodes:
60
+ self.node_to_fused_node_map[node.name] = op_id
61
+
62
+ def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
63
+ """
64
+ Add a new fused operation with the given ID and set of nodes.
65
+
66
+ Args:
67
+ op_id (str): The identifier for the fused operation.
68
+ nodes (Tuple[BaseNode]): The tuple of nodes that form the fused operation.
69
+
70
+ Raises:
71
+ ValueError: If the operation ID already exists.
72
+ """
73
+ if op_id in self.fusing_data:
74
+ raise ValueError(f"Fused operation {op_id} already exists.")
75
+ assert isinstance(nodes, tuple), f"Expected nodes to be a tuple but its type is {type(nodes)}"
76
+ self.fusing_data[op_id] = nodes
77
+ # Update the mapping for these nodes
78
+ for node in nodes:
79
+ self.node_to_fused_node_map[node.name] = op_id
80
+
81
+ def remove_fused_operation(self, op_id: str) -> None:
82
+ """
83
+ Remove a fused operation by its ID.
84
+
85
+ Args:
86
+ op_id (str): The identifier for the fused operation to remove.
87
+
88
+ Raises:
89
+ ValueError: If the operation ID does not exist.
90
+ """
91
+ if op_id not in self.fusing_data:
92
+ raise ValueError(f"Fused operation {op_id} does not exist.")
93
+ # Remove nodes from the mapping
94
+ nodes = self.fusing_data[op_id]
95
+ for node in nodes:
96
+ self.node_to_fused_node_map.pop(node.name, None)
97
+ del self.fusing_data[op_id]
98
+
99
+ def get_fused_node_name(self, node_name: str) -> Optional[str]:
100
+ """
101
+ Get the name of the fused node containing the given original node name.
102
+
103
+ Args:
104
+ node_name: The name of a node from the original graph.
105
+
106
+ Returns:
107
+ The name of the fused node containing this node, or None if not fused.
108
+ """
109
+ return self.node_to_fused_node_map.get(node_name)
110
+
111
+ def get_node_to_fused_node_map(self) -> Dict[str, str]:
112
+ """
113
+ Retrieve a copy of the mapping from original node names to fused node names.
114
+
115
+ Returns:
116
+ A dictionary mapping each original node name to its fused node name.
117
+ """
118
+ return self.node_to_fused_node_map.copy()
119
+
120
+ def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]:
121
+ """
122
+ Retrieve the list of nodes for a given fused operation ID.
123
+
124
+ Args:
125
+ op_id (str): The identifier for the fused operation.
126
+
127
+ Returns:
128
+ Optional[List[BaseNode]]: The list of nodes for the operation, or None if not found.
129
+ """
130
+ return self.fusing_data.get(op_id)
131
+
132
+ def is_node_in_fused_op(self, node: 'BaseNode') -> bool:
133
+ """
134
+ Check if a node is part of any fused operation.
135
+
136
+ Args:
137
+ node (BaseNode): The node to check.
138
+
139
+ Returns:
140
+ bool: True if the node is in any fused operation, False otherwise.
141
+ """
142
+ return any(node in nodes for nodes in self.fusing_data.values())
143
+
144
+ def get_all_fused_operations(self) -> Dict[str, Tuple['BaseNode']]:
145
+ """
146
+ Retrieve fused information.
147
+
148
+ Returns:
149
+ Dict[str, List[BaseNode]]: The fusing data.
150
+ """
151
+ return self.fusing_data
152
+
153
+
154
+ @staticmethod
155
+ def generate_fused_op_id(nodes: List['BaseNode']) -> str:
156
+ """
157
+ Generates an identifier for a fused operation by concatenating
158
+ the names of the given nodes with a prefix.
159
+
160
+ Args:
161
+ nodes (List[BaseNode]): A list of nodes to be fused.
162
+
163
+ Returns:
164
+ str: An identifier string for the fused operation.
165
+ """
166
+ id = FUSED_OP_ID_PREFIX + '_'.join([node.name for node in nodes])
167
+ return id
168
+
169
+ def validate(self, graph) -> None:
170
+ """
171
+ Validate that the fusing information is consistent with the given graph and generation logic.
172
+
173
+ This method performs the following checks:
174
+ 1. All nodes in the fusing data exist in the graph.
175
+ 2. Each fused sequence forms a valid linear chain in the graph:
176
+ - Each node (except the last) has exactly one successor, which is the next node in the sequence.
177
+ 3. No node is part of more than one fused operation.
178
+ 4. Each fused sequence matches a valid fusing pattern from the original set.
179
+
180
+ Args:
181
+ graph: The computational graph to validate against. It is expected to have:
182
+ - `get_topo_sorted_nodes()`: Returns a list of nodes in topological order.
183
+ - `get_next_nodes(node)`: Returns a list of direct successor nodes.
184
+
185
+ Raises:
186
+ ValueError: If any validation check fails.
187
+ """
188
+ graph_nodes = set(graph.get_topo_sorted_nodes()) # Retrieve all nodes from the graph
189
+ all_fused_nodes = set() # Track all nodes used in fusions to ensure no overlap
190
+
191
+ for op_id, nodes in self.fusing_data.items():
192
+ # Check 1: Ensure all fused nodes exist in the graph
193
+ for node in nodes:
194
+ if node not in graph_nodes:
195
+ raise ValueError(f"Fused operation {op_id} contains node {node.name} not present in the graph.")
196
+
197
+ # Check 2: Validate the fusion sequence forms a valid linear chain
198
+ for i in range(len(nodes) - 1): # Up to the second-to-last node
199
+ current_node = nodes[i]
200
+ next_node = nodes[i + 1]
201
+ successors = graph.get_next_nodes(current_node)
202
+ if len(successors) != 1 or successors[0] != next_node:
203
+ raise ValueError(
204
+ f"Fused operation {op_id} is not a valid linear chain: "
205
+ f"node {current_node.name} does not connect directly to {next_node.name} "
206
+ f"with exactly one successor (found successors: {[n.name for n in successors]})."
207
+ )
208
+
209
+ # Check 3: Ensure no node is reused across fusions
210
+ node_set = set(nodes)
211
+ overlap = node_set & all_fused_nodes
212
+ if overlap:
213
+ raise ValueError(
214
+ f"Fused operation {op_id} contains nodes already used in another fusion: "
215
+ f"{[node.name for node in overlap]}."
216
+ )
217
+ all_fused_nodes.update(node_set)
218
+
219
+ # Check 4: Ensure the sequence matches a valid fusing pattern
220
+ if not is_valid_fusion(self.fusing_patterns, nodes):
221
+ raise ValueError(
222
+ f"Fused operation {op_id} does not match any valid fusing pattern "
223
+ f"from {self.fusing_patterns}."
224
+ )
225
+
226
+ def is_nodes_eligible_to_be_fused(self, nodes: List['BaseNode']) -> bool:
227
+ """
228
+ Check whether the given nodes are eligible to be fused based on predefined fusing patterns.
229
+
230
+ This method retrieves the fusing patterns from `self.fqc` and verifies whether the
231
+ given sequence of nodes matches any of the valid patterns.
232
+
233
+ Args:
234
+ nodes (List[BaseNode]): The list of nodes to check for fusion eligibility.
235
+
236
+ Returns:
237
+ bool: True if the nodes can be fused according to fusing patterns, otherwise False.
238
+ """
239
+ # If no fusing patterns are defined, fusion is not possible
240
+ if not self.fusing_patterns:
241
+ return False
242
+
243
+ # Check if the provided nodes match a valid fusion pattern
244
+ return is_valid_fusion(fusing_patterns=self.fusing_patterns, nodes=nodes)
245
+
246
+ def __repr__(self) -> str:
247
+ """
248
+ Return a string representation of the fusing information.
249
+ """
250
+ fusing_data_repr = "\n".join(
251
+ f" {op_id}: [{', '.join(node.name for node in nodes)}]"
252
+ for op_id, nodes in self.fusing_data.items()
253
+ )
254
+ mapping_repr = ", ".join(
255
+ f"{node} -> {op_id}" for node, op_id in self.node_to_fused_node_map.items()
256
+ )
257
+ return (
258
+ f"FusingInfo(\n"
259
+ f" Total fused operations: {len(self.fusing_data)}\n"
260
+ f" Fusing Data:\n{fusing_data_repr}\n"
261
+ f" Node-to-Fused Mapping:\n {mapping_repr}\n"
262
+ f")"
263
+ )
264
+
265
+
266
+ class FusingInfoGenerator:
267
+ def __init__(self, fusing_patterns):
268
+ self._fusing_patterns = fusing_patterns
269
+
270
+ def generate_fusing_info(self, graph) -> FusingInfo:
271
+ """
272
+ Generate fusing information based on the graph and fusing patterns.
273
+
274
+ Args:
275
+ graph: The input graph to analyze, expected to have methods like
276
+ get_topo_sorted_nodes() and get_next_nodes(node).
277
+
278
+ Returns:
279
+ A dictionary where keys are unique fusion identifiers (e.g., 'fused_op_0')
280
+ and values are lists of BaseNode objects representing nodes in that fusion.
281
+
282
+ Notes:
283
+ - Assumes get_valid_fusing_patterns_for_node and is_valid_fusion functions are defined elsewhere.
284
+ - Nodes are processed in topological order to respect operation sequence.
285
+ - Fusions are linear sequences (each node has exactly one successor).
286
+ - Each node belongs to at most one fused operation.
287
+ """
288
+ if not self._fusing_patterns:
289
+ return FusingInfo(fusing_patterns=self._fusing_patterns)
290
+
291
+ # Find max fusion
292
+ max_layers_fusing = 0 if len(self._fusing_patterns) == 0 else max([len(fusing_pattern) for fusing_pattern in self._fusing_patterns])
293
+
294
+ # Travel along the graph to find layers for fusing
295
+ nodes = graph.get_topo_sorted_nodes()
296
+
297
+ fusing_info: Dict[str, Tuple['BaseNode']] = {}
298
+ fused_nodes = [] # nodes that are participating in fusing
299
+
300
+ for node in nodes:
301
+ # Skip if already in fusing
302
+ if node in fused_nodes:
303
+ continue
304
+ # Start fusing search
305
+ fusing_nodes = [] # nodes that are candidates for participating in fusing
306
+ patterns = copy.deepcopy(self._fusing_patterns)
307
+ next_nodes = [node]
308
+ for i in range(max_layers_fusing):
309
+ patterns = get_valid_fusing_patterns_for_node(patterns, next_nodes[0], i)
310
+ if len(patterns) == 0: # Give up if no more fusion pattern
311
+ break
312
+ fusing_nodes.append(next_nodes[0])
313
+ next_nodes = graph.get_next_nodes(fusing_nodes[-1])
314
+ if len(next_nodes) != 1: # Give up if node has more than one connection (not supported for fusion)
315
+ break
316
+
317
+ # New fusion
318
+ if is_valid_fusion(self._fusing_patterns, fusing_nodes):
319
+ fused_op_id = FusingInfo.generate_fused_op_id(fusing_nodes)
320
+ assert fused_op_id not in fusing_info, f"{fused_op_id} is already in fusing info: {fusing_info}"
321
+ fusing_info[fused_op_id] = tuple(fusing_nodes)
322
+ fused_nodes.extend(fusing_nodes)
323
+
324
+ return FusingInfo(fusing_data=fusing_info, fusing_patterns=self._fusing_patterns)
325
+
326
+
327
+ def get_valid_fusing_patterns_for_node(fusing_patterns: List[List[Any]],
328
+ node: 'BaseNode',
329
+ idx: int = 0) -> List[List[Any]]:
330
+ """
331
+ Returns only the fusing patterns where a specific layer (at index idx) matches the given node — either by type or filter params.
332
+
333
+ Args:
334
+ fusing_patterns: supported fusings
335
+ node: node to decide if it can be a part of fusion
336
+ idx: index of layer in the fusion
337
+
338
+ Returns:
339
+ fusing_patterns after filtering non-relevant fusions
340
+ """
341
+ valid_fusing_patterns = []
342
+ for i, fusing_pattern in enumerate(fusing_patterns):
343
+ if idx < len(fusing_pattern):
344
+ if ((type(fusing_pattern[idx]) == LayerFilterParams and node.is_match_filter_params(
345
+ fusing_pattern[idx])) or node.is_match_type(fusing_pattern[idx])):
346
+ valid_fusing_patterns.append(fusing_pattern)
347
+
348
+ # Return only valid patterns for this node
349
+ return valid_fusing_patterns
350
+
351
+
352
+ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode']) -> bool:
353
+ """
354
+ Check if the fusion is valid: exist in fusing_patterns
355
+ Args:
356
+ fusing_patterns: supported fusing patterns
357
+ nodes: nodes which are participating in fusion
358
+ Returns:
359
+ whether the fusion in valid
360
+ """
361
+ fusion_depth = len(nodes)
362
+ if fusion_depth <= 1:
363
+ return False
364
+ for fusing_pattern in fusing_patterns:
365
+ if fusion_depth != len(fusing_pattern):
366
+ continue
367
+ counter = 0
368
+ for i, layer in enumerate(fusing_pattern):
369
+ if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or \
370
+ nodes[i].is_match_type(layer):
371
+ counter += 1
372
+ if counter == fusion_depth:
373
+ return True
374
+ return False
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
1
+ # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,10 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Dict, List
16
+ import copy
17
+ from typing import List, Tuple
17
18
 
18
- from model_compression_toolkit.core.common import Graph, BaseNode
19
- from model_compression_toolkit.core.common.graph.base_graph import OutTensor
19
+ from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator
20
+ from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
21
+ from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig
22
+ from itertools import product
20
23
 
21
24
 
22
25
  class FusedLayerType:
@@ -27,35 +30,41 @@ class FusedLayerType:
27
30
  def __init__(self):
28
31
  self.__name__ = 'FusedLayer'
29
32
 
30
-
31
33
  class GraphFuser:
32
-
33
- def create_fused_graph(self, graph: Graph) -> Dict[str, str]:
34
+ def apply_node_fusion(self, graph: Graph) -> Graph:
34
35
  """
35
- GraphFuser is responsible for fusing nodes in a networkx graph.
36
- The fusion process involves:
37
- 1. Creating new fused nodes to represent these groups.
38
- 2. Updating the graph structure to replace the original nodes with fused nodes.
39
- 3. Maintaining mapping of original node names to their fused node names.
36
+ Applies node fusion to the graph according the fusing_info it has.
37
+
38
+ The fusion process includes:
39
+ 1. Generating new fused nodes to replace groups of original nodes.
40
+ 2. Updating the graph structure to replace those nodes with the fused representations.
40
41
 
41
42
  Args:
42
- graph: Graph to fuse its nodes.
43
+ graph: The graph and its fusing metadata.
43
44
 
44
45
  Returns:
45
- Mapping of original node names to their fused node names
46
+ The updated graph with fused nodes replacing the original node groups.
46
47
  """
47
- fused_nodes_mapping = {}
48
- # Iterate through each group of nodes to be fused
49
- for fused_nodes_list in graph.fused_nodes:
50
- new_fused_node = self._create_fused_node(fused_nodes_list)
51
- self._replace_nodes_with_fused_node(graph, fused_nodes_list, new_fused_node)
52
- # Update the mapping to keep track of which original nodes are now part of which fused nodes
53
- for node in fused_nodes_list:
54
- fused_nodes_mapping[node.name] = new_fused_node.name
55
- return fused_nodes_mapping
48
+ graph_copy = copy.deepcopy(graph)
49
+ expected_fusing_info = FusingInfoGenerator(graph_copy.fusing_info.fusing_patterns).generate_fusing_info(graph_copy)
50
+
51
+ if expected_fusing_info != graph_copy.fusing_info:
52
+ raise ValueError(
53
+ f"Mismatch between expected and existing fusing information.\n"
54
+ f"Expected:\n{expected_fusing_info}\nExisting:\n{graph_copy.fusing_info}"
55
+ )
56
+
57
+ fused_operations = list(graph_copy.fusing_info.get_all_fused_operations().items())
58
+ for fused_node_id, original_nodes in fused_operations:
59
+ fused_node = self._create_fused_node(fused_node_id, original_nodes)
60
+ graph_copy.fusing_info.remove_fused_operation(fused_node_id)
61
+ self._replace_nodes_with_fused_node(graph_copy, original_nodes, fused_node)
62
+
63
+ return graph_copy
64
+
56
65
 
57
66
  @staticmethod
58
- def _create_fused_node(nodes: List[BaseNode]) -> BaseNode:
67
+ def _create_fused_node(fused_node_id: str, nodes: Tuple[BaseNode]) -> BaseNode:
59
68
  """
60
69
  Create a new node that represents the fusion of the given nodes.
61
70
 
@@ -67,22 +76,28 @@ class GraphFuser:
67
76
  """
68
77
  # Create a new node with a name that reflects its components
69
78
  # Use the input shape of the first node and output shape of the last node
70
- fused_node = BaseNode(name='FusedNode_' + '_'.join([node.name for node in nodes]),
79
+ # TODO: consider replacing the fused node with a sub-model to allow inference on it, etc.
80
+ fused_node = BaseNode(name=fused_node_id,
71
81
  framework_attr={},
72
82
  input_shape=nodes[0].input_shape,
73
83
  output_shape=nodes[-1].output_shape,
74
84
  weights={},
75
85
  layer_class=FusedLayerType)
76
86
 
77
- # Preserve the final activation quantization configuration
78
- # This is important for maintaining the correct behavior of the fused node
87
+ activation_cfgs = [c.activation_quantization_cfg for c in nodes[-1].candidates_quantization_cfg]
88
+ fused_node.candidates_quantization_cfg = [
89
+ CandidateNodeQuantizationConfig(weights_quantization_cfg=None, activation_quantization_cfg=a) for a in
90
+ activation_cfgs]
91
+
92
+ # Keep the final configurations if they were set already.
93
+ fused_node.final_weights_quantization_cfg = nodes[0].final_weights_quantization_cfg
79
94
  fused_node.final_activation_quantization_cfg = nodes[-1].final_activation_quantization_cfg
80
95
 
81
96
  return fused_node
82
97
 
83
98
  @staticmethod
84
99
  def _replace_nodes_with_fused_node(graph: Graph,
85
- nodes_to_fuse: List[BaseNode],
100
+ nodes_to_fuse: Tuple[BaseNode],
86
101
  fused_node: BaseNode):
87
102
  """
88
103
  Replace the specified nodes in the graph with a new fused node.
@@ -118,6 +133,11 @@ class GraphFuser:
118
133
  for next_node in subsequent_nodes:
119
134
  assert next_node in nodes_to_fuse # Ensure we're not removing edges outside the fusion
120
135
  graph.remove_edge(current_node, next_node)
136
+ # next_node can have more incoming edges from other nodes that are not
137
+ # in the fusion and we should remove them to:
138
+ in_edges = graph.incoming_edges(next_node)
139
+ for ie in in_edges:
140
+ graph.remove_edge(ie.source_node, next_node)
121
141
 
122
142
  # Handle the case where fused nodes are part of the graph's outputs
123
143
  graph_output_tensors = graph.get_outputs()
@@ -136,3 +156,5 @@ class GraphFuser:
136
156
 
137
157
  # Finally, add the new fused node to the graph
138
158
  graph.add_node(fused_node)
159
+
160
+
@@ -15,7 +15,8 @@
15
15
  from collections import namedtuple
16
16
 
17
17
  from copy import copy, deepcopy
18
- from typing import List, Tuple, Any
18
+ from functools import wraps
19
+ from typing import List, Tuple, Any, Callable
19
20
 
20
21
  import networkx as nx
21
22
  import numpy as np
@@ -23,6 +24,7 @@ import numpy as np
23
24
  from networkx.algorithms.dag import topological_sort
24
25
 
25
26
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
27
+ from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo
26
28
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX, EDGE_SOURCE_INDEX
27
29
  from model_compression_toolkit.core.common.graph.edge import Edge, convert_to_edge
28
30
  from model_compression_toolkit.core.common.graph.graph_searches import GraphSearches
@@ -36,6 +38,27 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
36
38
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
37
39
  FrameworkQuantizationCapabilities
38
40
 
41
+
42
+ def validate_graph_after_change(method: Callable) -> Callable:
43
+ """
44
+ Decorator for graph-mutating methods. After the decorated method executes,
45
+ this decorator calls `self.validate()` to ensure the graph remains in a valid state.
46
+
47
+ Args:
48
+ method: The graph-modifying method to wrap.
49
+
50
+ Returns:
51
+ A wrapped method that validates the graph after execution.
52
+ """
53
+ @wraps(method)
54
+ def wrapper(self, *args, **kwargs):
55
+ result = method(self, *args, **kwargs)
56
+ if not self.skip_validation_check:
57
+ self.validate() # calls Graph.validate(). Ensure graph consistency after changes.
58
+ return result
59
+ return wrapper
60
+
61
+
39
62
  OutTensor = namedtuple('OutTensor', 'node node_out_index')
40
63
 
41
64
 
@@ -63,6 +86,11 @@ class Graph(nx.MultiDiGraph, GraphSearches):
63
86
  """
64
87
 
65
88
  super().__init__(**attr)
89
+
90
+ # This must be set first to ensure it's available when validation runs during graph creation.
91
+ self._skip_validation_check = False
92
+ self._fusing_info = FusingInfo()
93
+
66
94
  self.name = name
67
95
  self.input_nodes = input_nodes
68
96
  self.output_nodes = output_nodes
@@ -75,7 +103,25 @@ class Graph(nx.MultiDiGraph, GraphSearches):
75
103
  **e.get_attributes())
76
104
  self.user_info = UserInformation()
77
105
  self.fw_info = fw_info
78
- self.fused_nodes = []
106
+
107
+ @property
108
+ def skip_validation_check(self) -> bool:
109
+ return self._skip_validation_check
110
+
111
+ @skip_validation_check.setter
112
+ def skip_validation_check(self, value: bool):
113
+ if not isinstance(value, bool):
114
+ raise ValueError("skip_validation_check must be a boolean.")
115
+ self._skip_validation_check = value
116
+
117
+ @property
118
+ def fusing_info(self) -> FusingInfo:
119
+ return self._fusing_info
120
+
121
+ @fusing_info.setter
122
+ @validate_graph_after_change
123
+ def fusing_info(self, fusing_info: FusingInfo):
124
+ self._fusing_info = fusing_info
79
125
 
80
126
  def set_fw_info(self,
81
127
  fw_info: FrameworkInfo):
@@ -139,6 +185,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
139
185
 
140
186
  return self.output_nodes
141
187
 
188
+ @validate_graph_after_change
142
189
  def set_inputs(self,
143
190
  input_nodes: List[BaseNode]):
144
191
  """
@@ -149,6 +196,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
149
196
 
150
197
  self.input_nodes = input_nodes
151
198
 
199
+ @validate_graph_after_change
152
200
  def set_outputs(self,
153
201
  output_nodes: List[OutTensor]):
154
202
  """
@@ -321,6 +369,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
321
369
  sort_attr = None
322
370
  return [edges_list.source_node for edges_list in self.incoming_edges(node_obj, sort_by_attr=sort_attr)]
323
371
 
372
+ @validate_graph_after_change
324
373
  def reconnect_out_edges(self,
325
374
  current_node: BaseNode,
326
375
  new_node: BaseNode):
@@ -337,6 +386,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
337
386
  self.add_edge(new_node, oe.sink_node, **oe.get_attributes())
338
387
  self.remove_edge(current_node, oe.sink_node)
339
388
 
389
+ @validate_graph_after_change
340
390
  def reconnect_in_edges(self,
341
391
  current_node: BaseNode,
342
392
  new_node: BaseNode):
@@ -353,6 +403,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
353
403
  self.add_edge(ie.source_node, new_node, **ie.get_attributes())
354
404
  self.remove_edge(ie.source_node, current_node)
355
405
 
406
+ @validate_graph_after_change
356
407
  def add_node_with_in_edges(self, new_node: BaseNode, input_nodes: List[BaseNode],
357
408
  input_nodes_output_index: List[int] = []):
358
409
  """
@@ -378,6 +429,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
378
429
  for sink_index, (in_node, source_index) in enumerate(zip(input_nodes, input_nodes_output_index)):
379
430
  self.add_edge(in_node, new_node, source_index=source_index, sink_index=sink_index)
380
431
 
432
+ @validate_graph_after_change
381
433
  def replace_output_node(self,
382
434
  current_node: BaseNode,
383
435
  new_node: BaseNode):
@@ -400,6 +452,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
400
452
  new_graph_outputs[graph_ot_index] = OutTensor(new_node, ot.node_out_index)
401
453
  self.set_outputs(new_graph_outputs)
402
454
 
455
+ @validate_graph_after_change
403
456
  def replace_input_node(self,
404
457
  current_node: BaseNode,
405
458
  new_node: BaseNode):
@@ -424,6 +477,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
424
477
  new_graph_inputs.append(new_node)
425
478
  self.set_inputs(new_graph_inputs)
426
479
 
480
+ @validate_graph_after_change
427
481
  def remove_node(self,
428
482
  node_to_remove: BaseNode,
429
483
  new_graph_inputs: List[BaseNode] = None,
@@ -713,16 +767,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
713
767
  node = prev_nodes[0]
714
768
  return node
715
769
 
716
- def update_fused_nodes(self, fusion: List[Any]):
717
- """
718
- Updates the graphs fusions list with a new list of nodes that have been fused.
719
-
720
- Args:
721
- fusion: A list of nodes that have been fused.
722
-
723
- """
724
- self.fused_nodes.append(fusion)
725
-
726
770
  def has_any_configurable_activation(self) -> bool:
727
771
  """
728
772
  Checks whether any node in the graph has a configurable activation quantization.
@@ -742,6 +786,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
742
786
 
743
787
  return any([n.has_any_configurable_weight() for n in self.nodes])
744
788
 
789
+ @validate_graph_after_change
745
790
  def replace_node(self, node_to_replace: BaseNode, new_node: BaseNode):
746
791
  """
747
792
  Replaces a node in the graph with a new node.
@@ -867,4 +912,36 @@ class Graph(nx.MultiDiGraph, GraphSearches):
867
912
 
868
913
  return intermediate_nodes, next_node
869
914
 
915
+ def disable_fused_nodes_activation_quantization(self):
916
+ """
917
+ Disable activation quantization for all nodes in fused operations,
918
+ except for the last node in each fused group.
919
+ """
920
+ nodes_to_disable = [node for nodes in self.fusing_info.get_all_fused_operations().values() for node in nodes[:-1]]
921
+ for node in nodes_to_disable:
922
+ for qc in node.candidates_quantization_cfg:
923
+ qc.activation_quantization_cfg.enable_activation_quantization = False
924
+
925
+ def validate(self):
926
+ """
927
+ Validate that the current state of the graph is consistent with
928
+ the fusing information (e.g., no missing or incorrect fused node mapping).
870
929
 
930
+ Returns:
931
+ The result of the FusingInfo validation logic (typically None or raises error).
932
+ """
933
+ return self.fusing_info.validate(self)
934
+
935
+ @validate_graph_after_change
936
+ def add_edge(self, *args, **kwargs):
937
+ """
938
+ Wrap networkx functions (that modifies the graph) with our validate decorator.
939
+ """
940
+ return super().add_edge(*args, **kwargs)
941
+
942
+ @validate_graph_after_change
943
+ def remove_edge(self, *args, **kwargs):
944
+ """
945
+ Wrap networkx functions (that modifies the graph) with our validate decorator.
946
+ """
947
+ return super().remove_edge(*args, **kwargs)
@@ -65,6 +65,7 @@ def search_bit_width(graph: Graph,
65
65
  bit-width index on the node).
66
66
 
67
67
  """
68
+
68
69
  assert target_resource_utilization.is_any_restricted()
69
70
 
70
71
  # If we only run weights compression with MP than no need to consider activation quantization when computing the
@@ -88,6 +89,11 @@ def search_bit_width(graph: Graph,
88
89
  if search_method != BitWidthSearchMethod.INTEGER_PROGRAMMING:
89
90
  raise NotImplementedError()
90
91
 
92
+ # Validation is skipped during the mixed-precision search configuration because fusing information is not
93
+ # relevant for the virtual graph. Therefore, validation checks are disabled before the search begins and
94
+ # re-enabled once it completes.
95
+ graph.skip_validation_check = True
96
+
91
97
  # Search manager and LP are highly coupled, so LP search method was moved inside search manager.
92
98
  search_manager = MixedPrecisionSearchManager(graph,
93
99
  fw_info,
@@ -96,6 +102,8 @@ def search_bit_width(graph: Graph,
96
102
  target_resource_utilization)
97
103
  result_bit_cfg = search_manager.search()
98
104
 
105
+ graph.skip_validation_check = False
106
+
99
107
  if mp_config.refine_mp_solution:
100
108
  result_bit_cfg = greedy_solution_refinement_procedure(result_bit_cfg, search_manager, target_resource_utilization)
101
109
 
@@ -71,11 +71,13 @@ class CandidateNodeQuantizationConfig(BaseNodeQuantizationConfig):
71
71
 
72
72
  if weights_quantization_cfg is not None:
73
73
  self.weights_quantization_cfg = weights_quantization_cfg
74
- else:
75
- if any(v is None for v in (qc, op_cfg, node_attrs_list)): # pragma: no cover
76
- Logger.critical("Missing required arguments to initialize a node weights quantization configuration. "
77
- "Ensure QuantizationConfig, OpQuantizationConfig, weights quantization function, "
78
- "parameters function, and weights attribute quantization config are provided.")
79
- self.weights_quantization_cfg = NodeWeightsQuantizationConfig(qc=qc, op_cfg=op_cfg,
74
+ elif all(v is not None for v in (qc, op_cfg, node_attrs_list)):
75
+ self.weights_quantization_cfg = NodeWeightsQuantizationConfig(qc=qc,
76
+ op_cfg=op_cfg,
80
77
  weights_channels_axis=weights_channels_axis,
81
78
  node_attrs_list=node_attrs_list)
79
+ else:
80
+ self.weights_quantization_cfg = None
81
+ Logger.debug("Setting weights quantization config as None during CandidateNodeQuantizationConfig creation."
82
+ "Notice, this should happen only for FLN nodes.")
83
+
@@ -19,11 +19,11 @@ from typing import Callable
19
19
 
20
20
  import numpy as np
21
21
 
22
+ from model_compression_toolkit.core.common import Graph
22
23
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
23
24
  from model_compression_toolkit.core import common
24
25
  from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
25
26
  from model_compression_toolkit.logger import Logger
26
- from model_compression_toolkit.core.common.graph.base_graph import Graph
27
27
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
28
28
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
29
29
  from mct_quantizers import QuantizationMethod
@@ -143,6 +143,21 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
143
143
  AttributeQuantizationConfig(
144
144
  enable_weights_quantization=False)))
145
145
 
146
+ # Check if the source node was part of a fusion. If so, there are two cases:
147
+ # either this is no longer a fusion, and the fusion info should be updated by removing
148
+ # the current info, or this creates a new fusion and the old pattern should be
149
+ # replaced with the new one.
150
+ fi = graph.fusing_info
151
+ fused_op = fi.get_fused_node_name(source_node.name)
152
+ if fused_op:
153
+ fused_nodes = list(fi.get_fused_nodes(fused_op))
154
+ assert source_node in fused_nodes
155
+ fused_nodes.insert(fused_nodes.index(source_node)+1, bn_node)
156
+ fi.remove_fused_operation(fused_op)
157
+ if fi.is_nodes_eligible_to_be_fused(fused_nodes):
158
+ op_id = fi.generate_fused_op_id(fused_nodes)
159
+ fi.add_fused_operation(op_id, tuple(fused_nodes))
160
+
146
161
  graph.reconnect_out_edges(current_node=source_node, new_node=bn_node)
147
162
  graph.replace_output_node(current_node=source_node, new_node=bn_node)
148
163
  graph.add_node_with_in_edges(bn_node, [source_node])
@@ -18,7 +18,7 @@ from typing import Callable, Any
18
18
 
19
19
  from model_compression_toolkit.core.common import FrameworkInfo
20
20
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
- from model_compression_toolkit.core.common.fusion.layer_fusing import fusion
21
+ from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator
22
22
  from model_compression_toolkit.core.common.graph.base_graph import Graph
23
23
  from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
24
24
  from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates
@@ -136,6 +136,7 @@ def get_finalized_graph(initial_graph: Graph,
136
136
  node.prior_info = fw_impl.get_node_prior_info(node=node,
137
137
  fw_info=fw_info,
138
138
  graph=graph)
139
+
139
140
  ##################################################
140
141
  # Graph substitution (pre statistics collection)
141
142
  ##################################################
@@ -161,7 +162,9 @@ def get_finalized_graph(initial_graph: Graph,
161
162
  ######################################
162
163
  # Layer fusing
163
164
  ######################################
164
- transformed_graph = fusion(transformed_graph, fqc)
165
+ fusing_info = FusingInfoGenerator(fqc.get_fusing_patterns()).generate_fusing_info(transformed_graph)
166
+ transformed_graph.fusing_info = fusing_info
167
+ transformed_graph.disable_fused_nodes_activation_quantization()
165
168
 
166
169
  ######################################
167
170
  # Channel equalization
@@ -184,15 +184,14 @@ def core_runner(in_model: Any,
184
184
 
185
185
  scheduler_info = None
186
186
  if core_config.debug_config.simulate_scheduler:
187
- graph_to_fuse = copy.deepcopy(tg)
188
- fused_nodes_mapping = GraphFuser().create_fused_graph(graph_to_fuse)
189
- memory_graph = MemoryGraph(graph_to_fuse)
187
+ fused_graph = GraphFuser().apply_node_fusion(tg)
188
+ memory_graph = MemoryGraph(fused_graph)
190
189
  schedule, max_cut, cuts = compute_graph_max_cut(memory_graph)
191
190
  scheduler_info = SchedulerInfo(
192
191
  operators_scheduling=schedule,
193
192
  max_cut=float(max_cut),
194
193
  cuts=cuts,
195
- fused_nodes_mapping=fused_nodes_mapping
194
+ fused_nodes_mapping=tg.fusing_info.get_node_to_fused_node_map()
196
195
  )
197
196
 
198
197
  return tg, bit_widths_config, hessian_info_service, scheduler_info
@@ -1,4 +1,4 @@
1
- import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema
1
+ import model_compression_toolkit.target_platform_capabilities.schema.v2 as schema
2
2
 
3
3
  OperatorSetNames = schema.OperatorSetNames
4
4
  Signedness = schema.Signedness
@@ -0,0 +1,177 @@
1
+ # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import pprint
16
+ from enum import Enum
17
+ from typing import Dict, Any, Tuple, Optional
18
+
19
+ from pydantic import BaseModel, root_validator
20
+
21
+ from mct_quantizers import QuantizationMethod
22
+ from model_compression_toolkit.constants import FLOAT_BITWIDTH
23
+ from model_compression_toolkit.logger import Logger
24
+ from model_compression_toolkit.target_platform_capabilities.schema.v1 import (
25
+ Signedness,
26
+ AttributeQuantizationConfig,
27
+ OpQuantizationConfig,
28
+ QuantizationConfigOptions,
29
+ TargetPlatformModelComponent,
30
+ OperatorsSetBase,
31
+ OperatorsSet,
32
+ OperatorSetGroup,
33
+ Fusing)
34
+
35
+
36
+ class OperatorSetNames(str, Enum):
37
+ CONV = "Conv"
38
+ DEPTHWISE_CONV = "DepthwiseConv2D"
39
+ CONV_TRANSPOSE = "ConvTranspose"
40
+ FULLY_CONNECTED = "FullyConnected"
41
+ CONCATENATE = "Concatenate"
42
+ STACK = "Stack"
43
+ UNSTACK = "Unstack"
44
+ GATHER = "Gather"
45
+ EXPAND = "Expend"
46
+ BATCH_NORM = "BatchNorm"
47
+ L2NORM = "L2Norm"
48
+ RELU = "ReLU"
49
+ RELU6 = "ReLU6"
50
+ LEAKY_RELU = "LeakyReLU"
51
+ ELU = "Elu"
52
+ HARD_TANH = "HardTanh"
53
+ ADD = "Add"
54
+ SUB = "Sub"
55
+ MUL = "Mul"
56
+ DIV = "Div"
57
+ MIN = "Min"
58
+ MAX = "Max"
59
+ PRELU = "PReLU"
60
+ ADD_BIAS = "AddBias"
61
+ SWISH = "Swish"
62
+ SIGMOID = "Sigmoid"
63
+ SOFTMAX = "Softmax"
64
+ LOG_SOFTMAX = "LogSoftmax"
65
+ TANH = "Tanh"
66
+ GELU = "Gelu"
67
+ HARDSIGMOID = "HardSigmoid"
68
+ HARDSWISH = "HardSwish"
69
+ FLATTEN = "Flatten"
70
+ GET_ITEM = "GetItem"
71
+ RESHAPE = "Reshape"
72
+ UNSQUEEZE = "Unsqueeze"
73
+ SQUEEZE = "Squeeze"
74
+ PERMUTE = "Permute"
75
+ TRANSPOSE = "Transpose"
76
+ DROPOUT = "Dropout"
77
+ SPLIT_CHUNK = "SplitChunk"
78
+ MAXPOOL = "MaxPool"
79
+ AVGPOOL = "AvgPool"
80
+ SIZE = "Size"
81
+ SHAPE = "Shape"
82
+ EQUAL = "Equal"
83
+ ARGMAX = "ArgMax"
84
+ TOPK = "TopK"
85
+ FAKE_QUANT = "FakeQuant"
86
+ COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression"
87
+ BOX_DECODE = "BoxDecode"
88
+ ZERO_PADDING2D = "ZeroPadding2D"
89
+ CAST = "Cast"
90
+ RESIZE = "Resize"
91
+ PAD = "Pad"
92
+ FOLD = "Fold"
93
+ STRIDED_SLICE = "StridedSlice"
94
+ SSD_POST_PROCESS = "SSDPostProcess"
95
+
96
+ @classmethod
97
+ def get_values(cls):
98
+ return [v.value for v in cls]
99
+
100
+
101
+ class TargetPlatformCapabilities(BaseModel):
102
+ """
103
+ Represents the hardware configuration used for quantized model inference.
104
+
105
+ Attributes:
106
+ default_qco (QuantizationConfigOptions): Default quantization configuration options for the model.
107
+ operator_set (Optional[Tuple[OperatorsSet, ...]]): Tuple of operator sets within the model.
108
+ fusing_patterns (Optional[Tuple[Fusing, ...]]): Tuple of fusing patterns for the model.
109
+ tpc_minor_version (Optional[int]): Minor version of the Target Platform Configuration.
110
+ tpc_patch_version (Optional[int]): Patch version of the Target Platform Configuration.
111
+ tpc_platform_type (Optional[str]): Type of the platform for the Target Platform Configuration.
112
+ add_metadata (bool): Flag to determine if metadata should be added.
113
+ name (str): Name of the Target Platform Model.
114
+ is_simd_padding (bool): Indicates if SIMD padding is applied.
115
+ SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
116
+ """
117
+ default_qco: QuantizationConfigOptions
118
+ operator_set: Optional[Tuple[OperatorsSet, ...]]
119
+ fusing_patterns: Optional[Tuple[Fusing, ...]]
120
+ tpc_minor_version: Optional[int]
121
+ tpc_patch_version: Optional[int]
122
+ tpc_platform_type: Optional[str]
123
+ add_metadata: bool = True
124
+ name: Optional[str] = "default_tpc"
125
+ is_simd_padding: bool = False
126
+
127
+ SCHEMA_VERSION: int = 2
128
+
129
+ class Config:
130
+ frozen = True
131
+
132
+ @root_validator(allow_reuse=True)
133
+ def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]:
134
+ """
135
+ Perform validation after the model has been instantiated.
136
+
137
+ Args:
138
+ values (Dict[str, Any]): The instantiated target platform model.
139
+
140
+ Returns:
141
+ Dict[str, Any]: The validated values.
142
+ """
143
+ # Validate `default_qco`
144
+ default_qco = values.get('default_qco')
145
+ if len(default_qco.quantization_configurations) != 1:
146
+ Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
147
+
148
+ # Validate `operator_set` uniqueness
149
+ operator_set = values.get('operator_set')
150
+ if operator_set is not None:
151
+ opsets_names = [
152
+ op.name.value if isinstance(op.name, OperatorSetNames) else op.name
153
+ for op in operator_set
154
+ ]
155
+ if len(set(opsets_names)) != len(opsets_names):
156
+ Logger.critical("Operator Sets must have unique names.") # pragma: no cover
157
+
158
+ return values
159
+
160
+ def get_info(self) -> Dict[str, Any]:
161
+ """
162
+ Get a dictionary summarizing the TargetPlatformCapabilities properties.
163
+
164
+ Returns:
165
+ Dict[str, Any]: Summary of the TargetPlatformCapabilities properties.
166
+ """
167
+ return {
168
+ "Model name": self.name,
169
+ "Operators sets": [o.get_info() for o in self.operator_set] if self.operator_set else [],
170
+ "Fusing patterns": [f.get_info() for f in self.fusing_patterns] if self.fusing_patterns else [],
171
+ }
172
+
173
+ def show(self):
174
+ """
175
+ Display the TargetPlatformCapabilities.
176
+ """
177
+ pprint.pprint(self.get_info(), sort_dicts=False)
@@ -93,6 +93,7 @@ class AttachTpcToKeras(AttachTpcToFramework):
93
93
  OperatorSetNames.TOPK: [tf.nn.top_k],
94
94
  OperatorSetNames.FAKE_QUANT: [tf.quantization.fake_quant_with_min_max_vars],
95
95
  OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [tf.image.combined_non_max_suppression],
96
+ OperatorSetNames.BOX_DECODE: [], # no such operator in keras
96
97
  OperatorSetNames.ZERO_PADDING2D: [ZeroPadding2D],
97
98
  OperatorSetNames.CAST: [tf.cast],
98
99
  OperatorSetNames.STRIDED_SLICE: [tf.strided_slice],
@@ -97,7 +97,8 @@ class AttachTpcToPytorch(AttachTpcToFramework):
97
97
  OperatorSetNames.L2NORM: [LayerFilterParams(torch.nn.functional.normalize,
98
98
  Eq('p', 2) | Eq('p', None))],
99
99
  OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
100
- OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [] # no such operator in pytorch
100
+ OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [], # no such operator in pytorch
101
+ OperatorSetNames.BOX_DECODE: [] # no such operator in pytorch
101
102
  }
102
103
 
103
104
  pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
@@ -1,131 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import copy
16
- from typing import Any, List
17
- from model_compression_toolkit.core.common.graph.base_graph import Graph
18
- from model_compression_toolkit.core.common.graph.base_node import BaseNode
19
- from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
20
- FrameworkQuantizationCapabilities
21
- from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.layer_filter_params import LayerFilterParams
22
-
23
-
24
- def filter_fusing_patterns(fusing_patterns: List[List[Any]], node: BaseNode, idx: int = 0) -> List[List[Any]]:
25
- """
26
- Update relevant fusing patterns object if layer number 'idx' inside the fusion matches the node
27
- Args:
28
- fusing_patterns: supported fusings
29
- node: node to decide if it can be a part of fusion
30
- idx: index of layer in the fusion
31
- Returns:
32
- fusing_patterns after filtering non-relevant fusions
33
- """
34
- valid_fusing_patterns = []
35
- for i, fusing_pattern in enumerate(fusing_patterns):
36
- if idx < len(fusing_pattern):
37
- if (type(fusing_pattern[idx]) == LayerFilterParams and node.is_match_filter_params(fusing_pattern[idx])) or \
38
- node.is_match_type(fusing_pattern[idx]):
39
- valid_fusing_patterns.append(fusing_pattern)
40
-
41
- # Return only valid patterns for this node
42
- return valid_fusing_patterns
43
-
44
-
45
- def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List[BaseNode]) -> bool:
46
- """
47
- Check if the fusion is valid: exist in fusing_patterns
48
- Args:
49
- fusing_patterns: supported fusing patterns
50
- nodes: nodes which are participating in fusion
51
- Returns:
52
- whether the fusion in valid
53
- """
54
- fusion_depth = len(nodes)
55
- if fusion_depth <= 1:
56
- return False
57
- for fusing_pattern in fusing_patterns:
58
- if fusion_depth != len(fusing_pattern):
59
- continue
60
- counter = 0
61
- for i, layer in enumerate(fusing_pattern):
62
- if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or \
63
- nodes[i].is_match_type(layer):
64
- counter += 1
65
- if counter == fusion_depth:
66
- return True
67
- return False
68
-
69
-
70
- def disable_nodes_activation_quantization(nodes: List[BaseNode]):
71
- """
72
- Disable activation for non-quantization needed due to fusion
73
- Args:
74
- nodes: nodes to update their activation quantization
75
- """
76
- for node in nodes:
77
- for qc in node.candidates_quantization_cfg:
78
- qc.activation_quantization_cfg.enable_activation_quantization = False
79
-
80
-
81
- def fusion(graph: Graph, fqc: FrameworkQuantizationCapabilities) -> Graph:
82
- """
83
- Fusing defines a list of operators that should be combined and treated as a single operator,
84
- hence no quantization is applied between them when they appear in the graph.
85
- This function search and disable quantization for such patterns.
86
- Args:
87
- graph: Graph we apply the fusion on.
88
- fqc: FrameworkQuantizationCapabilities object that describes the desired inference target platform (includes fusing patterns MCT should handle).
89
- Returns:
90
- Graph after applying fusion activation marking.
91
- """
92
- fusing_patterns = fqc.get_fusing_patterns()
93
- if len(fusing_patterns) == 0:
94
- return graph
95
-
96
- # Find max fusion
97
- max_layers_fusing = 0 if len(fusing_patterns) == 0 else max([len(fusing_pattern) for fusing_pattern in fusing_patterns])
98
-
99
-
100
- # -------------------------------- #
101
- # Fusion algorithm
102
- # -------------------------------- #
103
- fused_graph = copy.deepcopy(graph)
104
-
105
- # Travel along the graph to find layers for fusing
106
- nodes = fused_graph.get_topo_sorted_nodes()
107
- fused_nodes = [] # nodes that are participating in fusing
108
- for node in nodes:
109
- # Skip if already in fusing
110
- if node in fused_nodes:
111
- continue
112
- # Start fusing search
113
- fusing_nodes = [] # nodes that are candidates for participating in fusing
114
- patterns = copy.deepcopy(fusing_patterns)
115
- next_nodes = [node]
116
- for i in range(max_layers_fusing):
117
- patterns = filter_fusing_patterns(patterns, next_nodes[0], i)
118
- if len(patterns) == 0: # Give up if no more fusion pattern
119
- break
120
- fusing_nodes.append(next_nodes[0])
121
- next_nodes = fused_graph.get_next_nodes(fusing_nodes[-1])
122
- if len(next_nodes) != 1: # Give up if node has more than one connection (not supported for fusion)
123
- break
124
-
125
- # New fusion: mark all nodes in the fusion except last one
126
- if is_valid_fusion(fusing_patterns, fusing_nodes):
127
- fused_nodes.extend(fusing_nodes)
128
- disable_nodes_activation_quantization(fusing_nodes[:-1])
129
- fused_graph.update_fused_nodes(fusing_nodes)
130
-
131
- return fused_graph