mct-nightly 2.3.0.20250514.602__py3-none-any.whl → 2.3.0.20250516.613__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.3.0.20250514.602.dist-info → mct_nightly-2.3.0.20250516.613.dist-info}/METADATA +2 -2
- {mct_nightly-2.3.0.20250514.602.dist-info → mct_nightly-2.3.0.20250516.613.dist-info}/RECORD +17 -17
- {mct_nightly-2.3.0.20250514.602.dist-info → mct_nightly-2.3.0.20250516.613.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/constants.py +5 -0
- model_compression_toolkit/core/common/fusion/fusing_info.py +75 -7
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +8 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +11 -8
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/schema_compatability.py +54 -14
- model_compression_toolkit/target_platform_capabilities/schema/v2.py +7 -5
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py +4 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -0
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py +25 -6
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +17 -2
- {mct_nightly-2.3.0.20250514.602.dist-info → mct_nightly-2.3.0.20250516.613.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.3.0.20250514.602.dist-info → mct_nightly-2.3.0.20250516.613.dist-info}/top_level.txt +0 -0
{mct_nightly-2.3.0.20250514.602.dist-info → mct_nightly-2.3.0.20250516.613.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mct-nightly
|
3
|
-
Version: 2.3.0.
|
3
|
+
Version: 2.3.0.20250516.613
|
4
4
|
Summary: A Model Compression Toolkit for neural networks
|
5
5
|
Author-email: ssi-dnn-dev@sony.com
|
6
6
|
Classifier: Programming Language :: Python :: 3
|
@@ -21,7 +21,7 @@ Requires-Dist: PuLP
|
|
21
21
|
Requires-Dist: matplotlib<3.10.0
|
22
22
|
Requires-Dist: scipy
|
23
23
|
Requires-Dist: protobuf
|
24
|
-
Requires-Dist: mct-quantizers
|
24
|
+
Requires-Dist: mct-quantizers==1.6.0
|
25
25
|
Requires-Dist: pydantic>=2.0
|
26
26
|
Requires-Dist: edge-mdt-cl-dev
|
27
27
|
Dynamic: author-email
|
{mct_nightly-2.3.0.20250514.602.dist-info → mct_nightly-2.3.0.20250516.613.dist-info}/RECORD
RENAMED
@@ -1,6 +1,6 @@
|
|
1
|
-
mct_nightly-2.3.0.
|
2
|
-
model_compression_toolkit/__init__.py,sha256=
|
3
|
-
model_compression_toolkit/constants.py,sha256=
|
1
|
+
mct_nightly-2.3.0.20250516.613.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
2
|
+
model_compression_toolkit/__init__.py,sha256=vpUrhwvqdXrPwyv56B5qlxS71UYcqZVGpzP-u2yJA9E,1557
|
3
|
+
model_compression_toolkit/constants.py,sha256=KNgiNLpsMgSYyXMNEbHXd4bFNerQc1D6HH3vpbUq_Gs,4086
|
4
4
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
5
5
|
model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
|
6
6
|
model_compression_toolkit/metadata.py,sha256=x_Bk4VpzILdsFax6--CZ3X18qUTP28sbF_AhoQW8dNc,4003
|
@@ -31,7 +31,7 @@ model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.p
|
|
31
31
|
model_compression_toolkit/core/common/collectors/statistics_collector.py,sha256=psijsQZefwjMDH8SU5E18n65HiGtQilPhKr1hhzZX-I,8268
|
32
32
|
model_compression_toolkit/core/common/collectors/weighted_histogram_collector.py,sha256=zp3dE7YTqWmkD5QWdRhsl9zD8W6Lr96G1Wjw1g2D3T0,4894
|
33
33
|
model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
34
|
-
model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=
|
34
|
+
model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=S7hBbUJxL52Z8uJ9_upLdFyoSEJvgmVX0OmneqDIj-c,18656
|
35
35
|
model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=F0AaAUBpJ9JjHMB5H2LD9pdwTSWJK-Kqm9dQmGHX1Jo,7368
|
36
36
|
model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
|
37
37
|
model_compression_toolkit/core/common/graph/base_graph.py,sha256=BSQpKy0BXoGX0G0bySTo72n2isTqvtpkbRYYa8-hPO4,41435
|
@@ -233,7 +233,7 @@ model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,s
|
|
233
233
|
model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
|
234
234
|
model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
|
235
235
|
model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=HAzzWOnPcIeDxQO1712254RNTBZD-gVSMSVnxqpfuQ0,11907
|
236
|
-
model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=
|
236
|
+
model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=FVewVclo3kx-Oufr_PJE4-MAqkKJseBvd96vz8JtuBg,22163
|
237
237
|
model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
|
238
238
|
model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
239
239
|
model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
|
@@ -349,7 +349,7 @@ model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer
|
|
349
349
|
model_compression_toolkit/exporter/model_wrapper/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
350
350
|
model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py,sha256=vQUGbCi8_pGoN8DwQ0IblSeN6L9t6Cr0reZNuCbBpkM,3469
|
351
351
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
352
|
-
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=
|
352
|
+
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=AKSpWbTtXHPjW7hY655OXANaK5SgEiF-FZCu5zoioxM,6860
|
353
353
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=Pl8a8MSZMzNbm5vngujFjCt_iSMbSmKjlcL1DvN9nTM,9292
|
354
354
|
model_compression_toolkit/gptq/__init__.py,sha256=pEgkJvmf05KSw70iLDTz_6LI_2Oi5L8sTN0JsEUpnpk,1445
|
355
355
|
model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
|
@@ -433,20 +433,20 @@ model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py,sha2
|
|
433
433
|
model_compression_toolkit/target_platform_capabilities/__init__.py,sha256=8RVOriZg-XNjSt53h_4Yum0oRgOe2gp5H45dfG_lZxE,1415
|
434
434
|
model_compression_toolkit/target_platform_capabilities/constants.py,sha256=JRz9DoxLRpkqvu532TFkIvv0595Bfb9NtU4pRp4urDY,1540
|
435
435
|
model_compression_toolkit/target_platform_capabilities/immutable.py,sha256=YhROBiXEIB3TU-bAFrnL3qbAsb1yuWPBAQ_CLOJbYUU,1827
|
436
|
-
model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=
|
436
|
+
model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=hFBq-qKUM9qKZGaMmrxsEmurTV_D1kWIXI1rTERZsbk,5241
|
437
437
|
model_compression_toolkit/target_platform_capabilities/schema/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
|
438
|
-
model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=
|
439
|
-
model_compression_toolkit/target_platform_capabilities/schema/schema_compatability.py,sha256=
|
438
|
+
model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=hf539WJ3nBGn0RnALXrKmAPnbhJ-VmWmLIa207x8b4M,541
|
439
|
+
model_compression_toolkit/target_platform_capabilities/schema/schema_compatability.py,sha256=OpZ9SH2aTAVTCBfj1m3wcAeouk_q_16yWxCwByXK_M8,6294
|
440
440
|
model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py,sha256=vBkXxVJagm9JKB9cdm4Pvi7u_luriXUjvNn0-m8Zr0k,4653
|
441
441
|
model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=oWKNQnnz04kmijmdWtRyXgVXbJ6BG_V_bUBz_MfUM94,27116
|
442
|
-
model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=
|
442
|
+
model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=1hrvq4EeLDRe0-wvpHkMLXMYYbETQ_tX-3FAHHsxb18,10880
|
443
443
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/__init__.py,sha256=XjNws3zoiJkeH4ixKqrLA5xBvpv5rq31qX7wYQjNpZM,1447
|
444
444
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py,sha256=HJ8uc3PFfyxg-WpVXPBg4mGaox8Z9bRqtQNbRfIyAk4,3745
|
445
|
-
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=
|
446
|
-
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=
|
445
|
+
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=5Uyb5CurpLm4fgOiARKYwy3T-bb0NMmJXIRBgRjMgjo,7301
|
446
|
+
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=R-kTbJka37u3toun9rRDGGGXYR3Sv4VdirLIn5G1BgQ,6541
|
447
447
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attribute_filter.py,sha256=jfhszvuD2Fyy6W2KjlLzXBQKFzTqGAaDZeFVr4-ONQw,8776
|
448
448
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/current_tpc.py,sha256=_kFG0USYa6yzvLsi82_Vusv_KR8Hi7J1u680pPXECuo,2192
|
449
|
-
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=
|
449
|
+
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=Y-HZKwoakzY6PAYYj9l-h19yLMqBs0qBHo2YIKIsrN8,10375
|
450
450
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities_component.py,sha256=9Hg6AMCzTdDsKKgivRd61UjxGT5SWvKsc3mIUPPsYDQ,1021
|
451
451
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/layer_filter_params.py,sha256=dIu6k1xvGKLtk_47wq1eKYvrS4lYAknAXTeJfFstW0Y,3878
|
452
452
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/operations_to_layers.py,sha256=vZ7I2XDr_YDgU8oQt8gKkcuUOJf28DCzCPunPK2h_Xw,6563
|
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
528
528
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
|
529
529
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
530
530
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
531
|
-
mct_nightly-2.3.0.
|
532
|
-
mct_nightly-2.3.0.
|
533
|
-
mct_nightly-2.3.0.
|
534
|
-
mct_nightly-2.3.0.
|
531
|
+
mct_nightly-2.3.0.20250516.613.dist-info/METADATA,sha256=WTFK8E9-__LO8PW9LL44DHCYKwUvNgKMkGl0ko8dcf0,25135
|
532
|
+
mct_nightly-2.3.0.20250516.613.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
533
|
+
mct_nightly-2.3.0.20250516.613.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
534
|
+
mct_nightly-2.3.0.20250516.613.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.3.0.
|
30
|
+
__version__ = "2.3.0.20250516.000613"
|
@@ -138,3 +138,8 @@ SHAPE = 'shape'
|
|
138
138
|
NODE_NAME = 'node_name'
|
139
139
|
TOTAL_SIZE = 'total_size'
|
140
140
|
NODE_OUTPUT_INDEX = 'node_output_index'
|
141
|
+
|
142
|
+
|
143
|
+
# Fusing Patterns constants
|
144
|
+
FUSED_LAYER_PATTERN = 'fused_layer_pattern'
|
145
|
+
FUSED_OP_QUANT_CONFIG = 'fused_op_quantization_config'
|
@@ -14,6 +14,8 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
from model_compression_toolkit.target_platform_capabilities import LayerFilterParams
|
17
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig
|
18
|
+
from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
|
17
19
|
from dataclasses import dataclass, field
|
18
20
|
|
19
21
|
from typing import Optional, List, Dict, Any, Tuple
|
@@ -41,6 +43,7 @@ class FusingInfo:
|
|
41
43
|
fusing_patterns: any = None
|
42
44
|
fusing_data: Dict[str, Tuple['BaseNode']] = field(default_factory=dict)
|
43
45
|
node_to_fused_node_map: Dict[str, str] = field(init=False, default_factory=dict)
|
46
|
+
fused_op_id_to_quant_config: Dict[str, OpQuantizationConfig] = field(default_factory=dict)
|
44
47
|
|
45
48
|
def __post_init__(self):
|
46
49
|
"""Validates and initializes mappings after dataclass instantiation."""
|
@@ -49,6 +52,7 @@ class FusingInfo:
|
|
49
52
|
assert isinstance(op_nodes, tuple) and len(op_nodes) > 1, f"Found invalid fused op nodes: {op_nodes}"
|
50
53
|
|
51
54
|
self._init_node_mapping()
|
55
|
+
self._init_quantization_config_map()
|
52
56
|
|
53
57
|
def _init_node_mapping(self) -> None:
|
54
58
|
"""
|
@@ -59,6 +63,15 @@ class FusingInfo:
|
|
59
63
|
for node in nodes:
|
60
64
|
self.node_to_fused_node_map[node.name] = op_id
|
61
65
|
|
66
|
+
def _init_quantization_config_map(self) -> None:
|
67
|
+
"""
|
68
|
+
Init the mapping between fused operation IDs and their quantization configurations.
|
69
|
+
"""
|
70
|
+
self.fused_op_id_to_quant_config.clear()
|
71
|
+
if self.fusing_patterns is not None:
|
72
|
+
for op_id, nodes in self.fusing_data.items():
|
73
|
+
self.set_fused_op_quantization_config(op_id, nodes)
|
74
|
+
|
62
75
|
def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
|
63
76
|
"""
|
64
77
|
Add a new fused operation with the given ID and set of nodes.
|
@@ -78,6 +91,22 @@ class FusingInfo:
|
|
78
91
|
for node in nodes:
|
79
92
|
self.node_to_fused_node_map[node.name] = op_id
|
80
93
|
|
94
|
+
# Update the quantization config mapping for this operation
|
95
|
+
if self.fusing_patterns is not None:
|
96
|
+
self.set_fused_op_quantization_config(op_id, nodes)
|
97
|
+
|
98
|
+
def set_fused_op_quantization_config(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
|
99
|
+
"""
|
100
|
+
Set the quantization configuration for a given fused operation ID.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
op_id (str): The identifier for the fused operation.
|
104
|
+
nodes (Tuple[BaseNode]): The tuple of nodes that form the fused operation.
|
105
|
+
"""
|
106
|
+
fusing_pattern = next((fp for fp in self.fusing_patterns if is_valid_fusion([fp.get(FUSED_LAYER_PATTERN)], nodes)), None)
|
107
|
+
if fusing_pattern is not None:
|
108
|
+
self.fused_op_id_to_quant_config[op_id] = fusing_pattern.get(FUSED_OP_QUANT_CONFIG)
|
109
|
+
|
81
110
|
def remove_fused_operation(self, op_id: str) -> None:
|
82
111
|
"""
|
83
112
|
Remove a fused operation by its ID.
|
@@ -95,6 +124,7 @@ class FusingInfo:
|
|
95
124
|
for node in nodes:
|
96
125
|
self.node_to_fused_node_map.pop(node.name, None)
|
97
126
|
del self.fusing_data[op_id]
|
127
|
+
self.fused_op_id_to_quant_config.pop(op_id, None)
|
98
128
|
|
99
129
|
def get_fused_node_name(self, node_name: str) -> Optional[str]:
|
100
130
|
"""
|
@@ -117,6 +147,15 @@ class FusingInfo:
|
|
117
147
|
"""
|
118
148
|
return self.node_to_fused_node_map.copy()
|
119
149
|
|
150
|
+
def get_fusing_quantization_config_map(self) -> Dict[str, OpQuantizationConfig]:
|
151
|
+
"""
|
152
|
+
Retrieve a copy of the mapping from fused operation IDs to their quantization configurations.
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
A dictionary mapping each fused operation ID to its quantization configuration.
|
156
|
+
"""
|
157
|
+
return self.fused_op_id_to_quant_config.copy()
|
158
|
+
|
120
159
|
def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]:
|
121
160
|
"""
|
122
161
|
Retrieve the list of nodes for a given fused operation ID.
|
@@ -129,6 +168,18 @@ class FusingInfo:
|
|
129
168
|
"""
|
130
169
|
return self.fusing_data.get(op_id)
|
131
170
|
|
171
|
+
def get_fused_op_quantization_config(self, op_id: str) -> OpQuantizationConfig:
|
172
|
+
"""
|
173
|
+
Retrieve the quantization configuration for a given fused operation ID.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
op_id (str): The identifier for the fused operation.
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
OpQuantizationConfig: The quantization configuration for the operation, or None if not found.
|
180
|
+
"""
|
181
|
+
return self.fused_op_id_to_quant_config.get(op_id)
|
182
|
+
|
132
183
|
def is_node_in_fused_op(self, node: 'BaseNode') -> bool:
|
133
184
|
"""
|
134
185
|
Check if a node is part of any fused operation.
|
@@ -216,10 +267,11 @@ class FusingInfo:
|
|
216
267
|
all_fused_nodes.update(node_set)
|
217
268
|
|
218
269
|
# Check 4: Ensure the sequence matches a valid fusing pattern
|
219
|
-
|
270
|
+
valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns)
|
271
|
+
if not is_valid_fusion(valid_fusing_patterns, nodes):
|
220
272
|
raise ValueError(
|
221
273
|
f"Fused operation {op_id} does not match any valid fusing pattern "
|
222
|
-
f"from {
|
274
|
+
f"from {valid_fusing_patterns}."
|
223
275
|
)
|
224
276
|
|
225
277
|
def is_nodes_eligible_to_be_fused(self, nodes: List['BaseNode']) -> bool:
|
@@ -240,7 +292,8 @@ class FusingInfo:
|
|
240
292
|
return False
|
241
293
|
|
242
294
|
# Check if the provided nodes match a valid fusion pattern
|
243
|
-
|
295
|
+
valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns)
|
296
|
+
return is_valid_fusion(fusing_patterns=valid_fusing_patterns, nodes=nodes)
|
244
297
|
|
245
298
|
def __repr__(self) -> str:
|
246
299
|
"""
|
@@ -287,8 +340,11 @@ class FusingInfoGenerator:
|
|
287
340
|
if not self._fusing_patterns:
|
288
341
|
return FusingInfo(fusing_patterns=self._fusing_patterns)
|
289
342
|
|
343
|
+
# Extract fusing layer patterns
|
344
|
+
fusing_layer_patterns = _get_fusing_layer_patterns(self._fusing_patterns)
|
345
|
+
|
290
346
|
# Find max fusion
|
291
|
-
|
347
|
+
max_layer_patterns = max([len(fusing_layer_pattern) for fusing_layer_pattern in fusing_layer_patterns])
|
292
348
|
|
293
349
|
# Travel along the graph to find layers for fusing
|
294
350
|
nodes = graph.get_topo_sorted_nodes()
|
@@ -302,9 +358,9 @@ class FusingInfoGenerator:
|
|
302
358
|
continue
|
303
359
|
# Start fusing search
|
304
360
|
fusing_nodes = [] # nodes that are candidates for participating in fusing
|
305
|
-
patterns = copy.deepcopy(
|
361
|
+
patterns = copy.deepcopy(fusing_layer_patterns)
|
306
362
|
next_nodes = [node]
|
307
|
-
for i in range(
|
363
|
+
for i in range(max_layer_patterns):
|
308
364
|
patterns = get_valid_fusing_patterns_for_node(patterns, next_nodes[0], i)
|
309
365
|
if len(patterns) == 0: # Give up if no more fusion pattern
|
310
366
|
break
|
@@ -314,7 +370,7 @@ class FusingInfoGenerator:
|
|
314
370
|
break
|
315
371
|
|
316
372
|
# New fusion
|
317
|
-
if is_valid_fusion(
|
373
|
+
if is_valid_fusion(fusing_layer_patterns, fusing_nodes):
|
318
374
|
fused_op_id = FusingInfo.generate_fused_op_id(fusing_nodes)
|
319
375
|
assert fused_op_id not in fusing_info, f"{fused_op_id} is already in fusing info: {fusing_info}"
|
320
376
|
fusing_info[fused_op_id] = tuple(fusing_nodes)
|
@@ -371,3 +427,15 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode']) -
|
|
371
427
|
if counter == fusion_depth:
|
372
428
|
return True
|
373
429
|
return False
|
430
|
+
|
431
|
+
|
432
|
+
def _get_fusing_layer_patterns(fusing_patterns: List[Dict[Any, OpQuantizationConfig]]) -> List[List[Any]]:
|
433
|
+
"""
|
434
|
+
Extracts the fusing layer patterns from the provided fusing patterns.
|
435
|
+
Args:
|
436
|
+
fusing_patterns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config.
|
437
|
+
|
438
|
+
Returns:
|
439
|
+
supported fusing layer patterns
|
440
|
+
"""
|
441
|
+
return [f.get(FUSED_LAYER_PATTERN) for f in fusing_patterns]
|
@@ -233,6 +233,7 @@ class PytorchModel(torch.nn.Module):
|
|
233
233
|
self.return_float_outputs = return_float_outputs
|
234
234
|
self.wrapper = wrapper
|
235
235
|
self.get_activation_quantizer_holder = get_activation_quantizer_holder_fn
|
236
|
+
self.insert_preserving_quantizers = graph.fqc.insert_preserving_quantizers
|
236
237
|
self.reuse_groups = {}
|
237
238
|
self._reused_nodes = []
|
238
239
|
|
@@ -335,12 +336,17 @@ class PytorchModel(torch.nn.Module):
|
|
335
336
|
activation_quantizer_holder = None
|
336
337
|
if self.use_activation_holder_during_model_building:
|
337
338
|
if node.is_activation_quantization_enabled():
|
338
|
-
activation_quantizer_holder = self.get_activation_quantizer_holder(node,
|
339
|
+
activation_quantizer_holder = self.get_activation_quantizer_holder(node,
|
340
|
+
holder_type=PytorchActivationQuantizationHolder)
|
339
341
|
|
340
342
|
elif node.is_quantization_preserving():
|
341
343
|
prev_node = self.graph.retrieve_preserved_quantization_node(node)
|
342
344
|
if prev_node.is_activation_quantization_enabled():
|
343
|
-
|
345
|
+
if self.insert_preserving_quantizers:
|
346
|
+
holder_kwargs = {'quantization_bypass': True}
|
347
|
+
activation_quantizer_holder = self.get_activation_quantizer_holder(prev_node,
|
348
|
+
holder_type=PytorchPreservingActivationQuantizationHolder,
|
349
|
+
**holder_kwargs)
|
344
350
|
|
345
351
|
if activation_quantizer_holder is not None:
|
346
352
|
activation_quantizer_holder_name = node.name + '_' + ACTIVATION_HOLDER_QUANTIZER
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py
CHANGED
@@ -65,26 +65,29 @@ if FOUND_TORCH:
|
|
65
65
|
return module
|
66
66
|
|
67
67
|
|
68
|
-
def get_activation_quantizer_holder(node: BaseNode, holder_type: PytorchActivationQuantizationHolder,
|
68
|
+
def get_activation_quantizer_holder(node: BaseNode, holder_type: PytorchActivationQuantizationHolder,
|
69
|
+
fw_impl, **kwargs) -> Callable:
|
69
70
|
"""
|
70
71
|
Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node.
|
71
72
|
If the layer is not supposed to be wrapped with an activation quantizer - return None.
|
73
|
+
|
72
74
|
Args:
|
73
75
|
node: Node to attach a PytorchActivationQuantizationHolder to its output.
|
74
76
|
holder_type: The type of the activation quantization holder to use.
|
75
77
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
78
|
+
**kwargs: Key-arguments to be passed to the quantization holder initialization to set specific arguments
|
79
|
+
based on the holder's type.
|
80
|
+
|
76
81
|
Returns:
|
77
82
|
A PytorchActivationQuantizationHolder module for the node's activation quantization.
|
78
83
|
"""
|
79
84
|
# Holder by definition uses a single quantizer for the activation quantization
|
80
|
-
# thus we make sure this is the only possible case (unless it's a node
|
85
|
+
# thus we make sure this is the only possible case (unless it's a node with no activation
|
81
86
|
# quantization, which in this case has an empty list).
|
82
87
|
_, activation_quantizers = fw_impl.get_inferable_quantizers(node)
|
83
88
|
if len(activation_quantizers) == 1:
|
84
|
-
|
85
|
-
|
86
|
-
elif holder_type == PytorchPreservingActivationQuantizationHolder:
|
87
|
-
return holder_type(activation_quantizers[0], quantization_bypass=True)
|
89
|
+
return holder_type(activation_quantizers[0], **kwargs)
|
90
|
+
|
88
91
|
Logger.critical(
|
89
92
|
f'PytorchActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
|
90
93
|
f'were found for node {node}')
|
@@ -105,9 +108,9 @@ if FOUND_TORCH:
|
|
105
108
|
wrapper=lambda n, m:
|
106
109
|
fully_quantized_wrapper(n, m,
|
107
110
|
fw_impl=fw_impl),
|
108
|
-
get_activation_quantizer_holder_fn=lambda n, holder_type:
|
111
|
+
get_activation_quantizer_holder_fn=lambda n, holder_type, **kwargs:
|
109
112
|
get_activation_quantizer_holder(n, holder_type,
|
110
|
-
fw_impl=fw_impl)).build_model()
|
113
|
+
fw_impl=fw_impl, **kwargs)).build_model()
|
111
114
|
|
112
115
|
Logger.info("\nPlease run your accuracy evaluation on the exported quantized model to verify it's accuracy.\n"
|
113
116
|
"Checkout the FAQ and Troubleshooting pages for resolving common issues and improving the quantized model accuracy:\n"
|
@@ -12,14 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
import copy
|
15
16
|
from typing import Any, Union
|
16
17
|
|
17
18
|
import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema_v1
|
18
19
|
import model_compression_toolkit.target_platform_capabilities.schema.v2 as schema_v2
|
19
20
|
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as current_schema
|
20
21
|
|
21
|
-
ALL_SCHEMA_VERSIONS = [schema_v1] # needs to be updated with all active schema versions
|
22
|
-
FUTURE_SCHEMA_VERSIONS = [
|
22
|
+
ALL_SCHEMA_VERSIONS = [schema_v1, schema_v2] # needs to be updated with all active schema versions
|
23
|
+
FUTURE_SCHEMA_VERSIONS = [] # once future schema becomes current schema, move to it ALL_SCHEMA_VERSIONS
|
23
24
|
all_tpc_types = tuple([s.TargetPlatformCapabilities for s in ALL_SCHEMA_VERSIONS])
|
24
25
|
tpc_or_str_type = all_tpc_types + (str,)
|
25
26
|
|
@@ -33,19 +34,57 @@ def is_tpc_instance(tpc_obj_or_path: Any) -> bool:
|
|
33
34
|
return type(tpc_obj_or_path) in all_tpc_types
|
34
35
|
|
35
36
|
|
36
|
-
def
|
37
|
+
def get_schema_by_version(schema_version: str):
|
38
|
+
return {
|
39
|
+
"1": schema_v1,
|
40
|
+
"2": schema_v2
|
41
|
+
}[schema_version]
|
42
|
+
|
43
|
+
|
44
|
+
def _schema_v1_to_v2(
|
45
|
+
tpc: schema_v1.TargetPlatformCapabilities) -> schema_v2.TargetPlatformCapabilities:
|
37
46
|
"""
|
38
|
-
Converts given tpc of schema version 1 to schema version 2
|
39
|
-
|
47
|
+
Converts given tpc of schema version 1 to schema version 2.
|
48
|
+
Schema v2 updates:
|
49
|
+
1. New TPC field: insert_preserving_quantizers
|
50
|
+
Compatability behavior: Set field to False by default
|
51
|
+
2. New field in Fusing: fuse_op_quantization_config
|
52
|
+
Compatability behavior: set enable_activation_quantization=False in every fuse_op_quantization_config
|
53
|
+
3. New operator set names: EXP, SIN, COS
|
54
|
+
Compatability behavior: Not required
|
55
|
+
:return: TargetPlatformCapabilities instance of schema version 2
|
40
56
|
"""
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
57
|
+
v1_default_qco = tpc.default_qco.base_config
|
58
|
+
v2_default_qco = schema_v2.OpQuantizationConfig(
|
59
|
+
default_weight_attr_config=v1_default_qco.default_weight_attr_config,
|
60
|
+
attr_weights_configs_mapping=v1_default_qco.attr_weights_configs_mapping,
|
61
|
+
activation_quantization_method=v1_default_qco.activation_quantization_method,
|
62
|
+
activation_n_bits=v1_default_qco.activation_n_bits,
|
63
|
+
supported_input_activation_n_bits=v1_default_qco.supported_input_activation_n_bits,
|
64
|
+
enable_activation_quantization=False, # set to False by default because feature not exist in schema v1
|
65
|
+
quantization_preserving=v1_default_qco.quantization_preserving,
|
66
|
+
fixed_scale=v1_default_qco.fixed_scale,
|
67
|
+
fixed_zero_point=v1_default_qco.fixed_zero_point,
|
68
|
+
simd_size=v1_default_qco.simd_size,
|
69
|
+
signedness=v1_default_qco.signedness)
|
70
|
+
|
71
|
+
schema_v2_fusing_patters = []
|
72
|
+
for fussing_pattern in tpc.fusing_patterns:
|
73
|
+
schema_v2_fusing_patters.append(
|
74
|
+
schema_v2.Fusing(operator_groups=fussing_pattern.operator_groups,
|
75
|
+
fuse_op_quantization_config=copy.deepcopy(v2_default_qco),
|
76
|
+
name=fussing_pattern.name))
|
77
|
+
|
78
|
+
tpc_schema_v2 = schema_v2.TargetPlatformCapabilities(default_qco=tpc.default_qco,
|
79
|
+
operator_set=tpc.operator_set,
|
80
|
+
fusing_patterns=schema_v2_fusing_patters,
|
81
|
+
tpc_minor_version=tpc.tpc_minor_version,
|
82
|
+
tpc_patch_version=tpc.tpc_patch_version,
|
83
|
+
tpc_platform_type=tpc.tpc_platform_type,
|
84
|
+
add_metadata=tpc.add_metadata,
|
85
|
+
insert_preserving_quantizers=False) # set to False by default because feature not exist in schema v1
|
86
|
+
return tpc_schema_v2
|
87
|
+
|
49
88
|
|
50
89
|
def get_conversion_map() -> dict:
|
51
90
|
"""
|
@@ -60,7 +99,8 @@ def get_conversion_map() -> dict:
|
|
60
99
|
return conversion_map
|
61
100
|
|
62
101
|
|
63
|
-
def tpc_to_current_schema_version(
|
102
|
+
def tpc_to_current_schema_version(
|
103
|
+
tpc: Union[all_tpc_types]) -> current_schema.TargetPlatformCapabilities:
|
64
104
|
"""
|
65
105
|
Given tpc instance of some schema version, convert it to the current MCT schema version.
|
66
106
|
|
@@ -91,6 +91,8 @@ class OperatorSetNames(str, Enum):
|
|
91
91
|
STRIDED_SLICE = "StridedSlice"
|
92
92
|
SSD_POST_PROCESS = "SSDPostProcess"
|
93
93
|
EXP = "Exp"
|
94
|
+
SIN = "Sin"
|
95
|
+
COS = "Cos"
|
94
96
|
|
95
97
|
@classmethod
|
96
98
|
def get_values(cls):
|
@@ -218,11 +220,11 @@ class TargetPlatformCapabilities(BaseModel):
|
|
218
220
|
SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
|
219
221
|
"""
|
220
222
|
default_qco: QuantizationConfigOptions
|
221
|
-
operator_set: Optional[Tuple[OperatorsSet, ...]]
|
222
|
-
fusing_patterns: Optional[Tuple[Fusing, ...]]
|
223
|
-
tpc_minor_version: Optional[int]
|
224
|
-
tpc_patch_version: Optional[int]
|
225
|
-
tpc_platform_type: Optional[str]
|
223
|
+
operator_set: Optional[Tuple[OperatorsSet, ...]] = None
|
224
|
+
fusing_patterns: Optional[Tuple[Fusing, ...]] = None
|
225
|
+
tpc_minor_version: Optional[int] = None
|
226
|
+
tpc_patch_version: Optional[int] = None
|
227
|
+
tpc_platform_type: Optional[str] = None
|
226
228
|
add_metadata: bool = True
|
227
229
|
name: Optional[str] = "default_tpc"
|
228
230
|
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py
CHANGED
@@ -100,7 +100,10 @@ class AttachTpcToKeras(AttachTpcToFramework):
|
|
100
100
|
OperatorSetNames.LOG_SOFTMAX: [tf.nn.log_softmax],
|
101
101
|
OperatorSetNames.ADD_BIAS: [tf.nn.bias_add],
|
102
102
|
OperatorSetNames.L2NORM: [tf.math.l2_normalize],
|
103
|
-
OperatorSetNames.SSD_POST_PROCESS: [SSDPostProcess]
|
103
|
+
OperatorSetNames.SSD_POST_PROCESS: [SSDPostProcess],
|
104
|
+
OperatorSetNames.EXP: [tf.math.exp],
|
105
|
+
OperatorSetNames.SIN: [tf.math.sin],
|
106
|
+
OperatorSetNames.COS: [tf.math.cos]
|
104
107
|
}
|
105
108
|
|
106
109
|
self._opset2attr_mapping = {
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py
CHANGED
@@ -99,6 +99,9 @@ class AttachTpcToPytorch(AttachTpcToFramework):
|
|
99
99
|
Eq('p', 2) | Eq('p', None))],
|
100
100
|
OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
|
101
101
|
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [MulticlassNMS, MulticlassNMSWithIndices],
|
102
|
+
OperatorSetNames.EXP: [torch.exp],
|
103
|
+
OperatorSetNames.SIN: [torch.sin],
|
104
|
+
OperatorSetNames.COS: [torch.cos],
|
102
105
|
}
|
103
106
|
|
104
107
|
pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
|
@@ -31,6 +31,9 @@ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_s
|
|
31
31
|
OpQuantizationConfig, QuantizationConfigOptions
|
32
32
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import _current_tpc
|
33
33
|
|
34
|
+
from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
|
35
|
+
|
36
|
+
|
34
37
|
class FrameworkQuantizationCapabilities(ImmutableClass):
|
35
38
|
"""
|
36
39
|
Attach framework information to a modeled hardware.
|
@@ -94,20 +97,26 @@ class FrameworkQuantizationCapabilities(ImmutableClass):
|
|
94
97
|
"""
|
95
98
|
return self.op_sets_to_layers.get_layers_by_op(op)
|
96
99
|
|
97
|
-
def get_fusing_patterns(self) -> List[List[Any]]:
|
100
|
+
def get_fusing_patterns(self) -> List[Dict[List[Any], OpQuantizationConfig]]:
|
98
101
|
"""
|
99
102
|
|
100
|
-
Returns: List of patterns of layers/LayerFilterParams to fuse.
|
103
|
+
Returns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config.
|
101
104
|
|
102
105
|
"""
|
103
|
-
|
106
|
+
|
107
|
+
patterns = []
|
104
108
|
if self.tpc.fusing_patterns is None:
|
105
|
-
return
|
109
|
+
return patterns
|
110
|
+
|
106
111
|
for p in self.tpc.fusing_patterns:
|
112
|
+
res = []
|
107
113
|
ops = [self.get_layers_by_opset(x) for x in p.operator_groups]
|
108
114
|
res.extend(itertools.product(*ops))
|
109
|
-
return [list(x) for x in res]
|
110
115
|
|
116
|
+
fused_op_quant_config = getattr(p, FUSED_OP_QUANT_CONFIG, None)
|
117
|
+
patterns.extend({FUSED_LAYER_PATTERN: list(x), FUSED_OP_QUANT_CONFIG: fused_op_quant_config} for x in res)
|
118
|
+
|
119
|
+
return patterns
|
111
120
|
|
112
121
|
def get_info(self) -> Dict[str, Any]:
|
113
122
|
"""
|
@@ -230,7 +239,17 @@ class FrameworkQuantizationCapabilities(ImmutableClass):
|
|
230
239
|
def is_simd_padding(self) -> bool:
|
231
240
|
"""
|
232
241
|
|
233
|
-
Returns: Check if the
|
242
|
+
Returns: Check if the TPC defines that padding due to SIMD constrains occurs.
|
234
243
|
|
235
244
|
"""
|
236
245
|
return self.tpc.is_simd_padding
|
246
|
+
|
247
|
+
@property
|
248
|
+
def insert_preserving_quantizers(self) -> bool:
|
249
|
+
"""
|
250
|
+
|
251
|
+
Returns: Check if the TPC defines that a quantizer for quantization preserving operators should be added to the
|
252
|
+
constructed model.
|
253
|
+
|
254
|
+
"""
|
255
|
+
return self.tpc.insert_preserving_quantizers
|
@@ -12,12 +12,23 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
import json
|
15
16
|
from pathlib import Path
|
16
17
|
from typing import Union
|
17
18
|
|
18
19
|
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
|
19
20
|
from model_compression_toolkit.target_platform_capabilities.schema.schema_compatability import is_tpc_instance, \
|
20
|
-
tpc_to_current_schema_version, tpc_or_str_type
|
21
|
+
tpc_to_current_schema_version, tpc_or_str_type, get_schema_by_version
|
22
|
+
|
23
|
+
|
24
|
+
def _get_json_schema(tpc_json_path: str):
|
25
|
+
"""
|
26
|
+
Given a TPC json file path, extract the schema version from it, and return schema object matched to that
|
27
|
+
schema version.
|
28
|
+
"""
|
29
|
+
with open(tpc_json_path, 'r', encoding='utf-8') as f:
|
30
|
+
schema_version = str(json.load(f)["SCHEMA_VERSION"])
|
31
|
+
return get_schema_by_version(schema_version)
|
21
32
|
|
22
33
|
|
23
34
|
def _get_tpc_from_json(tpc_path: str) -> schema.TargetPlatformCapabilities:
|
@@ -40,7 +51,11 @@ def _get_tpc_from_json(tpc_path: str) -> schema.TargetPlatformCapabilities:
|
|
40
51
|
raise ValueError(f"Error reading the file '{tpc_path}': {e.strerror}.") from e
|
41
52
|
|
42
53
|
try:
|
43
|
-
|
54
|
+
# json_schema = _get_json_schema(tpc_path)
|
55
|
+
# tpc = json_schema.TargetPlatformCapabilities.parse_raw(data)
|
56
|
+
# return tpc_to_current_schema_version(tpc)
|
57
|
+
tpc = schema.TargetPlatformCapabilities.parse_raw(data)
|
58
|
+
return tpc_to_current_schema_version(tpc)
|
44
59
|
except ValueError as e:
|
45
60
|
raise ValueError(f"Invalid JSON for loading TargetPlatformCapabilities in '{tpc_path}': {e}.") from e
|
46
61
|
except Exception as e:
|
File without changes
|
{mct_nightly-2.3.0.20250514.602.dist-info → mct_nightly-2.3.0.20250516.613.dist-info}/top_level.txt
RENAMED
File without changes
|