mct-nightly 2.1.0.20240811.503__py3-none-any.whl → 2.1.0.20240812.432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.1.0.20240811.503.dist-info → mct_nightly-2.1.0.20240812.432.dist-info}/METADATA +1 -1
- {mct_nightly-2.1.0.20240811.503.dist-info → mct_nightly-2.1.0.20240812.432.dist-info}/RECORD +16 -15
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/__init__.py +1 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -0
- model_compression_toolkit/core/common/quantization/bit_width_config.py +91 -0
- model_compression_toolkit/core/common/quantization/core_config.py +8 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +88 -22
- model_compression_toolkit/core/graph_prep_runner.py +16 -9
- model_compression_toolkit/core/runner.py +1 -0
- model_compression_toolkit/pruning/keras/pruning_facade.py +1 -0
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +1 -0
- {mct_nightly-2.1.0.20240811.503.dist-info → mct_nightly-2.1.0.20240812.432.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.1.0.20240811.503.dist-info → mct_nightly-2.1.0.20240812.432.dist-info}/WHEEL +0 -0
- {mct_nightly-2.1.0.20240811.503.dist-info → mct_nightly-2.1.0.20240812.432.dist-info}/top_level.txt +0 -0
{mct_nightly-2.1.0.20240811.503.dist-info → mct_nightly-2.1.0.20240812.432.dist-info}/RECORD
RENAMED
@@ -1,13 +1,13 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=dFfNYHNevEMx3n6CPcaXHcIQxE4Nlhrsckn2CtIDLiY,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=0qrEGjX36Oo7Lt8mR0LD2aSe2xA7gKrhkzBGp7g5eiA,4345
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
5
5
|
model_compression_toolkit/metadata.py,sha256=UtXS5ClK-qPoxGRuytlDGZSzgLo911dMni2EFRcg6io,3623
|
6
|
-
model_compression_toolkit/core/__init__.py,sha256=
|
6
|
+
model_compression_toolkit/core/__init__.py,sha256=tnDtL9KmT0vsOU27SsJ19TKDEbIH-tXYeGxTo5YnNUM,2077
|
7
7
|
model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
|
8
|
-
model_compression_toolkit/core/graph_prep_runner.py,sha256=
|
8
|
+
model_compression_toolkit/core/graph_prep_runner.py,sha256=7-b7Jd5jBVaXOWg5nSqbEyzBtdaGDbCxs8aqMV6GZ6I,11287
|
9
9
|
model_compression_toolkit/core/quantization_prep_runner.py,sha256=K9eJ7VbB_rpeyxX4yEnorOmSxFW3DkvofzxS6QI8Hp8,6454
|
10
|
-
model_compression_toolkit/core/runner.py,sha256=
|
10
|
+
model_compression_toolkit/core/runner.py,sha256=XQDNJirZkVJ_FXP72d7tbVc_Tr3Jw0Eqm_kxNHW8kPs,13636
|
11
11
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
12
12
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
13
13
|
model_compression_toolkit/core/common/framework_implementation.py,sha256=kSg2f7wS7e2EyvX6y0eKfNTTFvVFVrB8lvldJvcPvN8,20724
|
@@ -70,7 +70,7 @@ model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,s
|
|
70
70
|
model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=KifDMbm7qkSfvSl6pcZzQ82naIXzeKL6aT-VsvWZYyc,7901
|
71
71
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
72
72
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=HILF7CIn-GYPvPmTFyvjWLhuLDwSGwdBcAaKFgVYrwk,4745
|
73
|
-
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=
|
73
|
+
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=3ZOI-RNp5faT-U2Og7rLW9EKwBB6ooa7-RwSsWJmquo,14022
|
74
74
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py,sha256=ttc8wPa_9LZansutQ2f1ss-RTzgTv739wy3qsdLzyyk,4217
|
75
75
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py,sha256=QhuqaECEGLnYC08iD6-2XXcU7NXbPzYf1sQcjYlGak8,1682
|
76
76
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py,sha256=WC1EHoNuo_lrzy4NRhGJ1cgmJ2IsFsbmP86mrVO3AVA,21506
|
@@ -98,8 +98,9 @@ model_compression_toolkit/core/common/pruning/mask/__init__.py,sha256=huHoBUcKNB
|
|
98
98
|
model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py,sha256=APY8BsM9B7ZxVCH6n1xs9fSCTB_A9ou9gHrCQl1DOdI,5131
|
99
99
|
model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256=4ohJrJHNzZk5uMnZEYkwLx2TDGzkh5kRhLGNVYNC6dc,5978
|
100
100
|
model_compression_toolkit/core/common/quantization/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
101
|
+
model_compression_toolkit/core/common/quantization/bit_width_config.py,sha256=r6VQXgyJxX_AM1JTzv-sTcrvCTnktBfOkVP20RllNmk,4586
|
101
102
|
model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=yU-Cr6S4wOSkDk57iH2NVe-WII0whOhLryejkomCOt4,4940
|
102
|
-
model_compression_toolkit/core/common/quantization/core_config.py,sha256=
|
103
|
+
model_compression_toolkit/core/common/quantization/core_config.py,sha256=f0uSuY9mX-vLX_1s2DemPARQlAXmLPKJKPtCArz3pZI,2670
|
103
104
|
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=8G8SpE_4rb8xBp8d6mMq8R_OnXJ_1oxB2g-Lxk9EJCM,1691
|
104
105
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
|
105
106
|
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=YycYN8_JMzvSR3pTVm5dT5x4zP3yBHn0Z9agnwrvOKI,26395
|
@@ -108,7 +109,7 @@ model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,
|
|
108
109
|
model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=MwIOBZ4BlZSTIOG75PDvlI3JmZ6t8YjPc1VP9Adei60,3847
|
109
110
|
model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
|
110
111
|
model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
|
111
|
-
model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=
|
112
|
+
model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=0pZVO4wsNP815R9ZOd5ojC_OdNEeKkxYKdjggsqsZKg,17750
|
112
113
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
|
113
114
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=Fd_gxr5js-mqEwucaRR1CQAZ1W_wna19L1gAPeOzxRQ,23610
|
114
115
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=RL-PklAjGyC-26anSt8fU07a6pB_LBQFQy9o4e9giN0,8739
|
@@ -380,9 +381,9 @@ model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/__init__.py,sha256
|
|
380
381
|
model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py,sha256=6uxq_w62jn8DDOt9T7VtA6jZ8jTAPcbTufKFOYpVUm4,8768
|
381
382
|
model_compression_toolkit/pruning/__init__.py,sha256=lQMZS8G0pvR1LVi53nnJHNXgLNTan_MWMdwsVxhjrow,1106
|
382
383
|
model_compression_toolkit/pruning/keras/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
|
383
|
-
model_compression_toolkit/pruning/keras/pruning_facade.py,sha256=
|
384
|
+
model_compression_toolkit/pruning/keras/pruning_facade.py,sha256=vDpY97xirGF-o5XB6HvG_y2bL4LzfiTW3cPURTvaeKI,8707
|
384
385
|
model_compression_toolkit/pruning/pytorch/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
|
385
|
-
model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=
|
386
|
+
model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=1uo5jWgbFNNhRbfb8da5REymMUdLJ3JidR8aAMXCBoE,9493
|
386
387
|
model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
|
387
388
|
model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
|
388
389
|
model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
@@ -507,7 +508,7 @@ model_compression_toolkit/xquant/common/core_report_generator.py,sha256=GHnJJpK6
|
|
507
508
|
model_compression_toolkit/xquant/common/dataset_utils.py,sha256=91uXF9UwxdY7BvUT0FNkFm8a69c8oK8Xdl-y7lbuJxk,1649
|
508
509
|
model_compression_toolkit/xquant/common/framework_report_utils.py,sha256=YE49232ESflW6ZaUABF1pk_GGHBxa_F1X5oRN2Jogys,3734
|
509
510
|
model_compression_toolkit/xquant/common/model_analyzer.py,sha256=T_8OetIQNqR0nkfSatWsEceXSPYpHfYjboBPIyR03-w,3953
|
510
|
-
model_compression_toolkit/xquant/common/model_folding_utils.py,sha256=
|
511
|
+
model_compression_toolkit/xquant/common/model_folding_utils.py,sha256=7XMNmsngJgCPVjsuMNt6g4hzhkviB45qUmNRe9jQE7g,4815
|
511
512
|
model_compression_toolkit/xquant/common/similarity_calculator.py,sha256=yCs_vlOThLzq7z-u2PkcEErLj7N7qCBPpRa6_5h34J8,10460
|
512
513
|
model_compression_toolkit/xquant/common/similarity_functions.py,sha256=Atah1otdX9oUUch2JK-p-e291QHtkP_c4DfLG9WWo1Y,2935
|
513
514
|
model_compression_toolkit/xquant/common/tensorboard_utils.py,sha256=85ABGQGKPZzctyZCHLazK0GxZ2ZUtQA3hZ_9fPiuMs0,6533
|
@@ -526,8 +527,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
526
527
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
|
527
528
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
528
529
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=yjghWXxqOtT-QXoXBOuJyh45yUpFI0pKjdDegum2i68,9705
|
529
|
-
mct_nightly-2.1.0.
|
530
|
-
mct_nightly-2.1.0.
|
531
|
-
mct_nightly-2.1.0.
|
532
|
-
mct_nightly-2.1.0.
|
533
|
-
mct_nightly-2.1.0.
|
530
|
+
mct_nightly-2.1.0.20240812.432.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
531
|
+
mct_nightly-2.1.0.20240812.432.dist-info/METADATA,sha256=2qqUeeA_e60PJ8S7cITQK2UlEOxMvWKKncdTK-H3v9E,19718
|
532
|
+
mct_nightly-2.1.0.20240812.432.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
533
|
+
mct_nightly-2.1.0.20240812.432.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
534
|
+
mct_nightly-2.1.0.20240812.432.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.1.0.
|
30
|
+
__version__ = "2.1.0.20240812.000432"
|
@@ -19,6 +19,7 @@ from model_compression_toolkit.core.common.quantization.debug_config import Debu
|
|
19
19
|
from model_compression_toolkit.core.common.quantization import quantization_config
|
20
20
|
from model_compression_toolkit.core.common.mixed_precision import mixed_precision_quantization_config
|
21
21
|
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, QuantizationErrorMethod, DEFAULTCONFIG
|
22
|
+
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
22
23
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
23
24
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
24
25
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
|
@@ -67,6 +67,7 @@ def compute_resource_utilization_data(in_model: Any,
|
|
67
67
|
fw_info,
|
68
68
|
fw_impl,
|
69
69
|
tpc,
|
70
|
+
bit_width_config=core_config.bit_width_config,
|
70
71
|
mixed_precision_enable=mixed_precision_enable)
|
71
72
|
|
72
73
|
# Compute parameters sum
|
@@ -227,6 +228,7 @@ def requires_mixed_precision(in_model: Any,
|
|
227
228
|
fw_info,
|
228
229
|
fw_impl,
|
229
230
|
tpc,
|
231
|
+
bit_width_config=core_config.bit_width_config,
|
230
232
|
mixed_precision_enable=False)
|
231
233
|
# Compute max weights memory in bytes
|
232
234
|
weights_memory_by_layer_bytes, _ = compute_nodes_weights_params(transformed_graph, fw_info)
|
@@ -0,0 +1,91 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from typing import List, Union, Dict
|
16
|
+
|
17
|
+
from model_compression_toolkit.core.common import Graph
|
18
|
+
from model_compression_toolkit.core.common.matchers.node_matcher import BaseNodeMatcher
|
19
|
+
from model_compression_toolkit.logger import Logger
|
20
|
+
|
21
|
+
|
22
|
+
class ManualBitWidthSelection:
|
23
|
+
"""
|
24
|
+
Class to encapsulate the manual bit width selection configuration for a specific filter.
|
25
|
+
|
26
|
+
Attributes:
|
27
|
+
filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation.
|
28
|
+
bit_width (int): The bit width to be applied to the selected nodes.
|
29
|
+
"""
|
30
|
+
def __init__(self,
|
31
|
+
filter: BaseNodeMatcher,
|
32
|
+
bit_width: int):
|
33
|
+
self.filter = filter
|
34
|
+
self.bit_width = bit_width
|
35
|
+
|
36
|
+
|
37
|
+
class BitWidthConfig:
|
38
|
+
"""
|
39
|
+
Class to manage manual bit-width configurations.
|
40
|
+
|
41
|
+
Attributes:
|
42
|
+
manual_activation_bit_width_selection_list (List[ManualBitWidthSelection]): A list of ManualBitWidthSelection objects defining manual bit-width configurations.
|
43
|
+
"""
|
44
|
+
def __init__(self,
|
45
|
+
manual_activation_bit_width_selection_list: List[ManualBitWidthSelection] = None):
|
46
|
+
self.manual_activation_bit_width_selection_list = [] if manual_activation_bit_width_selection_list is None else manual_activation_bit_width_selection_list
|
47
|
+
|
48
|
+
def __repr__(self):
|
49
|
+
# Used for debugging, thus no cover.
|
50
|
+
return str(self.__dict__) # pragma: no cover
|
51
|
+
|
52
|
+
def set_manual_activation_bit_width(self,
|
53
|
+
filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
|
54
|
+
bit_widths: Union[List[int], int]):
|
55
|
+
"""
|
56
|
+
Add a manual bit-width selection to the configuration.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
filter (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
|
60
|
+
bit_width (Union[List[int], int]): The bit widths to be applied to the selected nodes.
|
61
|
+
If a single value is given it will be applied to all the filters
|
62
|
+
"""
|
63
|
+
filters = [filters] if not isinstance(filters, list) else filters
|
64
|
+
bit_widths = [bit_widths] if not isinstance(bit_widths, list) else bit_widths
|
65
|
+
if len(bit_widths) > 1 and len(bit_widths) != len(filters):
|
66
|
+
Logger.critical(f"Configuration Error: The number of provided bit_width values {len(bit_widths)} "
|
67
|
+
f"must match the number of filters {len(filters)}, or a single bit_width value "
|
68
|
+
f"should be provided for all filters.")
|
69
|
+
elif len(bit_widths) == 1 and len(filters) > 1:
|
70
|
+
bit_widths = [bit_widths[0] for f in filters]
|
71
|
+
for bit_width, filter in zip (bit_widths, filters):
|
72
|
+
self.manual_activation_bit_width_selection_list += [ManualBitWidthSelection(filter, bit_width)]
|
73
|
+
|
74
|
+
def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> Dict:
|
75
|
+
"""
|
76
|
+
Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
graph (Graph): The graph containing the nodes to be filtered and manipulated.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
Dict: A dictionary mapping nodes to their new bit-widths.
|
83
|
+
"""
|
84
|
+
nodes_to_change_bit_width = {}
|
85
|
+
for manual_bit_width_selection in self.manual_activation_bit_width_selection_list:
|
86
|
+
filtered_nodes = graph.filter(manual_bit_width_selection.filter)
|
87
|
+
if len(filtered_nodes) == 0:
|
88
|
+
Logger.critical(f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
|
89
|
+
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
|
90
|
+
nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width for n in filtered_nodes})
|
91
|
+
return nodes_to_change_bit_width
|
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
15
16
|
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
16
17
|
from model_compression_toolkit.core.common.quantization.debug_config import DebugConfig
|
17
18
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
|
@@ -22,9 +23,10 @@ class CoreConfig:
|
|
22
23
|
A class to hold the configurations classes of the MCT-core.
|
23
24
|
"""
|
24
25
|
def __init__(self,
|
25
|
-
quantization_config: QuantizationConfig =
|
26
|
+
quantization_config: QuantizationConfig = None,
|
26
27
|
mixed_precision_config: MixedPrecisionQuantizationConfig = None,
|
27
|
-
|
28
|
+
bit_width_config: BitWidthConfig = None,
|
29
|
+
debug_config: DebugConfig = None
|
28
30
|
):
|
29
31
|
"""
|
30
32
|
|
@@ -32,10 +34,12 @@ class CoreConfig:
|
|
32
34
|
quantization_config (QuantizationConfig): Config for quantization.
|
33
35
|
mixed_precision_config (MixedPrecisionQuantizationConfig): Config for mixed precision quantization.
|
34
36
|
If None, a default MixedPrecisionQuantizationConfig is used.
|
37
|
+
bit_width_config (BitWidthConfig): Config for manual bit-width selection.
|
35
38
|
debug_config (DebugConfig): Config for debugging and editing the network quantization process.
|
36
39
|
"""
|
37
|
-
self.quantization_config = quantization_config
|
38
|
-
self.
|
40
|
+
self.quantization_config = QuantizationConfig() if quantization_config is None else quantization_config
|
41
|
+
self.bit_width_config = BitWidthConfig() if bit_width_config is None else bit_width_config
|
42
|
+
self.debug_config = DebugConfig() if debug_config is None else debug_config
|
39
43
|
|
40
44
|
if mixed_precision_config is None:
|
41
45
|
self.mixed_precision_config = MixedPrecisionQuantizationConfig()
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
|
17
17
|
import copy
|
18
|
-
from typing import List, Tuple
|
18
|
+
from typing import List, Tuple, Optional
|
19
19
|
|
20
|
+
from mct_quantizers.common.constants import ACTIVATION_N_BITS
|
20
21
|
from model_compression_toolkit.core.common import BaseNode
|
22
|
+
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
21
23
|
from model_compression_toolkit.logger import Logger
|
22
24
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
23
25
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
@@ -37,19 +39,21 @@ from model_compression_toolkit.target_platform_capabilities.target_platform.op_q
|
|
37
39
|
|
38
40
|
def set_quantization_configuration_to_graph(graph: Graph,
|
39
41
|
quant_config: QuantizationConfig,
|
42
|
+
bit_width_config: BitWidthConfig = None,
|
40
43
|
mixed_precision_enable: bool = False,
|
41
44
|
running_gptq: bool = False) -> Graph:
|
42
45
|
"""
|
43
46
|
Add quantization configuration for each graph node.
|
44
47
|
|
45
48
|
Args:
|
46
|
-
graph: Graph for which to add quantization info to each node.
|
47
|
-
quant_config: Quantization configuration containing parameters for how the graph should be quantized.
|
48
|
-
|
49
|
-
|
49
|
+
graph (Graph): Graph for which to add quantization info to each node.
|
50
|
+
quant_config (QuantizationConfig): Quantization configuration containing parameters for how the graph should be quantized.
|
51
|
+
bit_width_config (BitWidthConfig): Configuration for manual bit width selection. Defaults to None.
|
52
|
+
mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
|
53
|
+
running_gptq (bool): Whether or not a GPTQ optimization is planned to run after the PTQ process. Defaults to False.
|
50
54
|
|
51
55
|
Returns:
|
52
|
-
The graph with quantization configurations attached to each node in it.
|
56
|
+
Graph: The graph with quantization configurations attached to each node in it.
|
53
57
|
"""
|
54
58
|
|
55
59
|
if quant_config.weights_error_method == QuantizationErrorMethod.HMSE:
|
@@ -62,13 +66,16 @@ def set_quantization_configuration_to_graph(graph: Graph,
|
|
62
66
|
Logger.warning("Using the HMSE error method for weights quantization parameters search. "
|
63
67
|
"Note: This method may significantly increase runtime during the parameter search process.")
|
64
68
|
|
69
|
+
nodes_to_manipulate_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_bit_widths(graph)
|
70
|
+
|
65
71
|
for n in graph.nodes:
|
66
72
|
set_quantization_configs_to_node(node=n,
|
67
73
|
graph=graph,
|
68
74
|
quant_config=quant_config,
|
69
75
|
fw_info=graph.fw_info,
|
70
76
|
tpc=graph.tpc,
|
71
|
-
mixed_precision_enable=mixed_precision_enable
|
77
|
+
mixed_precision_enable=mixed_precision_enable,
|
78
|
+
manual_bit_width_override=nodes_to_manipulate_bit_widths.get(n))
|
72
79
|
return graph
|
73
80
|
|
74
81
|
|
@@ -77,21 +84,32 @@ def set_quantization_configs_to_node(node: BaseNode,
|
|
77
84
|
quant_config: QuantizationConfig,
|
78
85
|
fw_info: FrameworkInfo,
|
79
86
|
tpc: TargetPlatformCapabilities,
|
80
|
-
mixed_precision_enable: bool = False
|
87
|
+
mixed_precision_enable: bool = False,
|
88
|
+
manual_bit_width_override: Optional[int] = None):
|
81
89
|
"""
|
82
90
|
Create and set quantization configurations to a node (for both weights and activation).
|
83
91
|
|
84
92
|
Args:
|
85
|
-
node: Node to set its quantization configurations.
|
86
|
-
graph: Model's internal representation graph.
|
87
|
-
quant_config: Quantization configuration to generate the node's configurations from.
|
88
|
-
fw_info: Information needed for quantization about the specific framework.
|
89
|
-
tpc: TargetPlatformCapabilities to get default OpQuantizationConfig.
|
90
|
-
mixed_precision_enable:
|
93
|
+
node (BaseNode): Node to set its quantization configurations.
|
94
|
+
graph (Graph): Model's internal representation graph.
|
95
|
+
quant_config (QuantizationConfig): Quantization configuration to generate the node's configurations from.
|
96
|
+
fw_info (FrameworkInfo): Information needed for quantization about the specific framework.
|
97
|
+
tpc (TargetPlatformCapabilities): TargetPlatformCapabilities to get default OpQuantizationConfig.
|
98
|
+
mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
|
99
|
+
manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None.
|
91
100
|
"""
|
92
101
|
node_qc_options = node.get_qco(tpc)
|
93
102
|
base_config, node_qc_options_list = node.filter_node_qco_by_graph(tpc, graph.get_next_nodes(node), node_qc_options)
|
94
103
|
|
104
|
+
# If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override,
|
105
|
+
# and update base_config accordingly.
|
106
|
+
base_config, node_qc_options_list = filter_qc_options_with_manual_bit_width(
|
107
|
+
node=node,
|
108
|
+
node_qc_options_list=node_qc_options_list,
|
109
|
+
base_config=base_config,
|
110
|
+
manual_bit_width_override=manual_bit_width_override,
|
111
|
+
mixed_precision_enable=mixed_precision_enable)
|
112
|
+
|
95
113
|
# Create QC candidates for weights and activation combined
|
96
114
|
weight_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
|
97
115
|
node.candidates_quantization_cfg = _create_node_candidates_qc(quant_config,
|
@@ -199,16 +217,16 @@ def _create_node_candidates_qc(qc: QuantizationConfig,
|
|
199
217
|
Create a list of candidates of weights and activation quantization configurations for a node.
|
200
218
|
|
201
219
|
Args:
|
202
|
-
qc: Quantization configuration the quantization process should follow.
|
203
|
-
fw_info: Framework information (e.g., which layers should have their kernels
|
204
|
-
weight_channel_axis: (Output, Input) channel index of the node's kernel.
|
205
|
-
node_qc_options_list: List of quantization configs of node.
|
206
|
-
base_config: Base quantization config for node.
|
207
|
-
node: A node to set quantization configuration candidates to.
|
208
|
-
mixed_precision_enable:
|
220
|
+
qc (QuantizationConfig): Quantization configuration the quantization process should follow.
|
221
|
+
fw_info (FrameworkInfo): Framework information (e.g., which layers should have their kernels quantized).
|
222
|
+
weight_channel_axis (Tuple[int, int]): (Output, Input) channel index of the node's kernel.
|
223
|
+
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs of node.
|
224
|
+
base_config (OpQuantizationConfig): Base quantization config for node.
|
225
|
+
node (BaseNode): A node to set quantization configuration candidates to.
|
226
|
+
mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
|
209
227
|
|
210
228
|
Returns:
|
211
|
-
List of candidates of weights quantization configurations to set for a node.
|
229
|
+
List[CandidateNodeQuantizationConfig]: List of candidates of weights quantization configurations to set for a node.
|
212
230
|
"""
|
213
231
|
|
214
232
|
candidates = []
|
@@ -231,3 +249,51 @@ def _create_node_candidates_qc(qc: QuantizationConfig,
|
|
231
249
|
node_attrs_list))
|
232
250
|
|
233
251
|
return candidates
|
252
|
+
|
253
|
+
|
254
|
+
def filter_qc_options_with_manual_bit_width(
|
255
|
+
node: BaseNode,
|
256
|
+
node_qc_options_list: List[OpQuantizationConfig],
|
257
|
+
base_config: OpQuantizationConfig,
|
258
|
+
manual_bit_width_override: Optional[int],
|
259
|
+
mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
|
260
|
+
"""
|
261
|
+
Update the quantization configurations for a node, allowing manual bit-width overrides if specified.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
node (BaseNode): A node to set quantization configuration candidates to.
|
265
|
+
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
|
266
|
+
base_config (OpQuantizationConfig): Base quantization config for the node.
|
267
|
+
manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width.
|
268
|
+
mixed_precision_enable (bool): Whether mixed precision is enabled.
|
269
|
+
|
270
|
+
Returns:
|
271
|
+
Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
|
272
|
+
"""
|
273
|
+
if manual_bit_width_override is None:
|
274
|
+
return base_config, node_qc_options_list
|
275
|
+
|
276
|
+
# Filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override.
|
277
|
+
node_qc_options_list = [op_cfg for op_cfg in node_qc_options_list if
|
278
|
+
manual_bit_width_override == op_cfg.activation_n_bits]
|
279
|
+
|
280
|
+
if len(node_qc_options_list) == 0:
|
281
|
+
Logger.critical(f"Manually selected activation bit-width {manual_bit_width_override} is invalid for node {node}.")
|
282
|
+
else:
|
283
|
+
# Update the base_config to one of the values from the filtered node_qc_options_list.
|
284
|
+
# First, check if a configuration similar to the original base_config but with activation bits equal to manual_bit_width_override exists.
|
285
|
+
# If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
|
286
|
+
Logger.info(f"Setting node {node} bit-width to manually selected bit-width: {manual_bit_width_override} bits.")
|
287
|
+
updated_base_config = base_config.clone_and_edit({ACTIVATION_N_BITS, manual_bit_width_override})
|
288
|
+
if updated_base_config in node_qc_options_list:
|
289
|
+
# If a base_config with the specified manual_bit_width_override exists in the node_qc_options_list,
|
290
|
+
# point the base_config to this option.
|
291
|
+
base_config = node_qc_options_list[node_qc_options_list.index(updated_base_config)]
|
292
|
+
else:
|
293
|
+
# Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
|
294
|
+
base_config = node_qc_options_list[0]
|
295
|
+
if len(node_qc_options_list) > 0 and not mixed_precision_enable:
|
296
|
+
Logger.info(
|
297
|
+
f"Request received to select {manual_bit_width_override} activation bits. However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
|
298
|
+
f" Overriding base_config with an option that uses {manual_bit_width_override} bit activations.") # pragma: no cover
|
299
|
+
return base_config, node_qc_options_list
|
@@ -20,6 +20,7 @@ from model_compression_toolkit.core.common import FrameworkInfo
|
|
20
20
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
21
21
|
from model_compression_toolkit.core.common.fusion.layer_fusing import fusion
|
22
22
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
23
|
+
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
23
24
|
from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates
|
24
25
|
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
25
26
|
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
@@ -38,6 +39,7 @@ def graph_preparation_runner(in_model: Any,
|
|
38
39
|
fw_info: FrameworkInfo,
|
39
40
|
fw_impl: FrameworkImplementation,
|
40
41
|
tpc: TargetPlatformCapabilities,
|
42
|
+
bit_width_config: BitWidthConfig = None,
|
41
43
|
tb_w: TensorboardWriter = None,
|
42
44
|
mixed_precision_enable: bool = False,
|
43
45
|
running_gptq: bool = False) -> Graph:
|
@@ -50,17 +52,18 @@ def graph_preparation_runner(in_model: Any,
|
|
50
52
|
- Apply all necessary substitutions to finalize the graph for quantization.
|
51
53
|
|
52
54
|
Args:
|
53
|
-
in_model: Model to quantize.
|
54
|
-
representative_data_gen: Dataset used for calibration.
|
55
|
-
quantization_config: QuantizationConfig containing parameters of how the model should be quantized.
|
56
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
55
|
+
in_model (Any): Model to quantize.
|
56
|
+
representative_data_gen (Callable): Dataset used for calibration.
|
57
|
+
quantization_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be quantized.
|
58
|
+
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
57
59
|
groups of layers by how they should be quantized, etc.).
|
58
|
-
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
59
|
-
tpc: TargetPlatformCapabilities object that models the inference target platform and
|
60
|
+
fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
|
61
|
+
tpc (TargetPlatformCapabilities): TargetPlatformCapabilities object that models the inference target platform and
|
60
62
|
the attached framework operator's information.
|
61
|
-
|
62
|
-
|
63
|
-
|
63
|
+
bit_width_config (BitWidthConfig): Config for bit-width selection. Defaults to None.
|
64
|
+
tb_w (TensorboardWriter): TensorboardWriter object for logging.
|
65
|
+
mixed_precision_enable (bool): is mixed precision enabled.
|
66
|
+
running_gptq (bool): Whether or not a GPTQ optimization is planned to run after the PTQ process.
|
64
67
|
|
65
68
|
Returns:
|
66
69
|
An internal graph representation of the input model.
|
@@ -78,6 +81,7 @@ def graph_preparation_runner(in_model: Any,
|
|
78
81
|
transformed_graph = get_finalized_graph(graph,
|
79
82
|
tpc,
|
80
83
|
quantization_config,
|
84
|
+
bit_width_config,
|
81
85
|
fw_info,
|
82
86
|
tb_w,
|
83
87
|
fw_impl,
|
@@ -90,6 +94,7 @@ def graph_preparation_runner(in_model: Any,
|
|
90
94
|
def get_finalized_graph(initial_graph: Graph,
|
91
95
|
tpc: TargetPlatformCapabilities,
|
92
96
|
quant_config: QuantizationConfig = DEFAULTCONFIG,
|
97
|
+
bit_width_config: BitWidthConfig = None,
|
93
98
|
fw_info: FrameworkInfo = None,
|
94
99
|
tb_w: TensorboardWriter = None,
|
95
100
|
fw_impl: FrameworkImplementation = None,
|
@@ -104,6 +109,7 @@ def get_finalized_graph(initial_graph: Graph,
|
|
104
109
|
tpc (TargetPlatformCapabilities): TargetPlatformCapabilities object that describes the desired inference target platform (includes fusing patterns MCT should handle).
|
105
110
|
quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be
|
106
111
|
quantized.
|
112
|
+
bit_width_config (BitWidthConfig): Config for bit-width selection. Defaults to None.
|
107
113
|
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g.,
|
108
114
|
kernel channels indices, groups of layers by how they should be quantized, etc.)
|
109
115
|
tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
|
@@ -147,6 +153,7 @@ def get_finalized_graph(initial_graph: Graph,
|
|
147
153
|
######################################
|
148
154
|
transformed_graph = set_quantization_configuration_to_graph(graph=transformed_graph,
|
149
155
|
quant_config=quant_config,
|
156
|
+
bit_width_config=bit_width_config,
|
150
157
|
mixed_precision_enable=mixed_precision_enable,
|
151
158
|
running_gptq=running_gptq)
|
152
159
|
|
@@ -21,6 +21,7 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
|
|
21
21
|
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
22
22
|
from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
|
23
23
|
from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
|
24
|
+
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
24
25
|
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
25
26
|
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
26
27
|
from model_compression_toolkit.logger import Logger
|
@@ -20,6 +20,7 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
|
|
20
20
|
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
21
21
|
from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
|
22
22
|
from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
|
23
|
+
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
23
24
|
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
24
25
|
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
25
26
|
from model_compression_toolkit.logger import Logger
|
@@ -16,6 +16,7 @@ from model_compression_toolkit.core.common.framework_implementation import Frame
|
|
16
16
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
17
17
|
|
18
18
|
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
19
|
+
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
19
20
|
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
20
21
|
|
21
22
|
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
|
{mct_nightly-2.1.0.20240811.503.dist-info → mct_nightly-2.1.0.20240812.432.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.1.0.20240811.503.dist-info → mct_nightly-2.1.0.20240812.432.dist-info}/top_level.txt
RENAMED
File without changes
|