mct-nightly 2.3.0.20250416.541__py3-none-any.whl → 2.3.0.20250418.531__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.3.0.20250416.541.dist-info → mct_nightly-2.3.0.20250418.531.dist-info}/METADATA +1 -1
- {mct_nightly-2.3.0.20250416.541.dist-info → mct_nightly-2.3.0.20250418.531.dist-info}/RECORD +19 -19
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_info.py +6 -0
- model_compression_toolkit/core/common/graph/base_graph.py +9 -19
- model_compression_toolkit/core/common/graph/base_node.py +25 -39
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +5 -6
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +7 -5
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +82 -100
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +32 -41
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +13 -11
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +12 -4
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +2 -10
- model_compression_toolkit/core/keras/default_framework_info.py +2 -2
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +2 -9
- model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
- {mct_nightly-2.3.0.20250416.541.dist-info → mct_nightly-2.3.0.20250418.531.dist-info}/WHEEL +0 -0
- {mct_nightly-2.3.0.20250416.541.dist-info → mct_nightly-2.3.0.20250418.531.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.3.0.20250416.541.dist-info → mct_nightly-2.3.0.20250418.531.dist-info}/top_level.txt +0 -0
{mct_nightly-2.3.0.20250416.541.dist-info → mct_nightly-2.3.0.20250418.531.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mct-nightly
|
3
|
-
Version: 2.3.0.
|
3
|
+
Version: 2.3.0.20250418.531
|
4
4
|
Summary: A Model Compression Toolkit for neural networks
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
6
6
|
Classifier: License :: OSI Approved :: Apache Software License
|
{mct_nightly-2.3.0.20250416.541.dist-info → mct_nightly-2.3.0.20250418.531.dist-info}/RECORD
RENAMED
@@ -1,5 +1,5 @@
|
|
1
|
-
mct_nightly-2.3.0.
|
2
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
mct_nightly-2.3.0.20250418.531.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
2
|
+
model_compression_toolkit/__init__.py,sha256=kz46wlIXHqUJ124-nGslxvPJ-ClTRO6XVJAKyFnXNrk,1557
|
3
3
|
model_compression_toolkit/constants.py,sha256=2ltuH-gdaLZoZV4CPUgKjC3S9ojz2z4OTVdenyVEypU,3912
|
4
4
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
5
5
|
model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
|
@@ -13,7 +13,7 @@ model_compression_toolkit/core/runner.py,sha256=_r6cieb7Ur2BeHQK5XxTZHogjyA0utyb
|
|
13
13
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
14
14
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
15
15
|
model_compression_toolkit/core/common/framework_implementation.py,sha256=L88uv_sfYM_56FSmxXP--emjv01_lk7IPqOI7QBZEt0,22939
|
16
|
-
model_compression_toolkit/core/common/framework_info.py,sha256=
|
16
|
+
model_compression_toolkit/core/common/framework_info.py,sha256=5tderHT-7Cd21QrRFIJj3hH_gAcnlivOzwZ5m1ldJOs,6526
|
17
17
|
model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
|
18
18
|
model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
|
19
19
|
model_compression_toolkit/core/common/model_collector.py,sha256=Tno3-qx9jmPZAZyLYgbPlMLHakVfuEH5deuToZNuCb0,13195
|
@@ -34,8 +34,8 @@ model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5
|
|
34
34
|
model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=W8qZejLwbm-lkvNF3GepNL3ypO10vFRxOxbq-o_rt_I,15479
|
35
35
|
model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=F0AaAUBpJ9JjHMB5H2LD9pdwTSWJK-Kqm9dQmGHX1Jo,7368
|
36
36
|
model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
|
37
|
-
model_compression_toolkit/core/common/graph/base_graph.py,sha256=
|
38
|
-
model_compression_toolkit/core/common/graph/base_node.py,sha256=
|
37
|
+
model_compression_toolkit/core/common/graph/base_graph.py,sha256=2aRpL8OP-JWKc2XFdsAQjACthJZmS8zgwIX-wjBRCFQ,41383
|
38
|
+
model_compression_toolkit/core/common/graph/base_node.py,sha256=AbUadAT581zelVcGcK9_--6CAGiht9qwkeWahwT3RzE,33389
|
39
39
|
model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
|
40
40
|
model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
|
41
41
|
model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
|
@@ -67,18 +67,18 @@ model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_uti
|
|
67
67
|
model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=-x8edUyudu1EAEM66AuXPtgayLpzbxoLNubfEbFM5kU,2867
|
68
68
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=6pLUEEIqRTVIlCYQC4JIvY55KAvuBHEX8uTOQ-1Ac4Q,3859
|
69
69
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=r1t025_QHshyoop-PZvL7x6UuXaeplCCU3h4VNBhJHo,4309
|
70
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256
|
71
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=
|
72
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=
|
70
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=-hOMBucYn12ePyLd0b1KxniPOIRu4b53SwEzv0bWToI,4943
|
71
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=d5-3j2e_rdcQOT7c4s0p7640i3nSetjJ6MgMhhMM7dc,6152
|
72
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=a0lyySRmQ1vKikx5YvDMA4l1Eha-W5BCPYScvDlL_6c,37300
|
73
73
|
model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=4bkM8pYKvk18cxHbx973Dz6qWrNT0MRm44cuk__qVaI,27297
|
74
74
|
model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
|
75
|
-
model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=
|
75
|
+
model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=S1ChgxtUjzXJufNWyRbKoNdyNC6fGUjPeComDMx8ZCo,9479
|
76
76
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
77
77
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
|
78
78
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=cjFnpDvxZDE4K2sgt26DhosA2XqhxHDs0eW5Qe7AwAQ,40668
|
79
79
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=QQwtl08DiDxUOQGpYPnek_RlZjWm1Ky7tL2ESHXMK78,4050
|
80
80
|
model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
81
|
-
model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=
|
81
|
+
model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=32s620FyREMBJYx3AUp6umlRfHxjqhL31PRbVtLdMJ4,6664
|
82
82
|
model_compression_toolkit/core/common/network_editors/__init__.py,sha256=vZmu55bYqiaOQs3AjfwWDXHmuKZcLHt-wm7uR5fPEqg,1307
|
83
83
|
model_compression_toolkit/core/common/network_editors/actions.py,sha256=nid0_j-Cn10xvmztT8yCKW_6uA7JEnom9SW9syx7wc0,19594
|
84
84
|
model_compression_toolkit/core/common/network_editors/edit_network.py,sha256=dfgawi-nB0ocAJ0xcGn9E-Zv203oUnQLuMiXpX8vTgA,1748
|
@@ -106,7 +106,7 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
|
|
106
106
|
model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
|
107
107
|
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=uH45Uq3Tp9FIyMynex_WY2_y-Kv8LuPw2XXZydnpW5A,1649
|
108
108
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=n2A8pO7_DMMae4o69U0I00iW6mzeRlRfKHDxlQUBBuI,7204
|
109
|
-
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=
|
109
|
+
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=0OJZtQuv-StbKZOpalvGi9lcpHJNRPeuclevSaCPggc,29792
|
110
110
|
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=UkSVW7d1OF_Px9gAjsqqK65aYhIBFWaBO-_IH6_AFfg,4403
|
111
111
|
model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=HfBkSiRTOf9mNF-TNQHTCCs3xSg66F20no0O6vl5v1Y,2154
|
112
112
|
model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=7eG7dl1TcbdnHwgmvyjarxLs0o6Lw_9VAjXAm4rsiBk,3791
|
@@ -157,7 +157,7 @@ model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7V
|
|
157
157
|
model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiOcnLAq1v0MMBCJqMJzzFk,3225
|
158
158
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
159
159
|
model_compression_toolkit/core/keras/data_util.py,sha256=jm54o-SlI1DJ-sEvRuX9OyLN68tEt0VxcqrdIjR98Ag,8366
|
160
|
-
model_compression_toolkit/core/keras/default_framework_info.py,sha256=
|
160
|
+
model_compression_toolkit/core/keras/default_framework_info.py,sha256=DvK1Tr6z3cQlJw1nx62iFaeSsQSXJl55xOIcJ1uNGu8,5020
|
161
161
|
model_compression_toolkit/core/keras/keras_implementation.py,sha256=_15BrSGTRSSp_8ayuo2x-hdKanew1xuIPSumP46IGSA,32545
|
162
162
|
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
|
163
163
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=HUmzEXDQ8LGX7uOYSRiLZ2TNbYxLX9J9IeAa6QYlifg,3927
|
@@ -168,7 +168,7 @@ model_compression_toolkit/core/keras/back2framework/factory_model_builder.py,sha
|
|
168
168
|
model_compression_toolkit/core/keras/back2framework/float_model_builder.py,sha256=9SFHhX-JnkB8PvYIIHRYlReBDI_RkZY9LditzW_ElLk,2444
|
169
169
|
model_compression_toolkit/core/keras/back2framework/instance_builder.py,sha256=fBj13c6zkVoWX4JJG18_uXPptiEJqXClE_zFbaFB6Q8,4517
|
170
170
|
model_compression_toolkit/core/keras/back2framework/keras_model_builder.py,sha256=TY86-Mb8hmo8RgCcQvkSYIthYOqV9e4VIMpqIyouJ4Y,17397
|
171
|
-
model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py,sha256=
|
171
|
+
model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py,sha256=BTDJB6VUAyVapzkwnftdXkv9RaQfwp_GIEk1FyovdGg,14813
|
172
172
|
model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py,sha256=5wFb4nx_F0Wu4c8pLf6n6OzxOHtpOJ6_3mQsNSXIudU,2481
|
173
173
|
model_compression_toolkit/core/keras/graph_substitutions/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
174
174
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
@@ -222,7 +222,7 @@ model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG
|
|
222
222
|
model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
223
223
|
model_compression_toolkit/core/pytorch/constants.py,sha256=Sg0hkUaMe88mI2_pd3KqhVz5ORnA46S1uq9Tj5qhtHc,2828
|
224
224
|
model_compression_toolkit/core/pytorch/data_util.py,sha256=YYbT135HhlTt0q6XdD2JX7AS_L92f_uV2rWq2hsJOCA,6325
|
225
|
-
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256
|
225
|
+
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=-byHTXmQEuOiqTAX45BHGi3mRRBF4_EfJ3XhpmVilSU,4355
|
226
226
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
|
227
227
|
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=c_QFo4e7t6b21CDakGhjVpqy5aXFxxqkdJ-s54HEOfs,31207
|
228
228
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
|
@@ -232,7 +232,7 @@ model_compression_toolkit/core/pytorch/back2framework/__init__.py,sha256=H_WixgN
|
|
232
232
|
model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,sha256=bwppTPRs6gL96nm7qPiKrNcBj4Krr0yEsOWjRF0aXmQ,2339
|
233
233
|
model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
|
234
234
|
model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
|
235
|
-
model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=
|
235
|
+
model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=K4L8FzJFM8_Ge2MHYkSqzCtoZe-ejEhVq8C1RgecyOc,14531
|
236
236
|
model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=WccaNiHK12IIimYu29E1oJkQHUdhPCBcIRutefTQ3Ag,19903
|
237
237
|
model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
|
238
238
|
model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
528
528
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
|
529
529
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
530
530
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
531
|
-
mct_nightly-2.3.0.
|
532
|
-
mct_nightly-2.3.0.
|
533
|
-
mct_nightly-2.3.0.
|
534
|
-
mct_nightly-2.3.0.
|
531
|
+
mct_nightly-2.3.0.20250418.531.dist-info/METADATA,sha256=l29V43qlD_uYRJsyqWxSE8HcNNbvTHVxWrr5DEvvVFw,25413
|
532
|
+
mct_nightly-2.3.0.20250418.531.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
533
|
+
mct_nightly-2.3.0.20250418.531.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
534
|
+
mct_nightly-2.3.0.20250418.531.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.3.0.
|
30
|
+
__version__ = "2.3.0.20250418.000531"
|
@@ -22,6 +22,12 @@ from mct_quantizers import QuantizationMethod
|
|
22
22
|
from model_compression_toolkit.defaultdict import DefaultDict
|
23
23
|
|
24
24
|
|
25
|
+
# Default value to use for ops without kernel.
|
26
|
+
# This is a weird default, but it's used all over the place, so for now only extract it to const so that it can be
|
27
|
+
# referenced by variable instead of hard-coded.
|
28
|
+
DEFAULT_KERNEL_ATTRIBUTES = [None]
|
29
|
+
|
30
|
+
|
25
31
|
class ChannelAxis(Enum):
|
26
32
|
"""
|
27
33
|
|
@@ -16,7 +16,7 @@ from collections import namedtuple
|
|
16
16
|
|
17
17
|
from copy import copy, deepcopy
|
18
18
|
from functools import wraps
|
19
|
-
from typing import List, Tuple, Any, Callable
|
19
|
+
from typing import List, Tuple, Any, Callable, Dict
|
20
20
|
|
21
21
|
import networkx as nx
|
22
22
|
import numpy as np
|
@@ -684,7 +684,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
684
684
|
sorted_configurable_nodes.append(n)
|
685
685
|
return sorted_configurable_nodes
|
686
686
|
|
687
|
-
def get_min_candidates_config(self, fw_info: FrameworkInfo) ->
|
687
|
+
def get_min_candidates_config(self, fw_info: FrameworkInfo) -> Dict[BaseNode, int]:
|
688
688
|
"""
|
689
689
|
Builds a minimal configuration.
|
690
690
|
Note: we assume that a minimal configuration exists, i.e., each configurable node has exactly one candidate
|
@@ -694,18 +694,13 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
694
694
|
Args:
|
695
695
|
fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
|
696
696
|
|
697
|
-
Returns:
|
697
|
+
Returns:
|
698
|
+
A dict from layer to an index of its minimal candidate.
|
698
699
|
"""
|
699
|
-
|
700
700
|
conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info)
|
701
|
-
|
702
|
-
|
703
|
-
assert all([len(lst) == 1 for lst in min_cfg_candidates]), \
|
704
|
-
f"A minimal config candidate must be defined, but some node have multiple potential minimal candidates"
|
705
|
-
|
706
|
-
return [lst[0] for lst in min_cfg_candidates]
|
701
|
+
return {n: n.find_min_candidate_index() for n in conf_sorted_nodes}
|
707
702
|
|
708
|
-
def get_max_candidates_config(self, fw_info: FrameworkInfo) ->
|
703
|
+
def get_max_candidates_config(self, fw_info: FrameworkInfo) -> Dict[BaseNode, int]:
|
709
704
|
"""
|
710
705
|
Builds a maximal configuration.
|
711
706
|
Note: we assume that a maximal configuration exists, i.e., each configurable node has exactly one candidate
|
@@ -715,16 +710,11 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
715
710
|
Args:
|
716
711
|
fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
|
717
712
|
|
718
|
-
Returns:
|
713
|
+
Returns:
|
714
|
+
A dict from layer to an index of its maximal candidate.
|
719
715
|
"""
|
720
|
-
|
721
716
|
conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info)
|
722
|
-
|
723
|
-
|
724
|
-
assert all([len(lst) == 1 for lst in max_cfg_candidates]), \
|
725
|
-
f"A maximal config candidate must be defined, but some node have multiple potential maximal candidates"
|
726
|
-
|
727
|
-
return [lst[0] for lst in max_cfg_candidates]
|
717
|
+
return {n: n.find_max_candidate_index() for n in conf_sorted_nodes}
|
728
718
|
|
729
719
|
def get_final_weights_config(self, fw_info: FrameworkInfo) -> List[Tuple[BaseNode, int]]:
|
730
720
|
"""
|
@@ -484,49 +484,35 @@ class BaseNode:
|
|
484
484
|
# for scalar shape (None,) prod returns 1
|
485
485
|
return sum([np.prod([x for x in output_shape if x is not None]) for output_shape in output_shapes])
|
486
486
|
|
487
|
-
def
|
487
|
+
def find_min_candidate_index(self) -> int:
|
488
488
|
"""
|
489
|
-
Returns
|
490
|
-
|
491
|
-
on the Pareto Front, i.e., there is no other candidate that its n_bits pair exceeds in both entries.
|
492
|
-
|
493
|
-
Returns: A list of indices of potential minimal candidates.
|
494
|
-
|
495
|
-
"""
|
496
|
-
|
497
|
-
# We assume that the candidates are sorted according to weights_n_bits first and activation_n_bits second
|
498
|
-
# First, we add the last candidate to the set of minimal candidates (candidate, index)
|
499
|
-
first_min = (len(self.candidates_quantization_cfg) - 1,
|
500
|
-
self.candidates_quantization_cfg[-1].activation_quantization_cfg.activation_n_bits)
|
501
|
-
min_candidates = [first_min]
|
502
|
-
|
503
|
-
# Iterate over all other candidates, and add ones with higher weights_n_bits but smaller activation_n_bits
|
504
|
-
for i, c in reversed(list(enumerate(self.candidates_quantization_cfg))):
|
505
|
-
if c.activation_quantization_cfg.activation_n_bits < first_min[1]:
|
506
|
-
min_candidates.append((i, c))
|
507
|
-
|
508
|
-
return [i for i, a_n_bits in min_candidates]
|
509
|
-
|
510
|
-
def find_max_candidates_indices(self) -> List[int]:
|
489
|
+
Returns:
|
490
|
+
The index of the minimal bit-width candidate.
|
511
491
|
"""
|
512
|
-
|
513
|
-
|
514
|
-
|
492
|
+
aw_nbits = [(c.activation_quantization_cfg.activation_n_bits,
|
493
|
+
*[v.weights_n_bits for v in c.weights_quantization_cfg.get_all_weight_attrs_configs().values()])
|
494
|
+
for c in self.candidates_quantization_cfg]
|
495
|
+
min_nbits = min(aw_nbits)
|
496
|
+
min_ind = [i for i, nb in enumerate(aw_nbits) if min_nbits == nb]
|
497
|
+
# check that no other candidate has a lower nbit for any weight
|
498
|
+
if len(min_ind) > 1 or any(nb[i] < min_nbits[i] for i in range(len(min_nbits)) for nb in aw_nbits):
|
499
|
+
raise ValueError('Expected exactly one candidate with min activation and min weights.')
|
500
|
+
return min_ind[0]
|
515
501
|
|
516
|
-
|
502
|
+
def find_max_candidate_index(self) -> int:
|
517
503
|
"""
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
for i,
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
return [
|
504
|
+
Returns:
|
505
|
+
The index of the maximal bit-width candidate.
|
506
|
+
"""
|
507
|
+
aw_nbits = [(c.activation_quantization_cfg.activation_n_bits,
|
508
|
+
*[v.weights_n_bits for v in c.weights_quantization_cfg.get_all_weight_attrs_configs().values()])
|
509
|
+
for c in self.candidates_quantization_cfg]
|
510
|
+
max_nbits = max(aw_nbits)
|
511
|
+
max_ind = [i for i, nb in enumerate(aw_nbits) if max_nbits == nb]
|
512
|
+
# check that no other candidate has a higher nbit for any weight
|
513
|
+
if len(max_ind) > 1 or any(nb[i] > max_nbits[i] for i in range(len(max_nbits)) for nb in aw_nbits):
|
514
|
+
raise ValueError('Expected exactly one candidate with max activation and max weights.')
|
515
|
+
return max_ind[0]
|
530
516
|
|
531
517
|
def get_unique_weights_candidates(self, attr: str) -> List[Any]:
|
532
518
|
"""
|
@@ -12,12 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from typing import
|
15
|
+
from typing import Set, Dict, Tuple
|
16
16
|
|
17
17
|
import numpy as np
|
18
18
|
|
19
19
|
from model_compression_toolkit.core import FrameworkInfo
|
20
|
-
from model_compression_toolkit.core.common import Graph
|
20
|
+
from model_compression_toolkit.core.common import Graph, BaseNode
|
21
21
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
22
22
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
|
23
23
|
RUTarget
|
@@ -36,7 +36,7 @@ class MixedPrecisionRUHelper:
|
|
36
36
|
self.fw_impl = fw_impl
|
37
37
|
self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
|
38
38
|
|
39
|
-
def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg:
|
39
|
+
def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: Dict[BaseNode, int]) -> Dict[RUTarget, np.ndarray]:
|
40
40
|
"""
|
41
41
|
Compute utilization of requested targets for a specific configuration:
|
42
42
|
for weights and bops - total utilization,
|
@@ -74,7 +74,7 @@ class MixedPrecisionRUHelper:
|
|
74
74
|
f'Requested {ru_targets}')
|
75
75
|
return ru_dict
|
76
76
|
|
77
|
-
def get_quantization_candidates(self, mp_cfg) \
|
77
|
+
def get_quantization_candidates(self, mp_cfg: Dict[BaseNode, int]) \
|
78
78
|
-> Tuple[Dict[str, NodeActivationQuantizationConfig], Dict[str, NodeWeightsQuantizationConfig]]:
|
79
79
|
"""
|
80
80
|
Retrieve quantization candidates objects for weights and activations from the configuration list.
|
@@ -86,8 +86,7 @@ class MixedPrecisionRUHelper:
|
|
86
86
|
A mapping between nodes to weights quantization config, and a mapping between nodes and activation
|
87
87
|
quantization config.
|
88
88
|
"""
|
89
|
-
|
90
|
-
node_qcs = {n: n.candidates_quantization_cfg[mp_cfg[i]] for i, n in enumerate(mp_nodes)}
|
89
|
+
node_qcs = {n: n.candidates_quantization_cfg[candidate_idx] for n, candidate_idx in mp_cfg.items()}
|
91
90
|
act_qcs = {n.name: cfg.activation_quantization_cfg for n, cfg in node_qcs.items()}
|
92
91
|
w_qcs = {n.name: cfg.weights_quantization_cfg for n, cfg in node_qcs.items()}
|
93
92
|
return act_qcs, w_qcs
|
@@ -14,10 +14,10 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
from enum import Enum
|
17
|
-
from typing import List, Callable
|
17
|
+
from typing import List, Callable, Dict
|
18
18
|
|
19
19
|
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
|
20
|
-
from model_compression_toolkit.core.common import Graph
|
20
|
+
from model_compression_toolkit.core.common import Graph, BaseNode
|
21
21
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
22
22
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
23
23
|
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
@@ -100,11 +100,13 @@ def search_bit_width(graph: Graph,
|
|
100
100
|
fw_impl,
|
101
101
|
se,
|
102
102
|
target_resource_utilization)
|
103
|
-
|
103
|
+
nodes_bit_cfg = search_manager.search()
|
104
104
|
|
105
105
|
graph.skip_validation_check = False
|
106
106
|
|
107
107
|
if mp_config.refine_mp_solution:
|
108
|
-
|
108
|
+
nodes_bit_cfg = greedy_solution_refinement_procedure(nodes_bit_cfg, search_manager, target_resource_utilization)
|
109
109
|
|
110
|
-
|
110
|
+
topo_bit_cfg = [nodes_bit_cfg[n] for n in graph.get_configurable_sorted_nodes(fw_info)]
|
111
|
+
assert len(topo_bit_cfg) == len(nodes_bit_cfg)
|
112
|
+
return topo_bit_cfg
|
@@ -12,6 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
import itertools
|
16
|
+
|
15
17
|
import copy
|
16
18
|
from collections import defaultdict
|
17
19
|
|
@@ -21,7 +23,6 @@ from typing import Dict, List, Tuple
|
|
21
23
|
|
22
24
|
import numpy as np
|
23
25
|
|
24
|
-
from model_compression_toolkit.constants import EPS
|
25
26
|
from model_compression_toolkit.core.common import BaseNode
|
26
27
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
27
28
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
@@ -75,34 +76,44 @@ class MixedPrecisionSearchManager:
|
|
75
76
|
self.target_resource_utilization = target_resource_utilization
|
76
77
|
|
77
78
|
self.mp_topo_configurable_nodes = self.mp_graph.get_configurable_sorted_nodes(fw_info)
|
78
|
-
self.layer_to_bitwidth_mapping = self.get_search_space()
|
79
79
|
|
80
80
|
self.ru_targets = target_resource_utilization.get_restricted_targets()
|
81
81
|
self.ru_helper = MixedPrecisionRUHelper(self.mp_graph, fw_info, fw_impl)
|
82
82
|
|
83
|
-
self.min_ru_config = self.mp_graph.get_min_candidates_config(fw_info)
|
84
|
-
self.max_ru_config = self.mp_graph.get_max_candidates_config(fw_info)
|
83
|
+
self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config(fw_info)
|
84
|
+
self.max_ru_config: Dict[BaseNode, int] = self.mp_graph.get_max_candidates_config(fw_info)
|
85
85
|
self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, self.min_ru_config)
|
86
86
|
|
87
87
|
self.config_reconstruction_helper = ConfigReconstructionHelper(virtual_graph=self.mp_graph,
|
88
88
|
original_graph=self.original_graph)
|
89
89
|
|
90
|
-
def search(self) ->
|
90
|
+
def search(self) -> Dict[BaseNode, int]:
|
91
91
|
"""
|
92
92
|
Run mixed precision search.
|
93
93
|
|
94
94
|
Returns:
|
95
|
-
|
95
|
+
Mapping from nodes to indices of the selected bit-widths candidate.
|
96
96
|
"""
|
97
|
-
|
98
|
-
candidates_ru = self._compute_relative_ru_matrices()
|
99
|
-
rel_target_ru = self._get_relative_ru_constraint_per_mem_element()
|
100
|
-
solver = MixedPrecisionIntegerLPSolver(candidates_sensitivity, candidates_ru, rel_target_ru)
|
101
|
-
config = solver.run()
|
97
|
+
mp_config = self._prepare_and_run_solver()
|
102
98
|
|
103
99
|
if self.using_virtual_graph:
|
104
|
-
|
105
|
-
|
100
|
+
mp_config = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(mp_config)
|
101
|
+
|
102
|
+
return mp_config
|
103
|
+
|
104
|
+
def _prepare_and_run_solver(self) -> Dict[BaseNode, int]:
|
105
|
+
"""
|
106
|
+
Prepare sensitivity and ru data for LP solver and run the solver.
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
Mapping from nodes to indices of the selected bit-widths candidate.
|
110
|
+
"""
|
111
|
+
layers_candidates_sensitivity: Dict[BaseNode, List[float]] = self._build_sensitivity_mapping()
|
112
|
+
candidates_ru = self._compute_relative_ru_matrices()
|
113
|
+
rel_target_ru = self._get_relative_ru_constraint_per_mem_element()
|
114
|
+
solver = MixedPrecisionIntegerLPSolver(layers_candidates_sensitivity, candidates_ru, rel_target_ru)
|
115
|
+
mp_config = solver.run()
|
116
|
+
return mp_config
|
106
117
|
|
107
118
|
def _get_relative_ru_constraint_per_mem_element(self) -> Dict[RUTarget, np.ndarray]:
|
108
119
|
"""
|
@@ -119,7 +130,7 @@ class MixedPrecisionSearchManager:
|
|
119
130
|
"""
|
120
131
|
target_ru = self.target_resource_utilization.get_resource_utilization_dict(restricted_only=True)
|
121
132
|
rel_target_ru = {
|
122
|
-
ru_target: ru - self.min_ru[ru_target] for ru_target, ru in target_ru.items()
|
133
|
+
ru_target: (ru - self.min_ru[ru_target]) for ru_target, ru in target_ru.items()
|
123
134
|
}
|
124
135
|
unsatisfiable_targets = {
|
125
136
|
ru_target.value: target_ru[ru_target] for ru_target, ru in rel_target_ru.items() if any(ru < 0)
|
@@ -129,28 +140,31 @@ class MixedPrecisionSearchManager:
|
|
129
140
|
f"following targets: {unsatisfiable_targets}")
|
130
141
|
return rel_target_ru
|
131
142
|
|
132
|
-
def _build_sensitivity_mapping(self, eps: float =
|
143
|
+
def _build_sensitivity_mapping(self, eps: float = 1e-6) -> Dict[BaseNode, List[float]]:
|
133
144
|
"""
|
134
145
|
This function measures the sensitivity of a change in a bitwidth of a layer on the entire model.
|
135
|
-
It builds a mapping from a node's index, to its bitwidht's effect on the model sensitivity.
|
136
|
-
For each node and some possible node's bitwidth (according to the given search space), we use
|
137
|
-
the framework function compute_metric_fn in order to infer
|
138
|
-
a batch of images, and compute (using the inference results) the sensitivity metric of
|
139
|
-
the configured mixed-precision model.
|
140
146
|
|
141
147
|
Args:
|
142
|
-
eps:
|
148
|
+
eps: if sensitivity for a non-max candidate is lower than for a max candidate, we set it to
|
149
|
+
sensitivity of a max candidate + epsilon.
|
143
150
|
|
144
151
|
Returns:
|
145
|
-
Mapping from
|
146
|
-
the sensitivity of the model.
|
147
|
-
|
152
|
+
Mapping from nodes to their bitwidth candidates sensitivity.
|
148
153
|
"""
|
149
154
|
|
150
155
|
Logger.info('Starting to evaluate metrics')
|
151
|
-
layer_to_metrics_mapping = {}
|
152
156
|
|
153
|
-
|
157
|
+
orig_sorted_nodes = self.original_graph.get_configurable_sorted_nodes(self.fw_info)
|
158
|
+
|
159
|
+
def topo_cfg(cfg: dict) -> list:
|
160
|
+
topo_cfg = [cfg[n] for n in orig_sorted_nodes]
|
161
|
+
assert len(topo_cfg) == len(cfg)
|
162
|
+
return topo_cfg
|
163
|
+
|
164
|
+
def compute_metric(cfg, node_idx=None, baseline_cfg=None):
|
165
|
+
return self.sensitivity_evaluator.compute_metric(topo_cfg(cfg),
|
166
|
+
node_idx,
|
167
|
+
topo_cfg(baseline_cfg) if baseline_cfg else None)
|
154
168
|
if self.using_virtual_graph:
|
155
169
|
origin_max_config = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(
|
156
170
|
self.max_ru_config)
|
@@ -158,19 +172,17 @@ class MixedPrecisionSearchManager:
|
|
158
172
|
else:
|
159
173
|
max_config_value = compute_metric(self.max_ru_config)
|
160
174
|
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
for bitwidth_idx in layer_possible_bitwidths_indices:
|
166
|
-
if self.max_ru_config[node_idx] == bitwidth_idx:
|
175
|
+
layer_to_metrics_mapping = defaultdict(list)
|
176
|
+
for node_idx, node in tqdm(enumerate(self.mp_topo_configurable_nodes)):
|
177
|
+
for bitwidth_idx, _ in enumerate(node.candidates_quantization_cfg):
|
178
|
+
if self.max_ru_config[node] == bitwidth_idx:
|
167
179
|
# This is a computation of the metric for the max configuration, assign pre-calculated value
|
168
|
-
layer_to_metrics_mapping[
|
180
|
+
layer_to_metrics_mapping[node].append(max_config_value)
|
169
181
|
continue
|
170
182
|
|
171
183
|
# Create a configuration that differs at one layer only from the baseline model
|
172
184
|
mp_model_configuration = self.max_ru_config.copy()
|
173
|
-
mp_model_configuration[
|
185
|
+
mp_model_configuration[node] = bitwidth_idx
|
174
186
|
|
175
187
|
# Build a distance matrix using the function we got from the framework implementation.
|
176
188
|
if self.using_virtual_graph:
|
@@ -180,8 +192,8 @@ class MixedPrecisionSearchManager:
|
|
180
192
|
mp_model_configuration,
|
181
193
|
changed_virtual_nodes_idx=[node_idx],
|
182
194
|
original_base_config=origin_max_config)
|
183
|
-
origin_changed_nodes_indices = [i for i, c in enumerate(origin_max_config) if
|
184
|
-
c != origin_mp_model_configuration[
|
195
|
+
origin_changed_nodes_indices = [i for i, (n, c) in enumerate(origin_max_config.items()) if
|
196
|
+
c != origin_mp_model_configuration[n]]
|
185
197
|
metric_value = compute_metric(
|
186
198
|
origin_mp_model_configuration,
|
187
199
|
origin_changed_nodes_indices,
|
@@ -191,11 +203,11 @@ class MixedPrecisionSearchManager:
|
|
191
203
|
mp_model_configuration,
|
192
204
|
[node_idx],
|
193
205
|
self.max_ru_config)
|
194
|
-
|
195
|
-
layer_to_metrics_mapping[
|
206
|
+
metric_value = max(metric_value, max_config_value + eps)
|
207
|
+
layer_to_metrics_mapping[node].append(metric_value)
|
196
208
|
|
197
209
|
# Finalize distance metric mapping
|
198
|
-
self.
|
210
|
+
self._finalize_distance_metric(layer_to_metrics_mapping)
|
199
211
|
|
200
212
|
return layer_to_metrics_mapping
|
201
213
|
|
@@ -221,22 +233,6 @@ class MixedPrecisionSearchManager:
|
|
221
233
|
|
222
234
|
return graph, False
|
223
235
|
|
224
|
-
def get_search_space(self) -> Dict[int, List[int]]:
|
225
|
-
"""
|
226
|
-
The search space is a mapping from a node's index to a list of integers (possible bitwidths candidates indeces
|
227
|
-
for the node).
|
228
|
-
|
229
|
-
Returns:
|
230
|
-
The entire search space of the graph.
|
231
|
-
"""
|
232
|
-
|
233
|
-
indices_mapping = {}
|
234
|
-
for idx, n in enumerate(self.mp_topo_configurable_nodes):
|
235
|
-
# For each node, get all possible bitwidth indices for it
|
236
|
-
# (which is a list from 0 to the length of the candidates mp_config list of the node).
|
237
|
-
indices_mapping[idx] = list(range(len(n.candidates_quantization_cfg))) # all search_methods space
|
238
|
-
return indices_mapping
|
239
|
-
|
240
236
|
def _compute_relative_ru_matrices(self) -> Dict[RUTarget, np.ndarray]:
|
241
237
|
"""
|
242
238
|
Computes and builds a resource utilization matrix for all restricted targets, to be used for the
|
@@ -248,55 +244,41 @@ class MixedPrecisionSearchManager:
|
|
248
244
|
per ru target. Num memory elements depends on the target, e.g. num cuts or 1 for cumulative metrics.
|
249
245
|
"""
|
250
246
|
rus_per_candidate = defaultdict(list)
|
251
|
-
for
|
252
|
-
for candidate_idx in
|
253
|
-
if candidate_idx == self.min_ru_config[
|
247
|
+
for node in self.mp_topo_configurable_nodes:
|
248
|
+
for candidate_idx, _ in enumerate(node.candidates_quantization_cfg):
|
249
|
+
if candidate_idx == self.min_ru_config[node]:
|
254
250
|
candidate_rus = self.min_ru
|
255
251
|
else:
|
256
|
-
|
252
|
+
cfg = self.min_ru_config.copy()
|
253
|
+
cfg[node] = candidate_idx
|
254
|
+
candidate_rus = self.ru_helper.compute_utilization(self.ru_targets, cfg)
|
257
255
|
|
258
256
|
for target, ru in candidate_rus.items():
|
259
257
|
rus_per_candidate[target].append(ru)
|
260
258
|
|
261
259
|
# Each target contains a matrix of num configurations X num elements
|
262
|
-
relative_rus = {target: np.array(ru) - self.min_ru[target] for target, ru in rus_per_candidate.items()}
|
260
|
+
relative_rus = {target: (np.array(ru) - self.min_ru[target]) for target, ru in rus_per_candidate.items()}
|
263
261
|
return relative_rus
|
264
262
|
|
265
|
-
def compute_ru_for_candidate(self, conf_node_idx: int, candidate_idx: int) -> Dict[RUTarget, np.ndarray]:
|
266
|
-
"""
|
267
|
-
Computes a resource utilization vector after replacing the given node's configuration candidate in the minimal
|
268
|
-
target configuration with the given candidate index.
|
269
|
-
|
270
|
-
Args:
|
271
|
-
conf_node_idx: The index of a node in a sorted configurable nodes list.
|
272
|
-
candidate_idx: Quantization config candidate to be used for the node's resource utilization computation.
|
273
|
-
|
274
|
-
Returns:
|
275
|
-
Node's resource utilization vector.
|
276
|
-
|
277
|
-
"""
|
278
|
-
cfg = self.replace_config_in_index(self.min_ru_config, conf_node_idx, candidate_idx)
|
279
|
-
return self.ru_helper.compute_utilization(self.ru_targets, cfg)
|
280
|
-
|
281
263
|
@staticmethod
|
282
|
-
def
|
264
|
+
def copy_config_with_replacement(mp_cfg: Dict[BaseNode, int], node: BaseNode, candidate_idx: int) -> Dict[BaseNode, int]:
|
283
265
|
"""
|
284
|
-
|
285
|
-
index (node's index) with the given value (candidate index).
|
266
|
+
Create a copy of the given mixed-precision configuration and update the candidate index for a specific node.
|
286
267
|
|
287
268
|
Args:
|
288
|
-
mp_cfg: Mixed-precision configuration
|
289
|
-
|
290
|
-
|
269
|
+
mp_cfg: Mixed-precision configuration.
|
270
|
+
node: Node to update the config for.
|
271
|
+
candidate_idx: A new candidate index to configure.
|
291
272
|
|
292
|
-
Returns:
|
273
|
+
Returns:
|
274
|
+
A new mixed-precision configuration.
|
293
275
|
|
294
276
|
"""
|
295
277
|
updated_cfg = mp_cfg.copy()
|
296
|
-
updated_cfg[
|
278
|
+
updated_cfg[node] = candidate_idx
|
297
279
|
return updated_cfg
|
298
280
|
|
299
|
-
def compute_resource_utilization_for_config(self, config:
|
281
|
+
def compute_resource_utilization_for_config(self, config: Dict[BaseNode, int]) -> ResourceUtilization:
|
300
282
|
"""
|
301
283
|
Computes the resource utilization values for a given mixed-precision configuration.
|
302
284
|
|
@@ -313,7 +295,7 @@ class MixedPrecisionSearchManager:
|
|
313
295
|
w_qcs=w_qcs, ru_targets=self.ru_targets, allow_unused_qcs=True)
|
314
296
|
return ru
|
315
297
|
|
316
|
-
def
|
298
|
+
def _finalize_distance_metric(self, layer_to_metrics_mapping: Dict[BaseNode, List[float]]):
|
317
299
|
"""
|
318
300
|
Finalizing the distance metric building.
|
319
301
|
The method checks to see if the maximal distance value is larger than a given threshold, and if so,
|
@@ -321,21 +303,20 @@ class MixedPrecisionSearchManager:
|
|
321
303
|
Modification to the dictionary is done inplace.
|
322
304
|
|
323
305
|
Args:
|
324
|
-
layer_to_metrics_mapping: A mapping between a node
|
325
|
-
a bitwidth index to a distance value.
|
306
|
+
layer_to_metrics_mapping: A mapping between a node to a list of distance values per bitwidth candidate.
|
326
307
|
|
327
308
|
"""
|
328
309
|
# normalize metric for numerical stability
|
310
|
+
max_dist = max(itertools.chain.from_iterable(layer_to_metrics_mapping.values()))
|
329
311
|
|
330
|
-
max_dist = max([max([d for b, d in dists.items()]) for layer, dists in layer_to_metrics_mapping.items()])
|
331
312
|
if max_dist >= self.sensitivity_evaluator.quant_config.metric_normalization_threshold:
|
332
313
|
Logger.warning(f"The mixed precision distance metric values indicate a large error in the quantized model."
|
333
314
|
f"this can cause numerical issues."
|
334
315
|
f"The program will proceed with mixed precision search after scaling the metric values,"
|
335
316
|
f"which can lead to unstable results.")
|
336
317
|
for layer, dists in layer_to_metrics_mapping.items():
|
337
|
-
for
|
338
|
-
layer_to_metrics_mapping[layer][
|
318
|
+
for i, _ in enumerate(dists):
|
319
|
+
layer_to_metrics_mapping[layer][i] /= max_dist
|
339
320
|
|
340
321
|
|
341
322
|
class ConfigReconstructionHelper:
|
@@ -363,7 +344,8 @@ class ConfigReconstructionHelper:
|
|
363
344
|
self.fw_info = original_graph.fw_info
|
364
345
|
|
365
346
|
self.virtual_sorted_nodes_names = self.virtual_graph.get_configurable_sorted_nodes_names(self.fw_info)
|
366
|
-
self.
|
347
|
+
self.origin_sorted_conf_nodes = self.original_graph.get_configurable_sorted_nodes(self.fw_info)
|
348
|
+
self.origin_sorted_conf_nodes_names = [n.name for n in self.origin_sorted_conf_nodes]
|
367
349
|
|
368
350
|
self.origin_node_idx_to_cfg = {}
|
369
351
|
|
@@ -375,9 +357,9 @@ class ConfigReconstructionHelper:
|
|
375
357
|
self.origin_node_idx_to_cfg = {}
|
376
358
|
|
377
359
|
def reconstruct_config_from_virtual_graph(self,
|
378
|
-
virtual_mp_cfg:
|
360
|
+
virtual_mp_cfg: Dict[BaseNode, int],
|
379
361
|
changed_virtual_nodes_idx: List[int] = None,
|
380
|
-
original_base_config:
|
362
|
+
original_base_config: Dict[BaseNode, int] = None) -> Dict[BaseNode, int]:
|
381
363
|
"""
|
382
364
|
Reconstructs the original config for a given virtual graph mixed-precision config.
|
383
365
|
It iterates over all virtual configurable node (that has some chosen bit-width virtual candidate)
|
@@ -405,21 +387,21 @@ class ConfigReconstructionHelper:
|
|
405
387
|
[(idx, self.virtual_graph.get_configurable_sorted_nodes(self.fw_info)[idx]) for idx in changed_virtual_nodes_idx]
|
406
388
|
# Iterating only over the virtual nodes that have updated config
|
407
389
|
for virtual_node_idx, n in updated_virtual_nodes:
|
408
|
-
self.reconstruct_node_config(n, virtual_mp_cfg, virtual_node_idx)
|
390
|
+
self.reconstruct_node_config(n, list(virtual_mp_cfg.values()), virtual_node_idx)
|
409
391
|
# Updating reconstructed config for all other nodes based on provided base_config
|
410
392
|
original_sorted_conf_nodes = self.original_graph.get_configurable_sorted_nodes(self.fw_info)
|
411
|
-
for i in
|
393
|
+
for i, (n, qc_ind) in enumerate(original_base_config.items()):
|
412
394
|
if i not in list(self.origin_node_idx_to_cfg.keys()):
|
413
|
-
self.update_config_at_original_idx(n=
|
414
|
-
origin_cfg_idx=original_base_config[i])
|
395
|
+
self.update_config_at_original_idx(n=n, origin_cfg_idx=qc_ind)
|
415
396
|
else:
|
416
397
|
# Reconstruct entire config
|
417
398
|
for virtual_node_idx, n in enumerate(self.virtual_graph.get_configurable_sorted_nodes(self.fw_info)):
|
418
|
-
self.reconstruct_node_config(n, virtual_mp_cfg, virtual_node_idx)
|
399
|
+
self.reconstruct_node_config(n, list(virtual_mp_cfg.values()), virtual_node_idx)
|
419
400
|
|
420
401
|
res_config = [self.origin_node_idx_to_cfg[key] for key in sorted(self.origin_node_idx_to_cfg.keys())]
|
421
402
|
self._clear_reconstruction_dict()
|
422
|
-
|
403
|
+
assert len(res_config) == len(self.origin_sorted_conf_nodes)
|
404
|
+
return {n: candidate_idx for n, candidate_idx in zip(self.origin_sorted_conf_nodes, res_config)}
|
423
405
|
|
424
406
|
def reconstruct_node_config(self,
|
425
407
|
n: BaseNode,
|
@@ -12,9 +12,11 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
from collections import defaultdict
|
16
|
+
|
15
17
|
import numpy as np
|
16
18
|
from pulp import *
|
17
|
-
from typing import Dict, Tuple,
|
19
|
+
from typing import Dict, Tuple, Any
|
18
20
|
|
19
21
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget
|
20
22
|
|
@@ -30,23 +32,23 @@ class MixedPrecisionIntegerLPSolver:
|
|
30
32
|
candidates_ru: resource utilization per candidate.
|
31
33
|
ru_constraints: resource utilization constraints corresponding to 'candidates_ru'.
|
32
34
|
"""
|
33
|
-
def __init__(self,
|
35
|
+
def __init__(self,
|
36
|
+
layer_to_sensitivity_mapping: Dict[Any, List[float]],
|
34
37
|
candidates_ru: Dict[RUTarget, np.ndarray],
|
35
38
|
ru_constraints: Dict[RUTarget, np.ndarray]):
|
39
|
+
|
36
40
|
self.layer_to_sensitivity_mapping = layer_to_sensitivity_mapping
|
37
41
|
self.candidates_ru = candidates_ru
|
38
42
|
self.ru_constraints = ru_constraints
|
39
43
|
|
40
|
-
self.
|
41
|
-
self._init_problem_vars(layer_to_sensitivity_mapping))
|
44
|
+
self.layer_to_indicator_vars, self.objective_vars = self._init_problem_vars(layer_to_sensitivity_mapping)
|
42
45
|
|
43
|
-
def run(self) ->
|
46
|
+
def run(self) -> Dict[Any, int]:
|
44
47
|
"""
|
45
48
|
Build and solve an ILP optimization problem.
|
46
49
|
|
47
50
|
Returns:
|
48
|
-
|
49
|
-
|
51
|
+
A dictionary from layer to the index of the selected bitwidth candidate.
|
50
52
|
"""
|
51
53
|
# Add all equations and inequalities that define the problem.
|
52
54
|
lp_problem = self._formalize_problem()
|
@@ -59,17 +61,14 @@ class MixedPrecisionIntegerLPSolver:
|
|
59
61
|
raise RuntimeError(f'No solution was found for the LP problem, with status {lp_problem.status}')
|
60
62
|
|
61
63
|
# Take the bitwidth index only if its corresponding indicator is one.
|
62
|
-
|
63
|
-
[
|
64
|
-
|
65
|
-
|
66
|
-
).flatten()
|
67
|
-
|
68
|
-
return config.tolist()
|
64
|
+
mp_config = {
|
65
|
+
layer: [v.varValue for v in vars].index(1.) for layer, vars in self.layer_to_indicator_vars.items()
|
66
|
+
}
|
67
|
+
return mp_config
|
69
68
|
|
70
69
|
@staticmethod
|
71
|
-
def _init_problem_vars(layer_to_metrics_mapping: Dict[
|
72
|
-
|
70
|
+
def _init_problem_vars(layer_to_metrics_mapping: Dict[Any, List[float]]) -> Tuple[Dict[Any, List[LpVariable]],
|
71
|
+
List[LpVariable]]:
|
73
72
|
"""
|
74
73
|
Initialize the LP problem variables: Variable for each layer as to the index of the bitwidth it should use,
|
75
74
|
and a variable for each indicator for whether we use the former variable or not.
|
@@ -83,21 +82,18 @@ class MixedPrecisionIntegerLPSolver:
|
|
83
82
|
and the second for indicators for each variable.
|
84
83
|
"""
|
85
84
|
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
for layer, nbits_to_metric in layer_to_metrics_mapping.items():
|
90
|
-
layer_to_indicator_vars_mapping[layer] = dict()
|
85
|
+
layer_to_indicator_vars = defaultdict(list)
|
86
|
+
objective_vars = []
|
91
87
|
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
88
|
+
for layer_idx, (layer, bitwidth_metrics) in enumerate(layer_to_metrics_mapping.items()):
|
89
|
+
layer_to_indicator_vars[layer] = [
|
90
|
+
LpVariable(f"layer_{layer_idx}_{qc_idx}", lowBound=0, upBound=1, cat=LpInteger)
|
91
|
+
for qc_idx, _ in enumerate(bitwidth_metrics)
|
92
|
+
]
|
97
93
|
|
98
|
-
|
94
|
+
objective_vars.append(LpVariable(f"s_{layer_idx}", 0))
|
99
95
|
|
100
|
-
return
|
96
|
+
return layer_to_indicator_vars, objective_vars
|
101
97
|
|
102
98
|
def _formalize_problem(self) -> LpProblem:
|
103
99
|
"""
|
@@ -108,18 +104,16 @@ class MixedPrecisionIntegerLPSolver:
|
|
108
104
|
"""
|
109
105
|
|
110
106
|
lp_problem = LpProblem() # minimization problem by default
|
111
|
-
lp_problem += lpSum(
|
112
|
-
self.layer_to_sensitivity_mapping.keys()]) # Objective (minimize acc loss)
|
107
|
+
lp_problem += lpSum(self.objective_vars)
|
113
108
|
|
114
|
-
for
|
109
|
+
for layer_sensitivity, layer_indicator_vars, obj_var in zip(self.layer_to_sensitivity_mapping.values(),
|
110
|
+
self.layer_to_indicator_vars.values(),
|
111
|
+
self.objective_vars):
|
115
112
|
# Use every bitwidth for every layer with its indicator.
|
116
|
-
lp_problem += lpSum(
|
117
|
-
for nbits, indicator in self.layer_to_indicator_vars_mapping[layer].items()]) == \
|
118
|
-
self.layer_to_objective_vars_mapping[layer]
|
113
|
+
lp_problem += lpSum(list(np.multiply(layer_indicator_vars, layer_sensitivity))) == obj_var
|
119
114
|
|
120
115
|
# Constraint of only one indicator==1
|
121
|
-
lp_problem += lpSum(
|
122
|
-
[v for v in self.layer_to_indicator_vars_mapping[layer].values()]) == 1
|
116
|
+
lp_problem += lpSum(layer_indicator_vars) == 1
|
123
117
|
|
124
118
|
# Bound the feasible solution space with the desired resource utilization values.
|
125
119
|
self._add_ru_constraints(lp_problem=lp_problem)
|
@@ -134,10 +128,7 @@ class MixedPrecisionIntegerLPSolver:
|
|
134
128
|
Args:
|
135
129
|
lp_problem: An Lp problem object to add constraint to.
|
136
130
|
"""
|
137
|
-
|
138
|
-
for layer in self.layer_to_sensitivity_mapping:
|
139
|
-
indicators.extend(list(self.layer_to_indicator_vars_mapping[layer].values()))
|
140
|
-
indicators_vec = np.array(indicators)
|
131
|
+
indicator_vars = list(itertools.chain(*self.layer_to_indicator_vars.values()))
|
141
132
|
|
142
133
|
for target, ru_matrix in self.candidates_ru.items():
|
143
134
|
# We expect 2d matrix of shape (num candidates, m). For cumulative metrics (weights, bops) m=1 - overall
|
@@ -146,7 +137,7 @@ class MixedPrecisionIntegerLPSolver:
|
|
146
137
|
if target in [RUTarget.WEIGHTS, RUTarget.BOPS]:
|
147
138
|
assert ru_matrix.shape[1] == 1
|
148
139
|
|
149
|
-
indicated_ru_matrix = ru_matrix.T *
|
140
|
+
indicated_ru_matrix = ru_matrix.T * np.array(indicator_vars)
|
150
141
|
# build lp sum term over all candidates
|
151
142
|
ru_vec = indicated_ru_matrix.sum(axis=1)
|
152
143
|
|
@@ -16,6 +16,7 @@
|
|
16
16
|
from typing import List, Tuple, Dict
|
17
17
|
|
18
18
|
from model_compression_toolkit.core import ResourceUtilization
|
19
|
+
from model_compression_toolkit.core.common import BaseNode
|
19
20
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import \
|
20
21
|
MixedPrecisionSearchManager
|
21
22
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
@@ -23,9 +24,9 @@ from model_compression_toolkit.core.common.quantization.candidate_node_quantizat
|
|
23
24
|
from model_compression_toolkit.logger import Logger
|
24
25
|
|
25
26
|
|
26
|
-
def greedy_solution_refinement_procedure(mp_solution:
|
27
|
+
def greedy_solution_refinement_procedure(mp_solution: Dict[BaseNode, int],
|
27
28
|
search_manager: MixedPrecisionSearchManager,
|
28
|
-
target_resource_utilization: ResourceUtilization) ->
|
29
|
+
target_resource_utilization: ResourceUtilization) -> Dict[BaseNode, int]:
|
29
30
|
"""
|
30
31
|
A greedy procedure to try and improve a mixed-precision solution that was found by a mixed-precision optimization
|
31
32
|
algorithm.
|
@@ -50,6 +51,8 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
|
|
50
51
|
Logger.info(f'Target resource utilization constraint BOPs - Skipping MP greedy solution refinement')
|
51
52
|
return mp_solution
|
52
53
|
|
54
|
+
assert search_manager.using_virtual_graph is False
|
55
|
+
|
53
56
|
new_solution = mp_solution.copy()
|
54
57
|
changed = True
|
55
58
|
|
@@ -58,17 +61,16 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
|
|
58
61
|
nodes_ru = {}
|
59
62
|
nodes_next_candidate = {}
|
60
63
|
|
61
|
-
for
|
62
|
-
if new_solution[
|
64
|
+
for node in search_manager.mp_topo_configurable_nodes:
|
65
|
+
if new_solution[node] == 0:
|
63
66
|
# layer has max config in the given solution, nothing to optimize
|
64
67
|
continue
|
65
68
|
|
66
|
-
|
67
|
-
node_candidates = current_node.candidates_quantization_cfg
|
69
|
+
node_candidates = node.candidates_quantization_cfg
|
68
70
|
|
69
71
|
# only weights kernel attribute is quantized with weights mixed precision
|
70
72
|
valid_candidates = _get_valid_candidates_indices(node_candidates,
|
71
|
-
new_solution[
|
73
|
+
new_solution[node],
|
72
74
|
target_resource_utilization.activation_restricted(),
|
73
75
|
target_resource_utilization.weight_restricted()
|
74
76
|
)
|
@@ -77,7 +79,7 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
|
|
77
79
|
updated_ru = []
|
78
80
|
for valid_idx in valid_candidates:
|
79
81
|
node_updated_ru = search_manager.compute_resource_utilization_for_config(
|
80
|
-
config=search_manager.
|
82
|
+
config=search_manager.copy_config_with_replacement(new_solution, node, valid_idx))
|
81
83
|
updated_ru.append(node_updated_ru)
|
82
84
|
|
83
85
|
# filter out new configs that don't hold the resource utilization restrictions
|
@@ -88,8 +90,8 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
|
|
88
90
|
sorted_by_ru = sorted(node_filtered_ru, key=lambda node_ru: (node_ru[1].total_memory,
|
89
91
|
node_ru[1].weights_memory,
|
90
92
|
node_ru[1].activation_memory))
|
91
|
-
nodes_ru[
|
92
|
-
nodes_next_candidate[
|
93
|
+
nodes_ru[node] = sorted_by_ru[0][1]
|
94
|
+
nodes_next_candidate[node] = sorted_by_ru[0][0]
|
93
95
|
|
94
96
|
if len(nodes_ru) > 0:
|
95
97
|
# filter out new configs that don't hold the ru restrictions
|
@@ -102,7 +104,7 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
|
|
102
104
|
new_solution[node_idx_to_upgrade] = nodes_next_candidate[node_idx_to_upgrade]
|
103
105
|
changed = True
|
104
106
|
|
105
|
-
if any([mp_solution[
|
107
|
+
if any([mp_solution[n] != new_solution[n] for n in mp_solution]):
|
106
108
|
Logger.info(f'Greedy MP algorithm changed configuration from (numbers represent indices of the '
|
107
109
|
f'chosen bit-width candidate for each layer):\n{mp_solution}\nto\n{new_solution}')
|
108
110
|
|
@@ -464,7 +464,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
464
464
|
weights_attr_cfg=attr_cfg,
|
465
465
|
weights_channels_axis=weights_channels_axis)
|
466
466
|
|
467
|
-
def get_attr_config(self, attr_name:
|
467
|
+
def get_attr_config(self, attr_name: 'WeightAttrT') -> WeightsAttrQuantizationConfig:
|
468
468
|
"""
|
469
469
|
Returns a weights attribute config for an attribute that contains the given name.
|
470
470
|
If multiple attributes that contain the given name are found - looking for the exact name, otherwise,
|
@@ -499,7 +499,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
499
499
|
|
500
500
|
return attr_cfg
|
501
501
|
|
502
|
-
def set_attr_config(self, attr_name:
|
502
|
+
def set_attr_config(self, attr_name: 'WeightAttrT', attr_qc: WeightsAttrQuantizationConfig):
|
503
503
|
"""
|
504
504
|
Adding a new attribute with quantization configuration to the node's weights configurations mapping.
|
505
505
|
|
@@ -513,7 +513,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
513
513
|
else:
|
514
514
|
self.attributes_config_mapping[attr_name] = attr_qc
|
515
515
|
|
516
|
-
def has_attribute_config(self, attr_name:
|
516
|
+
def has_attribute_config(self, attr_name: 'WeightAttrT') -> bool:
|
517
517
|
"""
|
518
518
|
Checks whether the node weights configuration contains a configuration for a given weights attribute.
|
519
519
|
|
@@ -541,6 +541,14 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
541
541
|
"""
|
542
542
|
return list(self.pos_attributes_config_mapping.keys()) + list(self.attributes_config_mapping.keys())
|
543
543
|
|
544
|
+
def get_all_weight_attrs_configs(self) -> Dict['WeightAttrT', AttributeQuantizationConfig]:
|
545
|
+
""" Get quantization configs for all weights.
|
546
|
+
|
547
|
+
Returns:
|
548
|
+
A dict from weight attribute to its config.
|
549
|
+
"""
|
550
|
+
return {attr: self.get_attr_config(attr) for attr in self.all_weight_attrs}
|
551
|
+
|
544
552
|
def _extract_config_for_attributes_with_name(self, attr_name) -> Dict[str, WeightsAttrQuantizationConfig]:
|
545
553
|
"""
|
546
554
|
Extract the saved attributes that contain the given attribute name.
|
@@ -560,7 +568,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
560
568
|
return attrs_with_name
|
561
569
|
|
562
570
|
def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any,
|
563
|
-
attr_name:
|
571
|
+
attr_name: 'WeightAttrT' = None, *args: List[Any], **kwargs: Dict[str, Any]):
|
564
572
|
"""
|
565
573
|
This method overrides the parent class set_quant_config_attr to enable setting a specific weights
|
566
574
|
attribute config parameter.
|
@@ -137,11 +137,7 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
|
|
137
137
|
|
138
138
|
float_weights = n.get_weights_by_keys(attr)
|
139
139
|
|
140
|
-
|
141
|
-
if not len(max_cfg_candidates) == 1:
|
142
|
-
Logger.critical(f"A maximal configuration candidate must be defined; found multiple potential maximal candidates.")# pragma: no cover
|
143
|
-
|
144
|
-
max_candidate_idx = max_cfg_candidates[0]
|
140
|
+
max_candidate_idx = n.find_max_candidate_index()
|
145
141
|
|
146
142
|
return {'node_q_cfg': node_q_cfg_candidates,
|
147
143
|
'float_weights': float_weights,
|
@@ -178,11 +174,7 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
|
|
178
174
|
# if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
|
179
175
|
n.sort_node_candidates(self.fw_info)
|
180
176
|
|
181
|
-
|
182
|
-
assert len(max_cfg_candidates) == 1, \
|
183
|
-
f"A maximal config candidate must be defined, but some node have multiple potential maximal candidates"
|
184
|
-
max_candidate_idx = max_cfg_candidates[0]
|
185
|
-
|
177
|
+
max_candidate_idx = n.find_max_candidate_index()
|
186
178
|
kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
|
187
179
|
activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': node_q_cfg_candidates,
|
188
180
|
'max_candidate_idx': max_candidate_idx,
|
@@ -25,7 +25,7 @@ else:
|
|
25
25
|
from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU # pragma: no cover
|
26
26
|
|
27
27
|
from model_compression_toolkit.defaultdict import DefaultDict
|
28
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
28
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo, DEFAULT_KERNEL_ATTRIBUTES
|
29
29
|
from mct_quantizers import QuantizationMethod
|
30
30
|
from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
|
31
31
|
from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
|
@@ -39,7 +39,7 @@ If a layer that is not listed here is queried, [None] is returned.
|
|
39
39
|
KERNEL_ATTRIBUTES = DefaultDict({Conv2D: [KERNEL],
|
40
40
|
DepthwiseConv2D: [DEPTHWISE_KERNEL],
|
41
41
|
Dense: [KERNEL],
|
42
|
-
Conv2DTranspose: [KERNEL]},
|
42
|
+
Conv2DTranspose: [KERNEL]}, DEFAULT_KERNEL_ATTRIBUTES)
|
43
43
|
|
44
44
|
|
45
45
|
"""
|
@@ -136,11 +136,7 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
|
|
136
136
|
|
137
137
|
float_weights = n.get_weights_by_keys(attr)
|
138
138
|
|
139
|
-
|
140
|
-
if not len(max_cfg_candidates) == 1:
|
141
|
-
Logger.critical(f"A maximal configuration candidate must be uniquely defined; however, multiple potential maximal candidates were found.") # pragma: no cover
|
142
|
-
|
143
|
-
max_candidate_idx = max_cfg_candidates[0]
|
139
|
+
max_candidate_idx = n.find_max_candidate_index()
|
144
140
|
|
145
141
|
return {'node_q_cfg': node_q_cfg_candidates,
|
146
142
|
'float_weights': float_weights,
|
@@ -175,10 +171,7 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
|
|
175
171
|
# if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
|
176
172
|
n.sort_node_candidates(self.fw_info)
|
177
173
|
|
178
|
-
|
179
|
-
assert len(max_cfg_candidates) == 1, \
|
180
|
-
f"A maximal configuration candidate must be uniquely defined; however, multiple potential maximal candidates were found."
|
181
|
-
max_candidate_idx = max_cfg_candidates[0]
|
174
|
+
max_candidate_idx = n.find_max_candidate_index()
|
182
175
|
|
183
176
|
kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
|
184
177
|
activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': node_q_cfg_candidates,
|
@@ -18,7 +18,7 @@ from torch.nn import Conv2d, ConvTranspose2d, Linear
|
|
18
18
|
from torch import sigmoid
|
19
19
|
|
20
20
|
from model_compression_toolkit.defaultdict import DefaultDict
|
21
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
21
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo, DEFAULT_KERNEL_ATTRIBUTES
|
22
22
|
from mct_quantizers import QuantizationMethod
|
23
23
|
from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
|
24
24
|
from model_compression_toolkit.core.pytorch.constants import KERNEL
|
@@ -33,7 +33,7 @@ If a layer that is not listed here is queried, [None] is returned.
|
|
33
33
|
KERNEL_ATTRIBUTES = DefaultDict({Conv2d: [KERNEL],
|
34
34
|
ConvTranspose2d: [KERNEL],
|
35
35
|
Linear: [KERNEL]},
|
36
|
-
|
36
|
+
DEFAULT_KERNEL_ATTRIBUTES)
|
37
37
|
|
38
38
|
"""
|
39
39
|
Map a layer to its kernel's output and input channels indices.
|
File without changes
|
File without changes
|
{mct_nightly-2.3.0.20250416.541.dist-info → mct_nightly-2.3.0.20250418.531.dist-info}/top_level.txt
RENAMED
File without changes
|