mct-nightly 2.3.0.20250422.534__py3-none-any.whl → 2.3.0.20250424.534__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.3.0.20250422.534.dist-info → mct_nightly-2.3.0.20250424.534.dist-info}/METADATA +8 -8
- {mct_nightly-2.3.0.20250422.534.dist-info → mct_nightly-2.3.0.20250424.534.dist-info}/RECORD +12 -12
- {mct_nightly-2.3.0.20250422.534.dist-info → mct_nightly-2.3.0.20250424.534.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/base_graph.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +4 -11
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
- model_compression_toolkit/core/common/quantization/bit_width_config.py +6 -3
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +11 -1
- model_compression_toolkit/target_platform_capabilities/schema/v2.py +106 -4
- {mct_nightly-2.3.0.20250422.534.dist-info → mct_nightly-2.3.0.20250424.534.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.3.0.20250422.534.dist-info → mct_nightly-2.3.0.20250424.534.dist-info}/top_level.txt +0 -0
{mct_nightly-2.3.0.20250422.534.dist-info → mct_nightly-2.3.0.20250424.534.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mct-nightly
|
3
|
-
Version: 2.3.0.
|
3
|
+
Version: 2.3.0.20250424.534
|
4
4
|
Summary: A Model Compression Toolkit for neural networks
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
6
6
|
Classifier: License :: OSI Approved :: Apache Software License
|
@@ -34,7 +34,7 @@ Dynamic: summary
|
|
34
34
|
<div align="center" markdown="1">
|
35
35
|
<p>
|
36
36
|
<a href="https://sony.github.io/model_optimization/" target="_blank">
|
37
|
-
<img src="https://
|
37
|
+
<img src="https://raw.githubusercontent.com/sony/model_optimization/refs/heads/main/docsrc/images/mctHeader1-cropped.svg" width="1000"></a>
|
38
38
|
</p>
|
39
39
|
|
40
40
|
______________________________________________________________________
|
@@ -100,7 +100,7 @@ For further details, please see [Supported features and algorithms](#high-level-
|
|
100
100
|
<div align="center">
|
101
101
|
<p align="center">
|
102
102
|
|
103
|
-
<img src="https://
|
103
|
+
<img src="https://raw.githubusercontent.com/sony/model_optimization/refs/heads/main/docsrc/images/mctDiagram_clean.svg" width="800">
|
104
104
|
</p>
|
105
105
|
</div>
|
106
106
|
|
@@ -181,16 +181,16 @@ Currently, MCT is being tested on various Python, Pytorch and TensorFlow version
|
|
181
181
|
## <div align="center">Results</div>
|
182
182
|
|
183
183
|
<p align="center">
|
184
|
-
<img src="https://
|
185
|
-
<img src="https://
|
186
|
-
<img src="https://
|
187
|
-
<img src="https://
|
184
|
+
<img src="https://raw.githubusercontent.com/sony/model_optimization/refs/heads/main/docsrc/images/Classification.png" width="200">
|
185
|
+
<img src="https://raw.githubusercontent.com/sony/model_optimization/refs/heads/main/docsrc/images/SemSeg.png" width="200">
|
186
|
+
<img src="https://raw.githubusercontent.com/sony/model_optimization/refs/heads/main/docsrc/images/PoseEst.png" width="200">
|
187
|
+
<img src="https://raw.githubusercontent.com/sony/model_optimization/refs/heads/main/docsrc/images/ObjDet.png" width="200">
|
188
188
|
|
189
189
|
MCT can quantize an existing 32-bit floating-point model to an 8-bit fixed-point (or less) model without compromising accuracy.
|
190
190
|
Below is a graph of [MobileNetV2](https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html) accuracy on ImageNet vs average bit-width of weights (X-axis), using **single-precision** quantization, **mixed-precision** quantization, and mixed-precision quantization with GPTQ.
|
191
191
|
|
192
192
|
<p align="center">
|
193
|
-
<img src="https://
|
193
|
+
<img src="https://raw.githubusercontent.com/sony/model_optimization/refs/heads/main/docsrc/images/torch_mobilenetv2.png" width="800">
|
194
194
|
|
195
195
|
For more results, please see [1]
|
196
196
|
|
{mct_nightly-2.3.0.20250422.534.dist-info → mct_nightly-2.3.0.20250424.534.dist-info}/RECORD
RENAMED
@@ -1,5 +1,5 @@
|
|
1
|
-
mct_nightly-2.3.0.
|
2
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
mct_nightly-2.3.0.20250424.534.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
2
|
+
model_compression_toolkit/__init__.py,sha256=HmFpViJmJPVcQg5km-gnodqRgdt3lc5eqANLwoWrMqM,1557
|
3
3
|
model_compression_toolkit/constants.py,sha256=iJ6vfTjC2oFIZWt8wvHoxEw5YJi3yl0Hd4q30_8q0Zc,3958
|
4
4
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
5
5
|
model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
|
@@ -34,7 +34,7 @@ model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5
|
|
34
34
|
model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=W8qZejLwbm-lkvNF3GepNL3ypO10vFRxOxbq-o_rt_I,15479
|
35
35
|
model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=F0AaAUBpJ9JjHMB5H2LD9pdwTSWJK-Kqm9dQmGHX1Jo,7368
|
36
36
|
model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
|
37
|
-
model_compression_toolkit/core/common/graph/base_graph.py,sha256=
|
37
|
+
model_compression_toolkit/core/common/graph/base_graph.py,sha256=BSQpKy0BXoGX0G0bySTo72n2isTqvtpkbRYYa8-hPO4,41435
|
38
38
|
model_compression_toolkit/core/common/graph/base_node.py,sha256=AbUadAT581zelVcGcK9_--6CAGiht9qwkeWahwT3RzE,33389
|
39
39
|
model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
|
40
40
|
model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
|
@@ -75,10 +75,10 @@ model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,s
|
|
75
75
|
model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=S1ChgxtUjzXJufNWyRbKoNdyNC6fGUjPeComDMx8ZCo,9479
|
76
76
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
77
77
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
|
78
|
-
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=
|
78
|
+
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=D2sNbTPMDsDyUE18NUpVJN27AgdwwhpdOJ8UMLmhdPA,40420
|
79
79
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=QQwtl08DiDxUOQGpYPnek_RlZjWm1Ky7tL2ESHXMK78,4050
|
80
80
|
model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
81
|
-
model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=
|
81
|
+
model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=6Z6nQL9UH7B8dbcUR0cuCTEYFOKZAlvOb-SCk_cAZFA,6670
|
82
82
|
model_compression_toolkit/core/common/network_editors/__init__.py,sha256=vZmu55bYqiaOQs3AjfwWDXHmuKZcLHt-wm7uR5fPEqg,1307
|
83
83
|
model_compression_toolkit/core/common/network_editors/actions.py,sha256=nid0_j-Cn10xvmztT8yCKW_6uA7JEnom9SW9syx7wc0,19594
|
84
84
|
model_compression_toolkit/core/common/network_editors/edit_network.py,sha256=dfgawi-nB0ocAJ0xcGn9E-Zv203oUnQLuMiXpX8vTgA,1748
|
@@ -101,7 +101,7 @@ model_compression_toolkit/core/common/pruning/mask/__init__.py,sha256=huHoBUcKNB
|
|
101
101
|
model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py,sha256=77DB1vqq_gHwbUjeCHRaq1Q-V4wEtdVdwkGezcZgToA,5021
|
102
102
|
model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256=_LcDAxLeC5I0KdMHS8jib5XxIKO2ZLavXYuSMIPIQBo,5868
|
103
103
|
model_compression_toolkit/core/common/quantization/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
104
|
-
model_compression_toolkit/core/common/quantization/bit_width_config.py,sha256=
|
104
|
+
model_compression_toolkit/core/common/quantization/bit_width_config.py,sha256=034kgwe0ydyLXsV83KqxKyyHkoUQH06ai0leLyg0p8I,13019
|
105
105
|
model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=lyWPvnoX8BmulhLKR20r5gT2_Yan7P40d8EcgDhErPk,4905
|
106
106
|
model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
|
107
107
|
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=uH45Uq3Tp9FIyMynex_WY2_y-Kv8LuPw2XXZydnpW5A,1649
|
@@ -112,7 +112,7 @@ model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,
|
|
112
112
|
model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=7eG7dl1TcbdnHwgmvyjarxLs0o6Lw_9VAjXAm4rsiBk,3791
|
113
113
|
model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
|
114
114
|
model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
|
115
|
-
model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=
|
115
|
+
model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=_hhRb5eeFwbtPddu2xdLi7qK1RsxoR7UHUfjO0ICM3Q,30586
|
116
116
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
|
117
117
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=_m-XkEMJMHf0gYwVIXAoHVjdRa2NXt_gYdwBlw76ZR8,24031
|
118
118
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=RL-PklAjGyC-26anSt8fU07a6pB_LBQFQy9o4e9giN0,8739
|
@@ -439,7 +439,7 @@ model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema
|
|
439
439
|
model_compression_toolkit/target_platform_capabilities/schema/schema_compatability.py,sha256=TtMPbiibV6Hk53nl5Y_ctfpI6mSbd8VVH9fxnv5j9eM,4430
|
440
440
|
model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py,sha256=vBkXxVJagm9JKB9cdm4Pvi7u_luriXUjvNn0-m8Zr0k,4653
|
441
441
|
model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=oWKNQnnz04kmijmdWtRyXgVXbJ6BG_V_bUBz_MfUM94,27116
|
442
|
-
model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=
|
442
|
+
model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=ncKPHVNyq_Yy_F7HOVDHT68EDfRnaB9yVnEP3C89GJk,10627
|
443
443
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/__init__.py,sha256=XjNws3zoiJkeH4ixKqrLA5xBvpv5rq31qX7wYQjNpZM,1447
|
444
444
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py,sha256=HJ8uc3PFfyxg-WpVXPBg4mGaox8Z9bRqtQNbRfIyAk4,3745
|
445
445
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=Ehwpd_sL6zxmJFpJugOdN9uNxNX05nijvOCilNfHnFs,7162
|
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
528
528
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
|
529
529
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
530
530
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
531
|
-
mct_nightly-2.3.0.
|
532
|
-
mct_nightly-2.3.0.
|
533
|
-
mct_nightly-2.3.0.
|
534
|
-
mct_nightly-2.3.0.
|
531
|
+
mct_nightly-2.3.0.20250424.534.dist-info/METADATA,sha256=wMqM0-nGTBa189h4xpdr-iY2-QUxlm1vVnXkB7ogmzU,25560
|
532
|
+
mct_nightly-2.3.0.20250424.534.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
|
533
|
+
mct_nightly-2.3.0.20250424.534.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
534
|
+
mct_nightly-2.3.0.20250424.534.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.3.0.
|
30
|
+
__version__ = "2.3.0.20250424.000534"
|
@@ -754,7 +754,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
754
754
|
"""
|
755
755
|
while node.is_quantization_preserving():
|
756
756
|
prev_nodes = self.get_prev_nodes(node)
|
757
|
-
assert len(prev_nodes) == 1, "Activation preserving node should have only 1 input."
|
757
|
+
assert len(prev_nodes) == 1, f"Activation preserving node should have only 1 input, but node {node.name} has {len(prev_nodes)} inputs."
|
758
758
|
node = prev_nodes[0]
|
759
759
|
return node
|
760
760
|
|
@@ -51,7 +51,6 @@ class BitwidthMode(Enum):
|
|
51
51
|
single-precision nodes. To compute custom single precision configuration, use QCustom.
|
52
52
|
"""
|
53
53
|
Float = auto()
|
54
|
-
Q8Bit = auto()
|
55
54
|
QMaxBit = auto()
|
56
55
|
QMinBit = auto()
|
57
56
|
QCustom = auto()
|
@@ -573,7 +572,7 @@ class ResourceUtilizationCalculator:
|
|
573
572
|
not (a_node.is_activation_quantization_enabled() or n.is_weights_quantization_enabled(kernel_attr))):
|
574
573
|
return 0
|
575
574
|
|
576
|
-
act_qc =
|
575
|
+
act_qc = self._extract_qc(a_node, act_qcs)
|
577
576
|
a_nbits = self._get_activation_nbits(a_node, bitwidth_mode, act_qc)
|
578
577
|
w_nbits = self._get_weight_nbits(n, kernel_attr, bitwidth_mode, w_qc)
|
579
578
|
node_bops = a_nbits * w_nbits * node_mac
|
@@ -708,23 +707,20 @@ class ResourceUtilizationCalculator:
|
|
708
707
|
Returns:
|
709
708
|
Activation bit-width.
|
710
709
|
"""
|
710
|
+
n = self.graph.retrieve_preserved_quantization_node(n)
|
711
711
|
if act_qc:
|
712
712
|
assert bitwidth_mode == BitwidthMode.QCustom
|
713
713
|
return act_qc.activation_n_bits if act_qc.quant_mode == ActivationQuantizationMode.QUANT else FLOAT_BITWIDTH
|
714
714
|
|
715
|
-
if bitwidth_mode == BitwidthMode.Float or not
|
716
|
-
n.is_quantization_preserving()):
|
715
|
+
if bitwidth_mode == BitwidthMode.Float or not n.is_activation_quantization_enabled():
|
717
716
|
return FLOAT_BITWIDTH
|
718
717
|
|
719
|
-
if bitwidth_mode == BitwidthMode.Q8Bit:
|
720
|
-
return 8
|
721
|
-
|
722
718
|
if bitwidth_mode in self._bitwidth_mode_fn:
|
723
719
|
candidates_nbits = [c.activation_quantization_cfg.activation_n_bits for c in n.candidates_quantization_cfg]
|
724
720
|
return self._bitwidth_mode_fn[bitwidth_mode](candidates_nbits)
|
725
721
|
|
726
722
|
if bitwidth_mode in [BitwidthMode.QCustom, BitwidthMode.QDefaultSP]:
|
727
|
-
qcs =
|
723
|
+
qcs = n.get_unique_activation_candidates()
|
728
724
|
if len(qcs) != 1:
|
729
725
|
raise ValueError(f'Could not retrieve the activation quantization candidate for node {n} '
|
730
726
|
f'as it has {len(qcs)}!=1 unique candidates.')
|
@@ -760,9 +756,6 @@ class ResourceUtilizationCalculator:
|
|
760
756
|
if bitwidth_mode == BitwidthMode.Float or not n.is_weights_quantization_enabled(w_attr):
|
761
757
|
return FLOAT_BITWIDTH
|
762
758
|
|
763
|
-
if bitwidth_mode == BitwidthMode.Q8Bit:
|
764
|
-
return 8
|
765
|
-
|
766
759
|
node_qcs = n.get_unique_weights_candidates(w_attr)
|
767
760
|
w_qcs = [qc.weights_quantization_cfg.get_attr_config(w_attr) for qc in node_qcs]
|
768
761
|
if bitwidth_mode in cls._bitwidth_mode_fn:
|
@@ -16,7 +16,7 @@ from collections import defaultdict
|
|
16
16
|
|
17
17
|
import numpy as np
|
18
18
|
from pulp import *
|
19
|
-
from typing import Dict, Tuple, Any
|
19
|
+
from typing import Dict, Tuple, Any, List
|
20
20
|
|
21
21
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget
|
22
22
|
|
@@ -20,6 +20,8 @@ from model_compression_toolkit.core.common.matchers.node_matcher import BaseNode
|
|
20
20
|
from model_compression_toolkit.logger import Logger
|
21
21
|
|
22
22
|
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
|
23
|
+
from model_compression_toolkit.target_platform_capabilities.constants import POS_ATTR
|
24
|
+
|
23
25
|
|
24
26
|
@dataclass
|
25
27
|
class ManualBitWidthSelection:
|
@@ -221,9 +223,10 @@ class BitWidthConfig:
|
|
221
223
|
if isinstance(attr_str, str) and isinstance(manual_bit_width_selection.attr, str):
|
222
224
|
if attr_str.find(manual_bit_width_selection.attr) != -1:
|
223
225
|
attr.append(attr_str)
|
224
|
-
|
225
|
-
|
226
|
-
|
226
|
+
# this is a positional attribute, so it needs to be handled separately.
|
227
|
+
# Search manual_bit_width_selection's attribute that contain the POS_ATTR string.
|
228
|
+
elif isinstance(attr_str, int) and POS_ATTR in manual_bit_width_selection.attr:
|
229
|
+
attr.append(POS_ATTR)
|
227
230
|
if len(attr) == 0:
|
228
231
|
Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.')
|
229
232
|
|
@@ -67,7 +67,7 @@ def set_quantization_configuration_to_graph(graph: Graph,
|
|
67
67
|
nodes_to_manipulate_activation_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph)
|
68
68
|
nodes_to_manipulate_weights_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_weights_bit_widths(graph)
|
69
69
|
|
70
|
-
for n in graph.
|
70
|
+
for n in graph.get_topo_sorted_nodes():
|
71
71
|
manual_bit_width_override = {ACTIVATION: nodes_to_manipulate_activation_bit_widths.get(n),
|
72
72
|
WEIGHTS: nodes_to_manipulate_weights_bit_widths.get(n)}
|
73
73
|
set_quantization_configs_to_node(node=n,
|
@@ -199,6 +199,16 @@ def set_quantization_configs_to_node(node: BaseNode,
|
|
199
199
|
if candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.QUANT and \
|
200
200
|
not node.get_has_activation():
|
201
201
|
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
202
|
+
elif candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.PRESERVE_QUANT:
|
203
|
+
prev_nodes = graph.get_prev_nodes(node)
|
204
|
+
if len(prev_nodes) != 1:
|
205
|
+
# Preserving the quantization of more than 1 previous node is ambiguous, so disable it.
|
206
|
+
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because it has more than 1 input activations.")
|
207
|
+
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
208
|
+
elif not prev_nodes[0].is_quantization_preserving() or not prev_nodes[0].is_activation_quantization_enabled():
|
209
|
+
# Preserving the quantization of an unquantized node isn't possible, so disable it.
|
210
|
+
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because previous node activation quantization is disabled.")
|
211
|
+
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
202
212
|
|
203
213
|
|
204
214
|
def create_node_activation_qc(qc: QuantizationConfig,
|
@@ -14,9 +14,9 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
import pprint
|
16
16
|
from enum import Enum
|
17
|
-
from typing import Dict, Any, Tuple, Optional
|
17
|
+
from typing import Dict, Any, Union, Tuple, Optional, Annotated
|
18
18
|
|
19
|
-
from pydantic import BaseModel, root_validator, model_validator, ConfigDict
|
19
|
+
from pydantic import BaseModel, Field, root_validator, model_validator, ConfigDict
|
20
20
|
|
21
21
|
from mct_quantizers import QuantizationMethod
|
22
22
|
from model_compression_toolkit.constants import FLOAT_BITWIDTH
|
@@ -29,8 +29,7 @@ from model_compression_toolkit.target_platform_capabilities.schema.v1 import (
|
|
29
29
|
TargetPlatformModelComponent,
|
30
30
|
OperatorsSetBase,
|
31
31
|
OperatorsSet,
|
32
|
-
OperatorSetGroup
|
33
|
-
Fusing)
|
32
|
+
OperatorSetGroup)
|
34
33
|
|
35
34
|
|
36
35
|
class OperatorSetNames(str, Enum):
|
@@ -98,6 +97,109 @@ class OperatorSetNames(str, Enum):
|
|
98
97
|
return [v.value for v in cls]
|
99
98
|
|
100
99
|
|
100
|
+
class Fusing(TargetPlatformModelComponent):
|
101
|
+
"""
|
102
|
+
Fusing defines a tuple of operators that should be combined and treated as a single operator,
|
103
|
+
hence no quantization is applied between them.
|
104
|
+
|
105
|
+
Attributes:
|
106
|
+
operator_groups (Tuple[Union[OperatorsSet, OperatorSetGroup], ...]): A tuple of operator groups,
|
107
|
+
each being either an OperatorSetGroup or an OperatorsSet.
|
108
|
+
fuse_op_quantization_config (Optional[OpQuantizationConfig]): The quantization configuration for the fused operator.
|
109
|
+
name (Optional[str]): The name for the Fusing instance. If not provided, it is generated from the operator groups' names.
|
110
|
+
"""
|
111
|
+
operator_groups: Tuple[Annotated[Union[OperatorsSet, OperatorSetGroup], Field(discriminator='type')], ...]
|
112
|
+
fuse_op_quantization_config: Optional[OpQuantizationConfig] = None
|
113
|
+
name: Optional[str] = None # Will be set in the validator if not given.
|
114
|
+
|
115
|
+
model_config = ConfigDict(frozen=True)
|
116
|
+
|
117
|
+
@model_validator(mode="before")
|
118
|
+
def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
119
|
+
"""
|
120
|
+
Validate the operator_groups and set the name by concatenating operator group names.
|
121
|
+
|
122
|
+
Args:
|
123
|
+
values (Dict[str, Any]): Input data.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
Dict[str, Any]: Modified input data with 'name' set.
|
127
|
+
"""
|
128
|
+
operator_groups = values.get('operator_groups')
|
129
|
+
|
130
|
+
# When loading from JSON, lists are returned. If the value is a list, convert it to a tuple.
|
131
|
+
if isinstance(operator_groups, list):
|
132
|
+
values['operator_groups'] = tuple(operator_groups)
|
133
|
+
|
134
|
+
if values.get('name') is None:
|
135
|
+
# Generate the concatenated name from the operator groups
|
136
|
+
concatenated_name = "_".join([
|
137
|
+
op.name.value if isinstance(op.name, OperatorSetNames) else op.name
|
138
|
+
for op in values['operator_groups']
|
139
|
+
])
|
140
|
+
values['name'] = concatenated_name
|
141
|
+
|
142
|
+
return values
|
143
|
+
|
144
|
+
@model_validator(mode="after")
|
145
|
+
def validate_after_initialization(cls, model: 'Fusing') -> Any:
|
146
|
+
"""
|
147
|
+
Perform validation after the model has been instantiated.
|
148
|
+
Ensures that there are at least two operator groups.
|
149
|
+
"""
|
150
|
+
if len(model.operator_groups) < 2:
|
151
|
+
Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
|
152
|
+
|
153
|
+
return model
|
154
|
+
|
155
|
+
def contains(self, other: Any) -> bool:
|
156
|
+
"""
|
157
|
+
Determines if the current Fusing instance contains another Fusing instance.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
other (Any): The other Fusing instance to check against.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
bool: True if the other Fusing instance is contained within this one, False otherwise.
|
164
|
+
"""
|
165
|
+
if not isinstance(other, Fusing):
|
166
|
+
return False
|
167
|
+
|
168
|
+
# Check for containment by comparing operator groups
|
169
|
+
for i in range(len(self.operator_groups) - len(other.operator_groups) + 1):
|
170
|
+
for j in range(len(other.operator_groups)):
|
171
|
+
if self.operator_groups[i + j] != other.operator_groups[j] and not (
|
172
|
+
isinstance(self.operator_groups[i + j], OperatorSetGroup) and (
|
173
|
+
other.operator_groups[j] in self.operator_groups[i + j].operators_set)):
|
174
|
+
break
|
175
|
+
else:
|
176
|
+
# If all checks pass, the other Fusing instance is contained
|
177
|
+
return True
|
178
|
+
# Other Fusing instance is not contained
|
179
|
+
return False
|
180
|
+
|
181
|
+
def get_info(self) -> Union[Dict[str, str], str]:
|
182
|
+
"""
|
183
|
+
Retrieves information about the Fusing instance, including its name and the sequence of operator groups.
|
184
|
+
|
185
|
+
Returns:
|
186
|
+
Union[Dict[str, str], str]: A dictionary with the Fusing instance's name as the key
|
187
|
+
and the sequence of operator groups as the value,
|
188
|
+
or just the sequence of operator groups if no name is set.
|
189
|
+
"""
|
190
|
+
if self.name is not None:
|
191
|
+
return {
|
192
|
+
self.name: ' -> '.join([
|
193
|
+
x.name.value if isinstance(x.name, OperatorSetNames) else x.name
|
194
|
+
for x in self.operator_groups
|
195
|
+
])
|
196
|
+
}
|
197
|
+
return ' -> '.join([
|
198
|
+
x.name.value if isinstance(x.name, OperatorSetNames) else x.name
|
199
|
+
for x in self.operator_groups
|
200
|
+
])
|
201
|
+
|
202
|
+
|
101
203
|
class TargetPlatformCapabilities(BaseModel):
|
102
204
|
"""
|
103
205
|
Represents the hardware configuration used for quantized model inference.
|
File without changes
|
{mct_nightly-2.3.0.20250422.534.dist-info → mct_nightly-2.3.0.20250424.534.dist-info}/top_level.txt
RENAMED
File without changes
|