mct-nightly 2.1.0.20240810.432__py3-none-any.whl → 2.1.0.20240812.432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.1.0.20240810.432
3
+ Version: 2.1.0.20240812.432
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,13 +1,13 @@
1
- model_compression_toolkit/__init__.py,sha256=7RNr7Z_TEFTRvi88U4kbS4kMM1HNh0fF32c0pfQidmA,1573
1
+ model_compression_toolkit/__init__.py,sha256=dFfNYHNevEMx3n6CPcaXHcIQxE4Nlhrsckn2CtIDLiY,1573
2
2
  model_compression_toolkit/constants.py,sha256=0qrEGjX36Oo7Lt8mR0LD2aSe2xA7gKrhkzBGp7g5eiA,4345
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
5
5
  model_compression_toolkit/metadata.py,sha256=UtXS5ClK-qPoxGRuytlDGZSzgLo911dMni2EFRcg6io,3623
6
- model_compression_toolkit/core/__init__.py,sha256=TrRgkWpT1AN2Faw1M_1HXyJkJnbxfn9p-RigDZl7pg0,1982
6
+ model_compression_toolkit/core/__init__.py,sha256=tnDtL9KmT0vsOU27SsJ19TKDEbIH-tXYeGxTo5YnNUM,2077
7
7
  model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
8
- model_compression_toolkit/core/graph_prep_runner.py,sha256=kM70wmNG3yMFiGQc0uO0wn9j4ZbSWxUEykpxDK55doc,10567
8
+ model_compression_toolkit/core/graph_prep_runner.py,sha256=7-b7Jd5jBVaXOWg5nSqbEyzBtdaGDbCxs8aqMV6GZ6I,11287
9
9
  model_compression_toolkit/core/quantization_prep_runner.py,sha256=K9eJ7VbB_rpeyxX4yEnorOmSxFW3DkvofzxS6QI8Hp8,6454
10
- model_compression_toolkit/core/runner.py,sha256=uXpyYaX1uFNhKituGmSfKb3ZkguXG2V_Cg6XCnprplg,13569
10
+ model_compression_toolkit/core/runner.py,sha256=XQDNJirZkVJ_FXP72d7tbVc_Tr3Jw0Eqm_kxNHW8kPs,13636
11
11
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
12
12
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
13
13
  model_compression_toolkit/core/common/framework_implementation.py,sha256=kSg2f7wS7e2EyvX6y0eKfNTTFvVFVrB8lvldJvcPvN8,20724
@@ -70,7 +70,7 @@ model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,s
70
70
  model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=KifDMbm7qkSfvSl6pcZzQ82naIXzeKL6aT-VsvWZYyc,7901
71
71
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
72
72
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=HILF7CIn-GYPvPmTFyvjWLhuLDwSGwdBcAaKFgVYrwk,4745
73
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=az0XfBPVm1kAfxNCPb0Z-Q05-F-vqnmyRpKm6SBLa6c,13826
73
+ model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=3ZOI-RNp5faT-U2Og7rLW9EKwBB6ooa7-RwSsWJmquo,14022
74
74
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py,sha256=ttc8wPa_9LZansutQ2f1ss-RTzgTv739wy3qsdLzyyk,4217
75
75
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py,sha256=QhuqaECEGLnYC08iD6-2XXcU7NXbPzYf1sQcjYlGak8,1682
76
76
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py,sha256=WC1EHoNuo_lrzy4NRhGJ1cgmJ2IsFsbmP86mrVO3AVA,21506
@@ -98,8 +98,9 @@ model_compression_toolkit/core/common/pruning/mask/__init__.py,sha256=huHoBUcKNB
98
98
  model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py,sha256=APY8BsM9B7ZxVCH6n1xs9fSCTB_A9ou9gHrCQl1DOdI,5131
99
99
  model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256=4ohJrJHNzZk5uMnZEYkwLx2TDGzkh5kRhLGNVYNC6dc,5978
100
100
  model_compression_toolkit/core/common/quantization/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
101
+ model_compression_toolkit/core/common/quantization/bit_width_config.py,sha256=r6VQXgyJxX_AM1JTzv-sTcrvCTnktBfOkVP20RllNmk,4586
101
102
  model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=yU-Cr6S4wOSkDk57iH2NVe-WII0whOhLryejkomCOt4,4940
102
- model_compression_toolkit/core/common/quantization/core_config.py,sha256=KYdyfSmjSL4ye24nKlC_c4_AxYb14qoqaeMnZj4-8kE,2257
103
+ model_compression_toolkit/core/common/quantization/core_config.py,sha256=f0uSuY9mX-vLX_1s2DemPARQlAXmLPKJKPtCArz3pZI,2670
103
104
  model_compression_toolkit/core/common/quantization/debug_config.py,sha256=8G8SpE_4rb8xBp8d6mMq8R_OnXJ_1oxB2g-Lxk9EJCM,1691
104
105
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
105
106
  model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=YycYN8_JMzvSR3pTVm5dT5x4zP3yBHn0Z9agnwrvOKI,26395
@@ -108,7 +109,7 @@ model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,
108
109
  model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=MwIOBZ4BlZSTIOG75PDvlI3JmZ6t8YjPc1VP9Adei60,3847
109
110
  model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
110
111
  model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
111
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=9Y4eVDWCXFvCaXy2gbb-1880sp52M8wqH0M3KgAw8rM,12834
112
+ model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=0pZVO4wsNP815R9ZOd5ojC_OdNEeKkxYKdjggsqsZKg,17750
112
113
  model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
113
114
  model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=Fd_gxr5js-mqEwucaRR1CQAZ1W_wna19L1gAPeOzxRQ,23610
114
115
  model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=RL-PklAjGyC-26anSt8fU07a6pB_LBQFQy9o4e9giN0,8739
@@ -380,9 +381,9 @@ model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/__init__.py,sha256
380
381
  model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py,sha256=6uxq_w62jn8DDOt9T7VtA6jZ8jTAPcbTufKFOYpVUm4,8768
381
382
  model_compression_toolkit/pruning/__init__.py,sha256=lQMZS8G0pvR1LVi53nnJHNXgLNTan_MWMdwsVxhjrow,1106
382
383
  model_compression_toolkit/pruning/keras/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
383
- model_compression_toolkit/pruning/keras/pruning_facade.py,sha256=Vt9ipysniwQw4erWhwMO4oMCpIFUMKIGq67ugieMZd8,8612
384
+ model_compression_toolkit/pruning/keras/pruning_facade.py,sha256=vDpY97xirGF-o5XB6HvG_y2bL4LzfiTW3cPURTvaeKI,8707
384
385
  model_compression_toolkit/pruning/pytorch/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
385
- model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=cSuvHHCqgr7k9FdYOxFqe2njLcJ7IkzCrWSb26S0TK8,9398
386
+ model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=1uo5jWgbFNNhRbfb8da5REymMUdLJ3JidR8aAMXCBoE,9493
386
387
  model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
387
388
  model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
388
389
  model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
@@ -507,7 +508,7 @@ model_compression_toolkit/xquant/common/core_report_generator.py,sha256=GHnJJpK6
507
508
  model_compression_toolkit/xquant/common/dataset_utils.py,sha256=91uXF9UwxdY7BvUT0FNkFm8a69c8oK8Xdl-y7lbuJxk,1649
508
509
  model_compression_toolkit/xquant/common/framework_report_utils.py,sha256=YE49232ESflW6ZaUABF1pk_GGHBxa_F1X5oRN2Jogys,3734
509
510
  model_compression_toolkit/xquant/common/model_analyzer.py,sha256=T_8OetIQNqR0nkfSatWsEceXSPYpHfYjboBPIyR03-w,3953
510
- model_compression_toolkit/xquant/common/model_folding_utils.py,sha256=y5Vmc-hJ2rJhzWdM53HdY-PrT5LlspejTUNlXaCrq9Q,4720
511
+ model_compression_toolkit/xquant/common/model_folding_utils.py,sha256=7XMNmsngJgCPVjsuMNt6g4hzhkviB45qUmNRe9jQE7g,4815
511
512
  model_compression_toolkit/xquant/common/similarity_calculator.py,sha256=yCs_vlOThLzq7z-u2PkcEErLj7N7qCBPpRa6_5h34J8,10460
512
513
  model_compression_toolkit/xquant/common/similarity_functions.py,sha256=Atah1otdX9oUUch2JK-p-e291QHtkP_c4DfLG9WWo1Y,2935
513
514
  model_compression_toolkit/xquant/common/tensorboard_utils.py,sha256=85ABGQGKPZzctyZCHLazK0GxZ2ZUtQA3hZ_9fPiuMs0,6533
@@ -526,8 +527,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
526
527
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
527
528
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
528
529
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=yjghWXxqOtT-QXoXBOuJyh45yUpFI0pKjdDegum2i68,9705
529
- mct_nightly-2.1.0.20240810.432.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
530
- mct_nightly-2.1.0.20240810.432.dist-info/METADATA,sha256=jzPgK6fMH8_ZGTf9pcXXXPwGUondUO7o9yuRwpCTAm4,19718
531
- mct_nightly-2.1.0.20240810.432.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
532
- mct_nightly-2.1.0.20240810.432.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
533
- mct_nightly-2.1.0.20240810.432.dist-info/RECORD,,
530
+ mct_nightly-2.1.0.20240812.432.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
531
+ mct_nightly-2.1.0.20240812.432.dist-info/METADATA,sha256=2qqUeeA_e60PJ8S7cITQK2UlEOxMvWKKncdTK-H3v9E,19718
532
+ mct_nightly-2.1.0.20240812.432.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
533
+ mct_nightly-2.1.0.20240812.432.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
+ mct_nightly-2.1.0.20240812.432.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.1.0.20240810.000432"
30
+ __version__ = "2.1.0.20240812.000432"
@@ -19,6 +19,7 @@ from model_compression_toolkit.core.common.quantization.debug_config import Debu
19
19
  from model_compression_toolkit.core.common.quantization import quantization_config
20
20
  from model_compression_toolkit.core.common.mixed_precision import mixed_precision_quantization_config
21
21
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, QuantizationErrorMethod, DEFAULTCONFIG
22
+ from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
22
23
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
23
24
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
24
25
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
@@ -67,6 +67,7 @@ def compute_resource_utilization_data(in_model: Any,
67
67
  fw_info,
68
68
  fw_impl,
69
69
  tpc,
70
+ bit_width_config=core_config.bit_width_config,
70
71
  mixed_precision_enable=mixed_precision_enable)
71
72
 
72
73
  # Compute parameters sum
@@ -227,6 +228,7 @@ def requires_mixed_precision(in_model: Any,
227
228
  fw_info,
228
229
  fw_impl,
229
230
  tpc,
231
+ bit_width_config=core_config.bit_width_config,
230
232
  mixed_precision_enable=False)
231
233
  # Compute max weights memory in bytes
232
234
  weights_memory_by_layer_bytes, _ = compute_nodes_weights_params(transformed_graph, fw_info)
@@ -0,0 +1,91 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List, Union, Dict
16
+
17
+ from model_compression_toolkit.core.common import Graph
18
+ from model_compression_toolkit.core.common.matchers.node_matcher import BaseNodeMatcher
19
+ from model_compression_toolkit.logger import Logger
20
+
21
+
22
+ class ManualBitWidthSelection:
23
+ """
24
+ Class to encapsulate the manual bit width selection configuration for a specific filter.
25
+
26
+ Attributes:
27
+ filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation.
28
+ bit_width (int): The bit width to be applied to the selected nodes.
29
+ """
30
+ def __init__(self,
31
+ filter: BaseNodeMatcher,
32
+ bit_width: int):
33
+ self.filter = filter
34
+ self.bit_width = bit_width
35
+
36
+
37
+ class BitWidthConfig:
38
+ """
39
+ Class to manage manual bit-width configurations.
40
+
41
+ Attributes:
42
+ manual_activation_bit_width_selection_list (List[ManualBitWidthSelection]): A list of ManualBitWidthSelection objects defining manual bit-width configurations.
43
+ """
44
+ def __init__(self,
45
+ manual_activation_bit_width_selection_list: List[ManualBitWidthSelection] = None):
46
+ self.manual_activation_bit_width_selection_list = [] if manual_activation_bit_width_selection_list is None else manual_activation_bit_width_selection_list
47
+
48
+ def __repr__(self):
49
+ # Used for debugging, thus no cover.
50
+ return str(self.__dict__) # pragma: no cover
51
+
52
+ def set_manual_activation_bit_width(self,
53
+ filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
54
+ bit_widths: Union[List[int], int]):
55
+ """
56
+ Add a manual bit-width selection to the configuration.
57
+
58
+ Args:
59
+ filter (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
60
+ bit_width (Union[List[int], int]): The bit widths to be applied to the selected nodes.
61
+ If a single value is given it will be applied to all the filters
62
+ """
63
+ filters = [filters] if not isinstance(filters, list) else filters
64
+ bit_widths = [bit_widths] if not isinstance(bit_widths, list) else bit_widths
65
+ if len(bit_widths) > 1 and len(bit_widths) != len(filters):
66
+ Logger.critical(f"Configuration Error: The number of provided bit_width values {len(bit_widths)} "
67
+ f"must match the number of filters {len(filters)}, or a single bit_width value "
68
+ f"should be provided for all filters.")
69
+ elif len(bit_widths) == 1 and len(filters) > 1:
70
+ bit_widths = [bit_widths[0] for f in filters]
71
+ for bit_width, filter in zip (bit_widths, filters):
72
+ self.manual_activation_bit_width_selection_list += [ManualBitWidthSelection(filter, bit_width)]
73
+
74
+ def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> Dict:
75
+ """
76
+ Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
77
+
78
+ Args:
79
+ graph (Graph): The graph containing the nodes to be filtered and manipulated.
80
+
81
+ Returns:
82
+ Dict: A dictionary mapping nodes to their new bit-widths.
83
+ """
84
+ nodes_to_change_bit_width = {}
85
+ for manual_bit_width_selection in self.manual_activation_bit_width_selection_list:
86
+ filtered_nodes = graph.filter(manual_bit_width_selection.filter)
87
+ if len(filtered_nodes) == 0:
88
+ Logger.critical(f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
89
+ f"to change their bit width to {manual_bit_width_selection.bit_width}.")
90
+ nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width for n in filtered_nodes})
91
+ return nodes_to_change_bit_width
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
15
16
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
16
17
  from model_compression_toolkit.core.common.quantization.debug_config import DebugConfig
17
18
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
@@ -22,9 +23,10 @@ class CoreConfig:
22
23
  A class to hold the configurations classes of the MCT-core.
23
24
  """
24
25
  def __init__(self,
25
- quantization_config: QuantizationConfig = QuantizationConfig(),
26
+ quantization_config: QuantizationConfig = None,
26
27
  mixed_precision_config: MixedPrecisionQuantizationConfig = None,
27
- debug_config: DebugConfig = DebugConfig()
28
+ bit_width_config: BitWidthConfig = None,
29
+ debug_config: DebugConfig = None
28
30
  ):
29
31
  """
30
32
 
@@ -32,10 +34,12 @@ class CoreConfig:
32
34
  quantization_config (QuantizationConfig): Config for quantization.
33
35
  mixed_precision_config (MixedPrecisionQuantizationConfig): Config for mixed precision quantization.
34
36
  If None, a default MixedPrecisionQuantizationConfig is used.
37
+ bit_width_config (BitWidthConfig): Config for manual bit-width selection.
35
38
  debug_config (DebugConfig): Config for debugging and editing the network quantization process.
36
39
  """
37
- self.quantization_config = quantization_config
38
- self.debug_config = debug_config
40
+ self.quantization_config = QuantizationConfig() if quantization_config is None else quantization_config
41
+ self.bit_width_config = BitWidthConfig() if bit_width_config is None else bit_width_config
42
+ self.debug_config = DebugConfig() if debug_config is None else debug_config
39
43
 
40
44
  if mixed_precision_config is None:
41
45
  self.mixed_precision_config = MixedPrecisionQuantizationConfig()
@@ -15,9 +15,11 @@
15
15
 
16
16
 
17
17
  import copy
18
- from typing import List, Tuple
18
+ from typing import List, Tuple, Optional
19
19
 
20
+ from mct_quantizers.common.constants import ACTIVATION_N_BITS
20
21
  from model_compression_toolkit.core.common import BaseNode
22
+ from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
21
23
  from model_compression_toolkit.logger import Logger
22
24
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
25
  from model_compression_toolkit.core.common.graph.base_graph import Graph
@@ -37,19 +39,21 @@ from model_compression_toolkit.target_platform_capabilities.target_platform.op_q
37
39
 
38
40
  def set_quantization_configuration_to_graph(graph: Graph,
39
41
  quant_config: QuantizationConfig,
42
+ bit_width_config: BitWidthConfig = None,
40
43
  mixed_precision_enable: bool = False,
41
44
  running_gptq: bool = False) -> Graph:
42
45
  """
43
46
  Add quantization configuration for each graph node.
44
47
 
45
48
  Args:
46
- graph: Graph for which to add quantization info to each node.
47
- quant_config: Quantization configuration containing parameters for how the graph should be quantized.
48
- mixed_precision_enable: is mixed precision enabled.
49
- running_gptq: Whether or not a GPTQ optimization is planned to run after the PTQ process.
49
+ graph (Graph): Graph for which to add quantization info to each node.
50
+ quant_config (QuantizationConfig): Quantization configuration containing parameters for how the graph should be quantized.
51
+ bit_width_config (BitWidthConfig): Configuration for manual bit width selection. Defaults to None.
52
+ mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
53
+ running_gptq (bool): Whether or not a GPTQ optimization is planned to run after the PTQ process. Defaults to False.
50
54
 
51
55
  Returns:
52
- The graph with quantization configurations attached to each node in it.
56
+ Graph: The graph with quantization configurations attached to each node in it.
53
57
  """
54
58
 
55
59
  if quant_config.weights_error_method == QuantizationErrorMethod.HMSE:
@@ -62,13 +66,16 @@ def set_quantization_configuration_to_graph(graph: Graph,
62
66
  Logger.warning("Using the HMSE error method for weights quantization parameters search. "
63
67
  "Note: This method may significantly increase runtime during the parameter search process.")
64
68
 
69
+ nodes_to_manipulate_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_bit_widths(graph)
70
+
65
71
  for n in graph.nodes:
66
72
  set_quantization_configs_to_node(node=n,
67
73
  graph=graph,
68
74
  quant_config=quant_config,
69
75
  fw_info=graph.fw_info,
70
76
  tpc=graph.tpc,
71
- mixed_precision_enable=mixed_precision_enable)
77
+ mixed_precision_enable=mixed_precision_enable,
78
+ manual_bit_width_override=nodes_to_manipulate_bit_widths.get(n))
72
79
  return graph
73
80
 
74
81
 
@@ -77,21 +84,32 @@ def set_quantization_configs_to_node(node: BaseNode,
77
84
  quant_config: QuantizationConfig,
78
85
  fw_info: FrameworkInfo,
79
86
  tpc: TargetPlatformCapabilities,
80
- mixed_precision_enable: bool = False):
87
+ mixed_precision_enable: bool = False,
88
+ manual_bit_width_override: Optional[int] = None):
81
89
  """
82
90
  Create and set quantization configurations to a node (for both weights and activation).
83
91
 
84
92
  Args:
85
- node: Node to set its quantization configurations.
86
- graph: Model's internal representation graph.
87
- quant_config: Quantization configuration to generate the node's configurations from.
88
- fw_info: Information needed for quantization about the specific framework.
89
- tpc: TargetPlatformCapabilities to get default OpQuantizationConfig.
90
- mixed_precision_enable: is mixed precision enabled.
93
+ node (BaseNode): Node to set its quantization configurations.
94
+ graph (Graph): Model's internal representation graph.
95
+ quant_config (QuantizationConfig): Quantization configuration to generate the node's configurations from.
96
+ fw_info (FrameworkInfo): Information needed for quantization about the specific framework.
97
+ tpc (TargetPlatformCapabilities): TargetPlatformCapabilities to get default OpQuantizationConfig.
98
+ mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
99
+ manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None.
91
100
  """
92
101
  node_qc_options = node.get_qco(tpc)
93
102
  base_config, node_qc_options_list = node.filter_node_qco_by_graph(tpc, graph.get_next_nodes(node), node_qc_options)
94
103
 
104
+ # If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override,
105
+ # and update base_config accordingly.
106
+ base_config, node_qc_options_list = filter_qc_options_with_manual_bit_width(
107
+ node=node,
108
+ node_qc_options_list=node_qc_options_list,
109
+ base_config=base_config,
110
+ manual_bit_width_override=manual_bit_width_override,
111
+ mixed_precision_enable=mixed_precision_enable)
112
+
95
113
  # Create QC candidates for weights and activation combined
96
114
  weight_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
97
115
  node.candidates_quantization_cfg = _create_node_candidates_qc(quant_config,
@@ -199,16 +217,16 @@ def _create_node_candidates_qc(qc: QuantizationConfig,
199
217
  Create a list of candidates of weights and activation quantization configurations for a node.
200
218
 
201
219
  Args:
202
- qc: Quantization configuration the quantization process should follow.
203
- fw_info: Framework information (e.g., which layers should have their kernels' quantized).
204
- weight_channel_axis: (Output, Input) channel index of the node's kernel.
205
- node_qc_options_list: List of quantization configs of node.
206
- base_config: Base quantization config for node.
207
- node: A node to set quantization configuration candidates to.
208
- mixed_precision_enable: is mixed precision enabled
220
+ qc (QuantizationConfig): Quantization configuration the quantization process should follow.
221
+ fw_info (FrameworkInfo): Framework information (e.g., which layers should have their kernels quantized).
222
+ weight_channel_axis (Tuple[int, int]): (Output, Input) channel index of the node's kernel.
223
+ node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs of node.
224
+ base_config (OpQuantizationConfig): Base quantization config for node.
225
+ node (BaseNode): A node to set quantization configuration candidates to.
226
+ mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
209
227
 
210
228
  Returns:
211
- List of candidates of weights quantization configurations to set for a node.
229
+ List[CandidateNodeQuantizationConfig]: List of candidates of weights quantization configurations to set for a node.
212
230
  """
213
231
 
214
232
  candidates = []
@@ -231,3 +249,51 @@ def _create_node_candidates_qc(qc: QuantizationConfig,
231
249
  node_attrs_list))
232
250
 
233
251
  return candidates
252
+
253
+
254
+ def filter_qc_options_with_manual_bit_width(
255
+ node: BaseNode,
256
+ node_qc_options_list: List[OpQuantizationConfig],
257
+ base_config: OpQuantizationConfig,
258
+ manual_bit_width_override: Optional[int],
259
+ mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
260
+ """
261
+ Update the quantization configurations for a node, allowing manual bit-width overrides if specified.
262
+
263
+ Args:
264
+ node (BaseNode): A node to set quantization configuration candidates to.
265
+ node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
266
+ base_config (OpQuantizationConfig): Base quantization config for the node.
267
+ manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width.
268
+ mixed_precision_enable (bool): Whether mixed precision is enabled.
269
+
270
+ Returns:
271
+ Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
272
+ """
273
+ if manual_bit_width_override is None:
274
+ return base_config, node_qc_options_list
275
+
276
+ # Filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override.
277
+ node_qc_options_list = [op_cfg for op_cfg in node_qc_options_list if
278
+ manual_bit_width_override == op_cfg.activation_n_bits]
279
+
280
+ if len(node_qc_options_list) == 0:
281
+ Logger.critical(f"Manually selected activation bit-width {manual_bit_width_override} is invalid for node {node}.")
282
+ else:
283
+ # Update the base_config to one of the values from the filtered node_qc_options_list.
284
+ # First, check if a configuration similar to the original base_config but with activation bits equal to manual_bit_width_override exists.
285
+ # If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
286
+ Logger.info(f"Setting node {node} bit-width to manually selected bit-width: {manual_bit_width_override} bits.")
287
+ updated_base_config = base_config.clone_and_edit({ACTIVATION_N_BITS, manual_bit_width_override})
288
+ if updated_base_config in node_qc_options_list:
289
+ # If a base_config with the specified manual_bit_width_override exists in the node_qc_options_list,
290
+ # point the base_config to this option.
291
+ base_config = node_qc_options_list[node_qc_options_list.index(updated_base_config)]
292
+ else:
293
+ # Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
294
+ base_config = node_qc_options_list[0]
295
+ if len(node_qc_options_list) > 0 and not mixed_precision_enable:
296
+ Logger.info(
297
+ f"Request received to select {manual_bit_width_override} activation bits. However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
298
+ f" Overriding base_config with an option that uses {manual_bit_width_override} bit activations.") # pragma: no cover
299
+ return base_config, node_qc_options_list
@@ -20,6 +20,7 @@ from model_compression_toolkit.core.common import FrameworkInfo
20
20
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
21
  from model_compression_toolkit.core.common.fusion.layer_fusing import fusion
22
22
  from model_compression_toolkit.core.common.graph.base_graph import Graph
23
+ from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
23
24
  from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates
24
25
  from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
25
26
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
@@ -38,6 +39,7 @@ def graph_preparation_runner(in_model: Any,
38
39
  fw_info: FrameworkInfo,
39
40
  fw_impl: FrameworkImplementation,
40
41
  tpc: TargetPlatformCapabilities,
42
+ bit_width_config: BitWidthConfig = None,
41
43
  tb_w: TensorboardWriter = None,
42
44
  mixed_precision_enable: bool = False,
43
45
  running_gptq: bool = False) -> Graph:
@@ -50,17 +52,18 @@ def graph_preparation_runner(in_model: Any,
50
52
  - Apply all necessary substitutions to finalize the graph for quantization.
51
53
 
52
54
  Args:
53
- in_model: Model to quantize.
54
- representative_data_gen: Dataset used for calibration.
55
- quantization_config: QuantizationConfig containing parameters of how the model should be quantized.
56
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
55
+ in_model (Any): Model to quantize.
56
+ representative_data_gen (Callable): Dataset used for calibration.
57
+ quantization_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be quantized.
58
+ fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices,
57
59
  groups of layers by how they should be quantized, etc.).
58
- fw_impl: FrameworkImplementation object with a specific framework methods implementation.
59
- tpc: TargetPlatformCapabilities object that models the inference target platform and
60
+ fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
61
+ tpc (TargetPlatformCapabilities): TargetPlatformCapabilities object that models the inference target platform and
60
62
  the attached framework operator's information.
61
- tb_w: TensorboardWriter object for logging.
62
- mixed_precision_enable: is mixed precision enabled.
63
- running_gptq: Whether or not a GPTQ optimization is planned to run after the PTQ process.
63
+ bit_width_config (BitWidthConfig): Config for bit-width selection. Defaults to None.
64
+ tb_w (TensorboardWriter): TensorboardWriter object for logging.
65
+ mixed_precision_enable (bool): is mixed precision enabled.
66
+ running_gptq (bool): Whether or not a GPTQ optimization is planned to run after the PTQ process.
64
67
 
65
68
  Returns:
66
69
  An internal graph representation of the input model.
@@ -78,6 +81,7 @@ def graph_preparation_runner(in_model: Any,
78
81
  transformed_graph = get_finalized_graph(graph,
79
82
  tpc,
80
83
  quantization_config,
84
+ bit_width_config,
81
85
  fw_info,
82
86
  tb_w,
83
87
  fw_impl,
@@ -90,6 +94,7 @@ def graph_preparation_runner(in_model: Any,
90
94
  def get_finalized_graph(initial_graph: Graph,
91
95
  tpc: TargetPlatformCapabilities,
92
96
  quant_config: QuantizationConfig = DEFAULTCONFIG,
97
+ bit_width_config: BitWidthConfig = None,
93
98
  fw_info: FrameworkInfo = None,
94
99
  tb_w: TensorboardWriter = None,
95
100
  fw_impl: FrameworkImplementation = None,
@@ -104,6 +109,7 @@ def get_finalized_graph(initial_graph: Graph,
104
109
  tpc (TargetPlatformCapabilities): TargetPlatformCapabilities object that describes the desired inference target platform (includes fusing patterns MCT should handle).
105
110
  quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be
106
111
  quantized.
112
+ bit_width_config (BitWidthConfig): Config for bit-width selection. Defaults to None.
107
113
  fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g.,
108
114
  kernel channels indices, groups of layers by how they should be quantized, etc.)
109
115
  tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
@@ -147,6 +153,7 @@ def get_finalized_graph(initial_graph: Graph,
147
153
  ######################################
148
154
  transformed_graph = set_quantization_configuration_to_graph(graph=transformed_graph,
149
155
  quant_config=quant_config,
156
+ bit_width_config=bit_width_config,
150
157
  mixed_precision_enable=mixed_precision_enable,
151
158
  running_gptq=running_gptq)
152
159
 
@@ -115,6 +115,7 @@ def core_runner(in_model: Any,
115
115
  fw_info,
116
116
  fw_impl,
117
117
  tpc,
118
+ core_config.bit_width_config,
118
119
  tb_w,
119
120
  mixed_precision_enable=core_config.mixed_precision_enable,
120
121
  running_gptq=running_gptq)
@@ -21,6 +21,7 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
21
21
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
22
22
  from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
23
23
  from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
24
+ from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
24
25
  from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
25
26
  from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
26
27
  from model_compression_toolkit.logger import Logger
@@ -20,6 +20,7 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
20
20
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
21
21
  from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
22
22
  from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
23
+ from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
23
24
  from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
24
25
  from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
25
26
  from model_compression_toolkit.logger import Logger
@@ -16,6 +16,7 @@ from model_compression_toolkit.core.common.framework_implementation import Frame
16
16
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
17
17
 
18
18
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
19
+ from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
19
20
  from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
20
21
 
21
22
  from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner