mct-nightly 2.3.0.20250416.541__py3-none-any.whl → 2.3.0.20250417.547__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {mct_nightly-2.3.0.20250416.541.dist-info → mct_nightly-2.3.0.20250417.547.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.3.0.20250416.541.dist-info → mct_nightly-2.3.0.20250417.547.dist-info}/RECORD +19 -19
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_info.py +6 -0
  5. model_compression_toolkit/core/common/graph/base_graph.py +9 -19
  6. model_compression_toolkit/core/common/graph/base_node.py +25 -39
  7. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +5 -6
  8. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +7 -5
  9. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +82 -100
  10. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +32 -41
  11. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +13 -11
  12. model_compression_toolkit/core/common/quantization/node_quantization_config.py +12 -4
  13. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +2 -10
  14. model_compression_toolkit/core/keras/default_framework_info.py +2 -2
  15. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +2 -9
  16. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  17. {mct_nightly-2.3.0.20250416.541.dist-info → mct_nightly-2.3.0.20250417.547.dist-info}/WHEEL +0 -0
  18. {mct_nightly-2.3.0.20250416.541.dist-info → mct_nightly-2.3.0.20250417.547.dist-info}/licenses/LICENSE.md +0 -0
  19. {mct_nightly-2.3.0.20250416.541.dist-info → mct_nightly-2.3.0.20250417.547.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250416.541
3
+ Version: 2.3.0.20250417.547
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: Apache Software License
@@ -1,5 +1,5 @@
1
- mct_nightly-2.3.0.20250416.541.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=AeZ2o5FMPLxX0sepHjLsV8WP2kgUvZWHt78DlPDh7u8,1557
1
+ mct_nightly-2.3.0.20250417.547.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=_Cp9pP6V9N9TqIEjqugu9Spw91FwzVTnoFww5V-hqIs,1557
3
3
  model_compression_toolkit/constants.py,sha256=2ltuH-gdaLZoZV4CPUgKjC3S9ojz2z4OTVdenyVEypU,3912
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
@@ -13,7 +13,7 @@ model_compression_toolkit/core/runner.py,sha256=_r6cieb7Ur2BeHQK5XxTZHogjyA0utyb
13
13
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
14
14
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
15
15
  model_compression_toolkit/core/common/framework_implementation.py,sha256=L88uv_sfYM_56FSmxXP--emjv01_lk7IPqOI7QBZEt0,22939
16
- model_compression_toolkit/core/common/framework_info.py,sha256=RWeZfQOPiBroU2v4AeZoquVunNtZ4UORjOr2aRAPu8o,6279
16
+ model_compression_toolkit/core/common/framework_info.py,sha256=5tderHT-7Cd21QrRFIJj3hH_gAcnlivOzwZ5m1ldJOs,6526
17
17
  model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
18
18
  model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
19
19
  model_compression_toolkit/core/common/model_collector.py,sha256=Tno3-qx9jmPZAZyLYgbPlMLHakVfuEH5deuToZNuCb0,13195
@@ -34,8 +34,8 @@ model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5
34
34
  model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=W8qZejLwbm-lkvNF3GepNL3ypO10vFRxOxbq-o_rt_I,15479
35
35
  model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=F0AaAUBpJ9JjHMB5H2LD9pdwTSWJK-Kqm9dQmGHX1Jo,7368
36
36
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
37
- model_compression_toolkit/core/common/graph/base_graph.py,sha256=3OhaMHW01okwFY4mSy0ERFCJk8AZPDs8bCKAmjvmJEI,41893
38
- model_compression_toolkit/core/common/graph/base_node.py,sha256=Yl6GdjnP_Rt9w1lQUm00CJI0JUAffQF7wr6mur_YfbA,34124
37
+ model_compression_toolkit/core/common/graph/base_graph.py,sha256=2aRpL8OP-JWKc2XFdsAQjACthJZmS8zgwIX-wjBRCFQ,41383
38
+ model_compression_toolkit/core/common/graph/base_node.py,sha256=AbUadAT581zelVcGcK9_--6CAGiht9qwkeWahwT3RzE,33389
39
39
  model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
40
40
  model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
41
41
  model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
@@ -67,18 +67,18 @@ model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_uti
67
67
  model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=-x8edUyudu1EAEM66AuXPtgayLpzbxoLNubfEbFM5kU,2867
68
68
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=6pLUEEIqRTVIlCYQC4JIvY55KAvuBHEX8uTOQ-1Ac4Q,3859
69
69
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=r1t025_QHshyoop-PZvL7x6UuXaeplCCU3h4VNBhJHo,4309
70
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=2Pp4hiYvGW2I9YhloDxQNT0sZRg3TDp9CXObloF8IFU,4971
71
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=GGrp7QngrWvWtPN8cQnL4IEbNwcVRc-hAUqfnxjjMmk,5998
72
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=NBzzhkVI407S9cIiw7t7nsP3MrkOdSnweKQdPBXb8to,38180
70
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=-hOMBucYn12ePyLd0b1KxniPOIRu4b53SwEzv0bWToI,4943
71
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=d5-3j2e_rdcQOT7c4s0p7640i3nSetjJ6MgMhhMM7dc,6152
72
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=a0lyySRmQ1vKikx5YvDMA4l1Eha-W5BCPYScvDlL_6c,37300
73
73
  model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=4bkM8pYKvk18cxHbx973Dz6qWrNT0MRm44cuk__qVaI,27297
74
74
  model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
75
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=fk7PWiZ6Na5O_Z_dymk_UfDCTqW_X_4EROU7DZknQnc,9444
75
+ model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=S1ChgxtUjzXJufNWyRbKoNdyNC6fGUjPeComDMx8ZCo,9479
76
76
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
77
77
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
78
78
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=cjFnpDvxZDE4K2sgt26DhosA2XqhxHDs0eW5Qe7AwAQ,40668
79
79
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=QQwtl08DiDxUOQGpYPnek_RlZjWm1Ky7tL2ESHXMK78,4050
80
80
  model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
81
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=TaK5NqVdmygsHw9_x5JsJ-BPvlbKA9cRyTno1R8gbnU,7269
81
+ model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=32s620FyREMBJYx3AUp6umlRfHxjqhL31PRbVtLdMJ4,6664
82
82
  model_compression_toolkit/core/common/network_editors/__init__.py,sha256=vZmu55bYqiaOQs3AjfwWDXHmuKZcLHt-wm7uR5fPEqg,1307
83
83
  model_compression_toolkit/core/common/network_editors/actions.py,sha256=nid0_j-Cn10xvmztT8yCKW_6uA7JEnom9SW9syx7wc0,19594
84
84
  model_compression_toolkit/core/common/network_editors/edit_network.py,sha256=dfgawi-nB0ocAJ0xcGn9E-Zv203oUnQLuMiXpX8vTgA,1748
@@ -106,7 +106,7 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
106
106
  model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
107
107
  model_compression_toolkit/core/common/quantization/debug_config.py,sha256=uH45Uq3Tp9FIyMynex_WY2_y-Kv8LuPw2XXZydnpW5A,1649
108
108
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=n2A8pO7_DMMae4o69U0I00iW6mzeRlRfKHDxlQUBBuI,7204
109
- model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=gL3XUm85FBLvtF60jmWkPxITOBw7cs66scNtC7QHW-M,29471
109
+ model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=0OJZtQuv-StbKZOpalvGi9lcpHJNRPeuclevSaCPggc,29792
110
110
  model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=UkSVW7d1OF_Px9gAjsqqK65aYhIBFWaBO-_IH6_AFfg,4403
111
111
  model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=HfBkSiRTOf9mNF-TNQHTCCs3xSg66F20no0O6vl5v1Y,2154
112
112
  model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=7eG7dl1TcbdnHwgmvyjarxLs0o6Lw_9VAjXAm4rsiBk,3791
@@ -157,7 +157,7 @@ model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7V
157
157
  model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiOcnLAq1v0MMBCJqMJzzFk,3225
158
158
  model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
159
159
  model_compression_toolkit/core/keras/data_util.py,sha256=jm54o-SlI1DJ-sEvRuX9OyLN68tEt0VxcqrdIjR98Ag,8366
160
- model_compression_toolkit/core/keras/default_framework_info.py,sha256=IGEHKH3IcmpRfyHuEBJTpEXu2-TDFfqQzpm8kHuj8QY,4974
160
+ model_compression_toolkit/core/keras/default_framework_info.py,sha256=DvK1Tr6z3cQlJw1nx62iFaeSsQSXJl55xOIcJ1uNGu8,5020
161
161
  model_compression_toolkit/core/keras/keras_implementation.py,sha256=_15BrSGTRSSp_8ayuo2x-hdKanew1xuIPSumP46IGSA,32545
162
162
  model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
163
163
  model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=HUmzEXDQ8LGX7uOYSRiLZ2TNbYxLX9J9IeAa6QYlifg,3927
@@ -168,7 +168,7 @@ model_compression_toolkit/core/keras/back2framework/factory_model_builder.py,sha
168
168
  model_compression_toolkit/core/keras/back2framework/float_model_builder.py,sha256=9SFHhX-JnkB8PvYIIHRYlReBDI_RkZY9LditzW_ElLk,2444
169
169
  model_compression_toolkit/core/keras/back2framework/instance_builder.py,sha256=fBj13c6zkVoWX4JJG18_uXPptiEJqXClE_zFbaFB6Q8,4517
170
170
  model_compression_toolkit/core/keras/back2framework/keras_model_builder.py,sha256=TY86-Mb8hmo8RgCcQvkSYIthYOqV9e4VIMpqIyouJ4Y,17397
171
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py,sha256=emsaCYyZBF7oQfXAR0edU7idiMInXLXRuGPcrUp4slM,15301
171
+ model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py,sha256=BTDJB6VUAyVapzkwnftdXkv9RaQfwp_GIEk1FyovdGg,14813
172
172
  model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py,sha256=5wFb4nx_F0Wu4c8pLf6n6OzxOHtpOJ6_3mQsNSXIudU,2481
173
173
  model_compression_toolkit/core/keras/graph_substitutions/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
174
174
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
@@ -222,7 +222,7 @@ model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG
222
222
  model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
223
223
  model_compression_toolkit/core/pytorch/constants.py,sha256=Sg0hkUaMe88mI2_pd3KqhVz5ORnA46S1uq9Tj5qhtHc,2828
224
224
  model_compression_toolkit/core/pytorch/data_util.py,sha256=YYbT135HhlTt0q6XdD2JX7AS_L92f_uV2rWq2hsJOCA,6325
225
- model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=NLdmiig5a2EBxutJeDHjp8px4g_2EKt3zmntmK-NrT4,4309
225
+ model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=-byHTXmQEuOiqTAX45BHGi3mRRBF4_EfJ3XhpmVilSU,4355
226
226
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
227
227
  model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=c_QFo4e7t6b21CDakGhjVpqy5aXFxxqkdJ-s54HEOfs,31207
228
228
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
@@ -232,7 +232,7 @@ model_compression_toolkit/core/pytorch/back2framework/__init__.py,sha256=H_WixgN
232
232
  model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,sha256=bwppTPRs6gL96nm7qPiKrNcBj4Krr0yEsOWjRF0aXmQ,2339
233
233
  model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
234
234
  model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
235
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=D7lU1r9Uq_7fdNuKk2BMF8ho5GrsY-8gyGN6yYoHaVg,15060
235
+ model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=K4L8FzJFM8_Ge2MHYkSqzCtoZe-ejEhVq8C1RgecyOc,14531
236
236
  model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=WccaNiHK12IIimYu29E1oJkQHUdhPCBcIRutefTQ3Ag,19903
237
237
  model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
238
238
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
528
528
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
529
529
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
530
530
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
531
- mct_nightly-2.3.0.20250416.541.dist-info/METADATA,sha256=r1uKB8w4EULCSj-_wL_b-doM7GuOlu4NeTVo11pYUj0,25413
532
- mct_nightly-2.3.0.20250416.541.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
533
- mct_nightly-2.3.0.20250416.541.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
- mct_nightly-2.3.0.20250416.541.dist-info/RECORD,,
531
+ mct_nightly-2.3.0.20250417.547.dist-info/METADATA,sha256=1DqfUY3Xwy1dlr0Utj2WraMYXYOVeLf5kly2Mq-3Uyw,25413
532
+ mct_nightly-2.3.0.20250417.547.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
533
+ mct_nightly-2.3.0.20250417.547.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
+ mct_nightly-2.3.0.20250417.547.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250416.000541"
30
+ __version__ = "2.3.0.20250417.000547"
@@ -22,6 +22,12 @@ from mct_quantizers import QuantizationMethod
22
22
  from model_compression_toolkit.defaultdict import DefaultDict
23
23
 
24
24
 
25
+ # Default value to use for ops without kernel.
26
+ # This is a weird default, but it's used all over the place, so for now only extract it to const so that it can be
27
+ # referenced by variable instead of hard-coded.
28
+ DEFAULT_KERNEL_ATTRIBUTES = [None]
29
+
30
+
25
31
  class ChannelAxis(Enum):
26
32
  """
27
33
 
@@ -16,7 +16,7 @@ from collections import namedtuple
16
16
 
17
17
  from copy import copy, deepcopy
18
18
  from functools import wraps
19
- from typing import List, Tuple, Any, Callable
19
+ from typing import List, Tuple, Any, Callable, Dict
20
20
 
21
21
  import networkx as nx
22
22
  import numpy as np
@@ -684,7 +684,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
684
684
  sorted_configurable_nodes.append(n)
685
685
  return sorted_configurable_nodes
686
686
 
687
- def get_min_candidates_config(self, fw_info: FrameworkInfo) -> List[int]:
687
+ def get_min_candidates_config(self, fw_info: FrameworkInfo) -> Dict[BaseNode, int]:
688
688
  """
689
689
  Builds a minimal configuration.
690
690
  Note: we assume that a minimal configuration exists, i.e., each configurable node has exactly one candidate
@@ -694,18 +694,13 @@ class Graph(nx.MultiDiGraph, GraphSearches):
694
694
  Args:
695
695
  fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
696
696
 
697
- Returns: A list of candidate for each node (list on indices)
697
+ Returns:
698
+ A dict from layer to an index of its minimal candidate.
698
699
  """
699
-
700
700
  conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info)
701
- min_cfg_candidates = [n.find_min_candidates_indices() for n in conf_sorted_nodes] # list of lists of indices
702
-
703
- assert all([len(lst) == 1 for lst in min_cfg_candidates]), \
704
- f"A minimal config candidate must be defined, but some node have multiple potential minimal candidates"
705
-
706
- return [lst[0] for lst in min_cfg_candidates]
701
+ return {n: n.find_min_candidate_index() for n in conf_sorted_nodes}
707
702
 
708
- def get_max_candidates_config(self, fw_info: FrameworkInfo) -> List[int]:
703
+ def get_max_candidates_config(self, fw_info: FrameworkInfo) -> Dict[BaseNode, int]:
709
704
  """
710
705
  Builds a maximal configuration.
711
706
  Note: we assume that a maximal configuration exists, i.e., each configurable node has exactly one candidate
@@ -715,16 +710,11 @@ class Graph(nx.MultiDiGraph, GraphSearches):
715
710
  Args:
716
711
  fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
717
712
 
718
- Returns: A list of candidate for each node (list on indices)
713
+ Returns:
714
+ A dict from layer to an index of its maximal candidate.
719
715
  """
720
-
721
716
  conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info)
722
- max_cfg_candidates = [n.find_max_candidates_indices() for n in conf_sorted_nodes] # list of lists of indices
723
-
724
- assert all([len(lst) == 1 for lst in max_cfg_candidates]), \
725
- f"A maximal config candidate must be defined, but some node have multiple potential maximal candidates"
726
-
727
- return [lst[0] for lst in max_cfg_candidates]
717
+ return {n: n.find_max_candidate_index() for n in conf_sorted_nodes}
728
718
 
729
719
  def get_final_weights_config(self, fw_info: FrameworkInfo) -> List[Tuple[BaseNode, int]]:
730
720
  """
@@ -484,49 +484,35 @@ class BaseNode:
484
484
  # for scalar shape (None,) prod returns 1
485
485
  return sum([np.prod([x for x in output_shape if x is not None]) for output_shape in output_shapes])
486
486
 
487
- def find_min_candidates_indices(self) -> List[int]:
487
+ def find_min_candidate_index(self) -> int:
488
488
  """
489
- Returns a list with potential minimal candidates.
490
- A potential minimal candidate is a candidate which its weights_n_bits and activation_n_bits pair is
491
- on the Pareto Front, i.e., there is no other candidate that its n_bits pair exceeds in both entries.
492
-
493
- Returns: A list of indices of potential minimal candidates.
494
-
495
- """
496
-
497
- # We assume that the candidates are sorted according to weights_n_bits first and activation_n_bits second
498
- # First, we add the last candidate to the set of minimal candidates (candidate, index)
499
- first_min = (len(self.candidates_quantization_cfg) - 1,
500
- self.candidates_quantization_cfg[-1].activation_quantization_cfg.activation_n_bits)
501
- min_candidates = [first_min]
502
-
503
- # Iterate over all other candidates, and add ones with higher weights_n_bits but smaller activation_n_bits
504
- for i, c in reversed(list(enumerate(self.candidates_quantization_cfg))):
505
- if c.activation_quantization_cfg.activation_n_bits < first_min[1]:
506
- min_candidates.append((i, c))
507
-
508
- return [i for i, a_n_bits in min_candidates]
509
-
510
- def find_max_candidates_indices(self) -> List[int]:
489
+ Returns:
490
+ The index of the minimal bit-width candidate.
511
491
  """
512
- Returns a list with potential maximal candidates.
513
- A potential maximal candidate is a candidate which its weights_n_bits and activation_n_bits pair is
514
- on the Pareto Front, i.e., there is no other candidates that its n_bits pair is lower in both entries.
492
+ aw_nbits = [(c.activation_quantization_cfg.activation_n_bits,
493
+ *[v.weights_n_bits for v in c.weights_quantization_cfg.get_all_weight_attrs_configs().values()])
494
+ for c in self.candidates_quantization_cfg]
495
+ min_nbits = min(aw_nbits)
496
+ min_ind = [i for i, nb in enumerate(aw_nbits) if min_nbits == nb]
497
+ # check that no other candidate has a lower nbit for any weight
498
+ if len(min_ind) > 1 or any(nb[i] < min_nbits[i] for i in range(len(min_nbits)) for nb in aw_nbits):
499
+ raise ValueError('Expected exactly one candidate with min activation and min weights.')
500
+ return min_ind[0]
515
501
 
516
- Returns: A list of indices of potential maximal candidates.
502
+ def find_max_candidate_index(self) -> int:
517
503
  """
518
-
519
- # We assume that the candidates are sorted according to weights_n_bits first and activation_n_bits second
520
- # First, we add the first candidate to the set of maximal candidates (candidate, index)
521
- first_max = (0, self.candidates_quantization_cfg[0].activation_quantization_cfg.activation_n_bits)
522
- max_candidates = [first_max]
523
-
524
- # Iterate over all other candidates, and add ones with higher weights_n_bits but smaller activation_n_bits
525
- for i, c in enumerate(self.candidates_quantization_cfg):
526
- if c.activation_quantization_cfg.activation_n_bits > first_max[1]:
527
- max_candidates.append((i, c))
528
-
529
- return [i for i, a_n_bits in max_candidates]
504
+ Returns:
505
+ The index of the maximal bit-width candidate.
506
+ """
507
+ aw_nbits = [(c.activation_quantization_cfg.activation_n_bits,
508
+ *[v.weights_n_bits for v in c.weights_quantization_cfg.get_all_weight_attrs_configs().values()])
509
+ for c in self.candidates_quantization_cfg]
510
+ max_nbits = max(aw_nbits)
511
+ max_ind = [i for i, nb in enumerate(aw_nbits) if max_nbits == nb]
512
+ # check that no other candidate has a higher nbit for any weight
513
+ if len(max_ind) > 1 or any(nb[i] > max_nbits[i] for i in range(len(max_nbits)) for nb in aw_nbits):
514
+ raise ValueError('Expected exactly one candidate with max activation and max weights.')
515
+ return max_ind[0]
530
516
 
531
517
  def get_unique_weights_candidates(self, attr: str) -> List[Any]:
532
518
  """
@@ -12,12 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import List, Set, Dict, Tuple
15
+ from typing import Set, Dict, Tuple
16
16
 
17
17
  import numpy as np
18
18
 
19
19
  from model_compression_toolkit.core import FrameworkInfo
20
- from model_compression_toolkit.core.common import Graph
20
+ from model_compression_toolkit.core.common import Graph, BaseNode
21
21
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
22
22
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
23
23
  RUTarget
@@ -36,7 +36,7 @@ class MixedPrecisionRUHelper:
36
36
  self.fw_impl = fw_impl
37
37
  self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
38
38
 
39
- def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: List[int]) -> Dict[RUTarget, np.ndarray]:
39
+ def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: Dict[BaseNode, int]) -> Dict[RUTarget, np.ndarray]:
40
40
  """
41
41
  Compute utilization of requested targets for a specific configuration:
42
42
  for weights and bops - total utilization,
@@ -74,7 +74,7 @@ class MixedPrecisionRUHelper:
74
74
  f'Requested {ru_targets}')
75
75
  return ru_dict
76
76
 
77
- def get_quantization_candidates(self, mp_cfg) \
77
+ def get_quantization_candidates(self, mp_cfg: Dict[BaseNode, int]) \
78
78
  -> Tuple[Dict[str, NodeActivationQuantizationConfig], Dict[str, NodeWeightsQuantizationConfig]]:
79
79
  """
80
80
  Retrieve quantization candidates objects for weights and activations from the configuration list.
@@ -86,8 +86,7 @@ class MixedPrecisionRUHelper:
86
86
  A mapping between nodes to weights quantization config, and a mapping between nodes and activation
87
87
  quantization config.
88
88
  """
89
- mp_nodes = self.graph.get_configurable_sorted_nodes(self.fw_info)
90
- node_qcs = {n: n.candidates_quantization_cfg[mp_cfg[i]] for i, n in enumerate(mp_nodes)}
89
+ node_qcs = {n: n.candidates_quantization_cfg[candidate_idx] for n, candidate_idx in mp_cfg.items()}
91
90
  act_qcs = {n.name: cfg.activation_quantization_cfg for n, cfg in node_qcs.items()}
92
91
  w_qcs = {n.name: cfg.weights_quantization_cfg for n, cfg in node_qcs.items()}
93
92
  return act_qcs, w_qcs
@@ -14,10 +14,10 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from enum import Enum
17
- from typing import List, Callable
17
+ from typing import List, Callable, Dict
18
18
 
19
19
  from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
20
- from model_compression_toolkit.core.common import Graph
20
+ from model_compression_toolkit.core.common import Graph, BaseNode
21
21
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
22
22
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
23
  from model_compression_toolkit.core.common.hessian import HessianInfoService
@@ -100,11 +100,13 @@ def search_bit_width(graph: Graph,
100
100
  fw_impl,
101
101
  se,
102
102
  target_resource_utilization)
103
- result_bit_cfg = search_manager.search()
103
+ nodes_bit_cfg = search_manager.search()
104
104
 
105
105
  graph.skip_validation_check = False
106
106
 
107
107
  if mp_config.refine_mp_solution:
108
- result_bit_cfg = greedy_solution_refinement_procedure(result_bit_cfg, search_manager, target_resource_utilization)
108
+ nodes_bit_cfg = greedy_solution_refinement_procedure(nodes_bit_cfg, search_manager, target_resource_utilization)
109
109
 
110
- return result_bit_cfg
110
+ topo_bit_cfg = [nodes_bit_cfg[n] for n in graph.get_configurable_sorted_nodes(fw_info)]
111
+ assert len(topo_bit_cfg) == len(nodes_bit_cfg)
112
+ return topo_bit_cfg
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import itertools
16
+
15
17
  import copy
16
18
  from collections import defaultdict
17
19
 
@@ -21,7 +23,6 @@ from typing import Dict, List, Tuple
21
23
 
22
24
  import numpy as np
23
25
 
24
- from model_compression_toolkit.constants import EPS
25
26
  from model_compression_toolkit.core.common import BaseNode
26
27
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
27
28
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
@@ -75,34 +76,44 @@ class MixedPrecisionSearchManager:
75
76
  self.target_resource_utilization = target_resource_utilization
76
77
 
77
78
  self.mp_topo_configurable_nodes = self.mp_graph.get_configurable_sorted_nodes(fw_info)
78
- self.layer_to_bitwidth_mapping = self.get_search_space()
79
79
 
80
80
  self.ru_targets = target_resource_utilization.get_restricted_targets()
81
81
  self.ru_helper = MixedPrecisionRUHelper(self.mp_graph, fw_info, fw_impl)
82
82
 
83
- self.min_ru_config = self.mp_graph.get_min_candidates_config(fw_info)
84
- self.max_ru_config = self.mp_graph.get_max_candidates_config(fw_info)
83
+ self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config(fw_info)
84
+ self.max_ru_config: Dict[BaseNode, int] = self.mp_graph.get_max_candidates_config(fw_info)
85
85
  self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, self.min_ru_config)
86
86
 
87
87
  self.config_reconstruction_helper = ConfigReconstructionHelper(virtual_graph=self.mp_graph,
88
88
  original_graph=self.original_graph)
89
89
 
90
- def search(self) -> List[int]:
90
+ def search(self) -> Dict[BaseNode, int]:
91
91
  """
92
92
  Run mixed precision search.
93
93
 
94
94
  Returns:
95
- Indices of the selected bit-widths candidates.
95
+ Mapping from nodes to indices of the selected bit-widths candidate.
96
96
  """
97
- candidates_sensitivity = self._build_sensitivity_mapping()
98
- candidates_ru = self._compute_relative_ru_matrices()
99
- rel_target_ru = self._get_relative_ru_constraint_per_mem_element()
100
- solver = MixedPrecisionIntegerLPSolver(candidates_sensitivity, candidates_ru, rel_target_ru)
101
- config = solver.run()
97
+ mp_config = self._prepare_and_run_solver()
102
98
 
103
99
  if self.using_virtual_graph:
104
- config = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(config)
105
- return config
100
+ mp_config = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(mp_config)
101
+
102
+ return mp_config
103
+
104
+ def _prepare_and_run_solver(self) -> Dict[BaseNode, int]:
105
+ """
106
+ Prepare sensitivity and ru data for LP solver and run the solver.
107
+
108
+ Returns:
109
+ Mapping from nodes to indices of the selected bit-widths candidate.
110
+ """
111
+ layers_candidates_sensitivity: Dict[BaseNode, List[float]] = self._build_sensitivity_mapping()
112
+ candidates_ru = self._compute_relative_ru_matrices()
113
+ rel_target_ru = self._get_relative_ru_constraint_per_mem_element()
114
+ solver = MixedPrecisionIntegerLPSolver(layers_candidates_sensitivity, candidates_ru, rel_target_ru)
115
+ mp_config = solver.run()
116
+ return mp_config
106
117
 
107
118
  def _get_relative_ru_constraint_per_mem_element(self) -> Dict[RUTarget, np.ndarray]:
108
119
  """
@@ -119,7 +130,7 @@ class MixedPrecisionSearchManager:
119
130
  """
120
131
  target_ru = self.target_resource_utilization.get_resource_utilization_dict(restricted_only=True)
121
132
  rel_target_ru = {
122
- ru_target: ru - self.min_ru[ru_target] for ru_target, ru in target_ru.items()
133
+ ru_target: (ru - self.min_ru[ru_target]) for ru_target, ru in target_ru.items()
123
134
  }
124
135
  unsatisfiable_targets = {
125
136
  ru_target.value: target_ru[ru_target] for ru_target, ru in rel_target_ru.items() if any(ru < 0)
@@ -129,28 +140,31 @@ class MixedPrecisionSearchManager:
129
140
  f"following targets: {unsatisfiable_targets}")
130
141
  return rel_target_ru
131
142
 
132
- def _build_sensitivity_mapping(self, eps: float = EPS) -> Dict[int, Dict[int, float]]:
143
+ def _build_sensitivity_mapping(self, eps: float = 1e-6) -> Dict[BaseNode, List[float]]:
133
144
  """
134
145
  This function measures the sensitivity of a change in a bitwidth of a layer on the entire model.
135
- It builds a mapping from a node's index, to its bitwidht's effect on the model sensitivity.
136
- For each node and some possible node's bitwidth (according to the given search space), we use
137
- the framework function compute_metric_fn in order to infer
138
- a batch of images, and compute (using the inference results) the sensitivity metric of
139
- the configured mixed-precision model.
140
146
 
141
147
  Args:
142
- eps: Epsilon value to manually increase metric value (if necessary) for numerical stability
148
+ eps: if sensitivity for a non-max candidate is lower than for a max candidate, we set it to
149
+ sensitivity of a max candidate + epsilon.
143
150
 
144
151
  Returns:
145
- Mapping from each node's index in a graph, to a dictionary from the bitwidth index (of this node) to
146
- the sensitivity of the model.
147
-
152
+ Mapping from nodes to their bitwidth candidates sensitivity.
148
153
  """
149
154
 
150
155
  Logger.info('Starting to evaluate metrics')
151
- layer_to_metrics_mapping = {}
152
156
 
153
- compute_metric = self.sensitivity_evaluator.compute_metric
157
+ orig_sorted_nodes = self.original_graph.get_configurable_sorted_nodes(self.fw_info)
158
+
159
+ def topo_cfg(cfg: dict) -> list:
160
+ topo_cfg = [cfg[n] for n in orig_sorted_nodes]
161
+ assert len(topo_cfg) == len(cfg)
162
+ return topo_cfg
163
+
164
+ def compute_metric(cfg, node_idx=None, baseline_cfg=None):
165
+ return self.sensitivity_evaluator.compute_metric(topo_cfg(cfg),
166
+ node_idx,
167
+ topo_cfg(baseline_cfg) if baseline_cfg else None)
154
168
  if self.using_virtual_graph:
155
169
  origin_max_config = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(
156
170
  self.max_ru_config)
@@ -158,19 +172,17 @@ class MixedPrecisionSearchManager:
158
172
  else:
159
173
  max_config_value = compute_metric(self.max_ru_config)
160
174
 
161
- for node_idx, layer_possible_bitwidths_indices in tqdm(self.layer_to_bitwidth_mapping.items(),
162
- total=len(self.layer_to_bitwidth_mapping)):
163
- layer_to_metrics_mapping[node_idx] = {}
164
-
165
- for bitwidth_idx in layer_possible_bitwidths_indices:
166
- if self.max_ru_config[node_idx] == bitwidth_idx:
175
+ layer_to_metrics_mapping = defaultdict(list)
176
+ for node_idx, node in tqdm(enumerate(self.mp_topo_configurable_nodes)):
177
+ for bitwidth_idx, _ in enumerate(node.candidates_quantization_cfg):
178
+ if self.max_ru_config[node] == bitwidth_idx:
167
179
  # This is a computation of the metric for the max configuration, assign pre-calculated value
168
- layer_to_metrics_mapping[node_idx][bitwidth_idx] = max_config_value
180
+ layer_to_metrics_mapping[node].append(max_config_value)
169
181
  continue
170
182
 
171
183
  # Create a configuration that differs at one layer only from the baseline model
172
184
  mp_model_configuration = self.max_ru_config.copy()
173
- mp_model_configuration[node_idx] = bitwidth_idx
185
+ mp_model_configuration[node] = bitwidth_idx
174
186
 
175
187
  # Build a distance matrix using the function we got from the framework implementation.
176
188
  if self.using_virtual_graph:
@@ -180,8 +192,8 @@ class MixedPrecisionSearchManager:
180
192
  mp_model_configuration,
181
193
  changed_virtual_nodes_idx=[node_idx],
182
194
  original_base_config=origin_max_config)
183
- origin_changed_nodes_indices = [i for i, c in enumerate(origin_max_config) if
184
- c != origin_mp_model_configuration[i]]
195
+ origin_changed_nodes_indices = [i for i, (n, c) in enumerate(origin_max_config.items()) if
196
+ c != origin_mp_model_configuration[n]]
185
197
  metric_value = compute_metric(
186
198
  origin_mp_model_configuration,
187
199
  origin_changed_nodes_indices,
@@ -191,11 +203,11 @@ class MixedPrecisionSearchManager:
191
203
  mp_model_configuration,
192
204
  [node_idx],
193
205
  self.max_ru_config)
194
-
195
- layer_to_metrics_mapping[node_idx][bitwidth_idx] = max(metric_value, max_config_value + eps)
206
+ metric_value = max(metric_value, max_config_value + eps)
207
+ layer_to_metrics_mapping[node].append(metric_value)
196
208
 
197
209
  # Finalize distance metric mapping
198
- self.finalize_distance_metric(layer_to_metrics_mapping)
210
+ self._finalize_distance_metric(layer_to_metrics_mapping)
199
211
 
200
212
  return layer_to_metrics_mapping
201
213
 
@@ -221,22 +233,6 @@ class MixedPrecisionSearchManager:
221
233
 
222
234
  return graph, False
223
235
 
224
- def get_search_space(self) -> Dict[int, List[int]]:
225
- """
226
- The search space is a mapping from a node's index to a list of integers (possible bitwidths candidates indeces
227
- for the node).
228
-
229
- Returns:
230
- The entire search space of the graph.
231
- """
232
-
233
- indices_mapping = {}
234
- for idx, n in enumerate(self.mp_topo_configurable_nodes):
235
- # For each node, get all possible bitwidth indices for it
236
- # (which is a list from 0 to the length of the candidates mp_config list of the node).
237
- indices_mapping[idx] = list(range(len(n.candidates_quantization_cfg))) # all search_methods space
238
- return indices_mapping
239
-
240
236
  def _compute_relative_ru_matrices(self) -> Dict[RUTarget, np.ndarray]:
241
237
  """
242
238
  Computes and builds a resource utilization matrix for all restricted targets, to be used for the
@@ -248,55 +244,41 @@ class MixedPrecisionSearchManager:
248
244
  per ru target. Num memory elements depends on the target, e.g. num cuts or 1 for cumulative metrics.
249
245
  """
250
246
  rus_per_candidate = defaultdict(list)
251
- for c, c_n in enumerate(self.mp_topo_configurable_nodes):
252
- for candidate_idx in range(len(c_n.candidates_quantization_cfg)):
253
- if candidate_idx == self.min_ru_config[c]:
247
+ for node in self.mp_topo_configurable_nodes:
248
+ for candidate_idx, _ in enumerate(node.candidates_quantization_cfg):
249
+ if candidate_idx == self.min_ru_config[node]:
254
250
  candidate_rus = self.min_ru
255
251
  else:
256
- candidate_rus = self.compute_ru_for_candidate(c, candidate_idx)
252
+ cfg = self.min_ru_config.copy()
253
+ cfg[node] = candidate_idx
254
+ candidate_rus = self.ru_helper.compute_utilization(self.ru_targets, cfg)
257
255
 
258
256
  for target, ru in candidate_rus.items():
259
257
  rus_per_candidate[target].append(ru)
260
258
 
261
259
  # Each target contains a matrix of num configurations X num elements
262
- relative_rus = {target: np.array(ru) - self.min_ru[target] for target, ru in rus_per_candidate.items()}
260
+ relative_rus = {target: (np.array(ru) - self.min_ru[target]) for target, ru in rus_per_candidate.items()}
263
261
  return relative_rus
264
262
 
265
- def compute_ru_for_candidate(self, conf_node_idx: int, candidate_idx: int) -> Dict[RUTarget, np.ndarray]:
266
- """
267
- Computes a resource utilization vector after replacing the given node's configuration candidate in the minimal
268
- target configuration with the given candidate index.
269
-
270
- Args:
271
- conf_node_idx: The index of a node in a sorted configurable nodes list.
272
- candidate_idx: Quantization config candidate to be used for the node's resource utilization computation.
273
-
274
- Returns:
275
- Node's resource utilization vector.
276
-
277
- """
278
- cfg = self.replace_config_in_index(self.min_ru_config, conf_node_idx, candidate_idx)
279
- return self.ru_helper.compute_utilization(self.ru_targets, cfg)
280
-
281
263
  @staticmethod
282
- def replace_config_in_index(mp_cfg: List[int], idx: int, value: int) -> List[int]:
264
+ def copy_config_with_replacement(mp_cfg: Dict[BaseNode, int], node: BaseNode, candidate_idx: int) -> Dict[BaseNode, int]:
283
265
  """
284
- Replacing the quantization configuration candidate in a given mixed-precision configuration at the given
285
- index (node's index) with the given value (candidate index).
266
+ Create a copy of the given mixed-precision configuration and update the candidate index for a specific node.
286
267
 
287
268
  Args:
288
- mp_cfg: Mixed-precision configuration (list of candidates' indices)
289
- idx: A configurable node's index.
290
- value: A new candidate index to configure.
269
+ mp_cfg: Mixed-precision configuration.
270
+ node: Node to update the config for.
271
+ candidate_idx: A new candidate index to configure.
291
272
 
292
- Returns: A new mixed-precision configuration.
273
+ Returns:
274
+ A new mixed-precision configuration.
293
275
 
294
276
  """
295
277
  updated_cfg = mp_cfg.copy()
296
- updated_cfg[idx] = value
278
+ updated_cfg[node] = candidate_idx
297
279
  return updated_cfg
298
280
 
299
- def compute_resource_utilization_for_config(self, config: List[int]) -> ResourceUtilization:
281
+ def compute_resource_utilization_for_config(self, config: Dict[BaseNode, int]) -> ResourceUtilization:
300
282
  """
301
283
  Computes the resource utilization values for a given mixed-precision configuration.
302
284
 
@@ -313,7 +295,7 @@ class MixedPrecisionSearchManager:
313
295
  w_qcs=w_qcs, ru_targets=self.ru_targets, allow_unused_qcs=True)
314
296
  return ru
315
297
 
316
- def finalize_distance_metric(self, layer_to_metrics_mapping: Dict[int, Dict[int, float]]):
298
+ def _finalize_distance_metric(self, layer_to_metrics_mapping: Dict[BaseNode, List[float]]):
317
299
  """
318
300
  Finalizing the distance metric building.
319
301
  The method checks to see if the maximal distance value is larger than a given threshold, and if so,
@@ -321,21 +303,20 @@ class MixedPrecisionSearchManager:
321
303
  Modification to the dictionary is done inplace.
322
304
 
323
305
  Args:
324
- layer_to_metrics_mapping: A mapping between a node index to a mapping between
325
- a bitwidth index to a distance value.
306
+ layer_to_metrics_mapping: A mapping between a node to a list of distance values per bitwidth candidate.
326
307
 
327
308
  """
328
309
  # normalize metric for numerical stability
310
+ max_dist = max(itertools.chain.from_iterable(layer_to_metrics_mapping.values()))
329
311
 
330
- max_dist = max([max([d for b, d in dists.items()]) for layer, dists in layer_to_metrics_mapping.items()])
331
312
  if max_dist >= self.sensitivity_evaluator.quant_config.metric_normalization_threshold:
332
313
  Logger.warning(f"The mixed precision distance metric values indicate a large error in the quantized model."
333
314
  f"this can cause numerical issues."
334
315
  f"The program will proceed with mixed precision search after scaling the metric values,"
335
316
  f"which can lead to unstable results.")
336
317
  for layer, dists in layer_to_metrics_mapping.items():
337
- for b, d in dists.items():
338
- layer_to_metrics_mapping[layer][b] /= max_dist
318
+ for i, _ in enumerate(dists):
319
+ layer_to_metrics_mapping[layer][i] /= max_dist
339
320
 
340
321
 
341
322
  class ConfigReconstructionHelper:
@@ -363,7 +344,8 @@ class ConfigReconstructionHelper:
363
344
  self.fw_info = original_graph.fw_info
364
345
 
365
346
  self.virtual_sorted_nodes_names = self.virtual_graph.get_configurable_sorted_nodes_names(self.fw_info)
366
- self.origin_sorted_conf_nodes_names = self.original_graph.get_configurable_sorted_nodes_names(self.fw_info)
347
+ self.origin_sorted_conf_nodes = self.original_graph.get_configurable_sorted_nodes(self.fw_info)
348
+ self.origin_sorted_conf_nodes_names = [n.name for n in self.origin_sorted_conf_nodes]
367
349
 
368
350
  self.origin_node_idx_to_cfg = {}
369
351
 
@@ -375,9 +357,9 @@ class ConfigReconstructionHelper:
375
357
  self.origin_node_idx_to_cfg = {}
376
358
 
377
359
  def reconstruct_config_from_virtual_graph(self,
378
- virtual_mp_cfg: List[int],
360
+ virtual_mp_cfg: Dict[BaseNode, int],
379
361
  changed_virtual_nodes_idx: List[int] = None,
380
- original_base_config: List[int] = None) -> List[int]:
362
+ original_base_config: Dict[BaseNode, int] = None) -> Dict[BaseNode, int]:
381
363
  """
382
364
  Reconstructs the original config for a given virtual graph mixed-precision config.
383
365
  It iterates over all virtual configurable node (that has some chosen bit-width virtual candidate)
@@ -405,21 +387,21 @@ class ConfigReconstructionHelper:
405
387
  [(idx, self.virtual_graph.get_configurable_sorted_nodes(self.fw_info)[idx]) for idx in changed_virtual_nodes_idx]
406
388
  # Iterating only over the virtual nodes that have updated config
407
389
  for virtual_node_idx, n in updated_virtual_nodes:
408
- self.reconstruct_node_config(n, virtual_mp_cfg, virtual_node_idx)
390
+ self.reconstruct_node_config(n, list(virtual_mp_cfg.values()), virtual_node_idx)
409
391
  # Updating reconstructed config for all other nodes based on provided base_config
410
392
  original_sorted_conf_nodes = self.original_graph.get_configurable_sorted_nodes(self.fw_info)
411
- for i in range(len(original_base_config)):
393
+ for i, (n, qc_ind) in enumerate(original_base_config.items()):
412
394
  if i not in list(self.origin_node_idx_to_cfg.keys()):
413
- self.update_config_at_original_idx(n=original_sorted_conf_nodes[i],
414
- origin_cfg_idx=original_base_config[i])
395
+ self.update_config_at_original_idx(n=n, origin_cfg_idx=qc_ind)
415
396
  else:
416
397
  # Reconstruct entire config
417
398
  for virtual_node_idx, n in enumerate(self.virtual_graph.get_configurable_sorted_nodes(self.fw_info)):
418
- self.reconstruct_node_config(n, virtual_mp_cfg, virtual_node_idx)
399
+ self.reconstruct_node_config(n, list(virtual_mp_cfg.values()), virtual_node_idx)
419
400
 
420
401
  res_config = [self.origin_node_idx_to_cfg[key] for key in sorted(self.origin_node_idx_to_cfg.keys())]
421
402
  self._clear_reconstruction_dict()
422
- return res_config
403
+ assert len(res_config) == len(self.origin_sorted_conf_nodes)
404
+ return {n: candidate_idx for n, candidate_idx in zip(self.origin_sorted_conf_nodes, res_config)}
423
405
 
424
406
  def reconstruct_node_config(self,
425
407
  n: BaseNode,
@@ -12,9 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from collections import defaultdict
16
+
15
17
  import numpy as np
16
18
  from pulp import *
17
- from typing import Dict, Tuple, List
19
+ from typing import Dict, Tuple, Any
18
20
 
19
21
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget
20
22
 
@@ -30,23 +32,23 @@ class MixedPrecisionIntegerLPSolver:
30
32
  candidates_ru: resource utilization per candidate.
31
33
  ru_constraints: resource utilization constraints corresponding to 'candidates_ru'.
32
34
  """
33
- def __init__(self, layer_to_sensitivity_mapping: Dict[int, Dict[int, float]],
35
+ def __init__(self,
36
+ layer_to_sensitivity_mapping: Dict[Any, List[float]],
34
37
  candidates_ru: Dict[RUTarget, np.ndarray],
35
38
  ru_constraints: Dict[RUTarget, np.ndarray]):
39
+
36
40
  self.layer_to_sensitivity_mapping = layer_to_sensitivity_mapping
37
41
  self.candidates_ru = candidates_ru
38
42
  self.ru_constraints = ru_constraints
39
43
 
40
- self.layer_to_indicator_vars_mapping, self.layer_to_objective_vars_mapping = (
41
- self._init_problem_vars(layer_to_sensitivity_mapping))
44
+ self.layer_to_indicator_vars, self.objective_vars = self._init_problem_vars(layer_to_sensitivity_mapping)
42
45
 
43
- def run(self) -> List[int]:
46
+ def run(self) -> Dict[Any, int]:
44
47
  """
45
48
  Build and solve an ILP optimization problem.
46
49
 
47
50
  Returns:
48
- The mixed-precision configuration (A list of indices. Each indicates the bitwidth index of a node).
49
-
51
+ A dictionary from layer to the index of the selected bitwidth candidate.
50
52
  """
51
53
  # Add all equations and inequalities that define the problem.
52
54
  lp_problem = self._formalize_problem()
@@ -59,17 +61,14 @@ class MixedPrecisionIntegerLPSolver:
59
61
  raise RuntimeError(f'No solution was found for the LP problem, with status {lp_problem.status}')
60
62
 
61
63
  # Take the bitwidth index only if its corresponding indicator is one.
62
- config = np.asarray(
63
- [[nbits for nbits, indicator in nbits_to_indicator.items() if indicator.varValue == 1.0] for
64
- nbits_to_indicator
65
- in self.layer_to_indicator_vars_mapping.values()]
66
- ).flatten()
67
-
68
- return config.tolist()
64
+ mp_config = {
65
+ layer: [v.varValue for v in vars].index(1.) for layer, vars in self.layer_to_indicator_vars.items()
66
+ }
67
+ return mp_config
69
68
 
70
69
  @staticmethod
71
- def _init_problem_vars(layer_to_metrics_mapping: Dict[int, Dict[int, float]]) -> Tuple[
72
- Dict[int, Dict[int, LpVariable]], Dict[int, LpVariable]]:
70
+ def _init_problem_vars(layer_to_metrics_mapping: Dict[Any, List[float]]) -> Tuple[Dict[Any, List[LpVariable]],
71
+ List[LpVariable]]:
73
72
  """
74
73
  Initialize the LP problem variables: Variable for each layer as to the index of the bitwidth it should use,
75
74
  and a variable for each indicator for whether we use the former variable or not.
@@ -83,21 +82,18 @@ class MixedPrecisionIntegerLPSolver:
83
82
  and the second for indicators for each variable.
84
83
  """
85
84
 
86
- layer_to_indicator_vars_mapping = dict()
87
- layer_to_objective_vars_mapping = dict()
88
-
89
- for layer, nbits_to_metric in layer_to_metrics_mapping.items():
90
- layer_to_indicator_vars_mapping[layer] = dict()
85
+ layer_to_indicator_vars = defaultdict(list)
86
+ objective_vars = []
91
87
 
92
- for nbits in nbits_to_metric.keys():
93
- layer_to_indicator_vars_mapping[layer][nbits] = LpVariable(f"layer_{layer}_{nbits}",
94
- lowBound=0,
95
- upBound=1,
96
- cat=LpInteger)
88
+ for layer_idx, (layer, bitwidth_metrics) in enumerate(layer_to_metrics_mapping.items()):
89
+ layer_to_indicator_vars[layer] = [
90
+ LpVariable(f"layer_{layer_idx}_{qc_idx}", lowBound=0, upBound=1, cat=LpInteger)
91
+ for qc_idx, _ in enumerate(bitwidth_metrics)
92
+ ]
97
93
 
98
- layer_to_objective_vars_mapping[layer] = LpVariable(f"s_{layer}", 0)
94
+ objective_vars.append(LpVariable(f"s_{layer_idx}", 0))
99
95
 
100
- return layer_to_indicator_vars_mapping, layer_to_objective_vars_mapping
96
+ return layer_to_indicator_vars, objective_vars
101
97
 
102
98
  def _formalize_problem(self) -> LpProblem:
103
99
  """
@@ -108,18 +104,16 @@ class MixedPrecisionIntegerLPSolver:
108
104
  """
109
105
 
110
106
  lp_problem = LpProblem() # minimization problem by default
111
- lp_problem += lpSum([self.layer_to_objective_vars_mapping[layer] for layer in
112
- self.layer_to_sensitivity_mapping.keys()]) # Objective (minimize acc loss)
107
+ lp_problem += lpSum(self.objective_vars)
113
108
 
114
- for layer in self.layer_to_sensitivity_mapping.keys():
109
+ for layer_sensitivity, layer_indicator_vars, obj_var in zip(self.layer_to_sensitivity_mapping.values(),
110
+ self.layer_to_indicator_vars.values(),
111
+ self.objective_vars):
115
112
  # Use every bitwidth for every layer with its indicator.
116
- lp_problem += lpSum([indicator * self.layer_to_sensitivity_mapping[layer][nbits]
117
- for nbits, indicator in self.layer_to_indicator_vars_mapping[layer].items()]) == \
118
- self.layer_to_objective_vars_mapping[layer]
113
+ lp_problem += lpSum(list(np.multiply(layer_indicator_vars, layer_sensitivity))) == obj_var
119
114
 
120
115
  # Constraint of only one indicator==1
121
- lp_problem += lpSum(
122
- [v for v in self.layer_to_indicator_vars_mapping[layer].values()]) == 1
116
+ lp_problem += lpSum(layer_indicator_vars) == 1
123
117
 
124
118
  # Bound the feasible solution space with the desired resource utilization values.
125
119
  self._add_ru_constraints(lp_problem=lp_problem)
@@ -134,10 +128,7 @@ class MixedPrecisionIntegerLPSolver:
134
128
  Args:
135
129
  lp_problem: An Lp problem object to add constraint to.
136
130
  """
137
- indicators = []
138
- for layer in self.layer_to_sensitivity_mapping:
139
- indicators.extend(list(self.layer_to_indicator_vars_mapping[layer].values()))
140
- indicators_vec = np.array(indicators)
131
+ indicator_vars = list(itertools.chain(*self.layer_to_indicator_vars.values()))
141
132
 
142
133
  for target, ru_matrix in self.candidates_ru.items():
143
134
  # We expect 2d matrix of shape (num candidates, m). For cumulative metrics (weights, bops) m=1 - overall
@@ -146,7 +137,7 @@ class MixedPrecisionIntegerLPSolver:
146
137
  if target in [RUTarget.WEIGHTS, RUTarget.BOPS]:
147
138
  assert ru_matrix.shape[1] == 1
148
139
 
149
- indicated_ru_matrix = ru_matrix.T * indicators_vec
140
+ indicated_ru_matrix = ru_matrix.T * np.array(indicator_vars)
150
141
  # build lp sum term over all candidates
151
142
  ru_vec = indicated_ru_matrix.sum(axis=1)
152
143
 
@@ -16,6 +16,7 @@
16
16
  from typing import List, Tuple, Dict
17
17
 
18
18
  from model_compression_toolkit.core import ResourceUtilization
19
+ from model_compression_toolkit.core.common import BaseNode
19
20
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import \
20
21
  MixedPrecisionSearchManager
21
22
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
@@ -23,9 +24,9 @@ from model_compression_toolkit.core.common.quantization.candidate_node_quantizat
23
24
  from model_compression_toolkit.logger import Logger
24
25
 
25
26
 
26
- def greedy_solution_refinement_procedure(mp_solution: List[int],
27
+ def greedy_solution_refinement_procedure(mp_solution: Dict[BaseNode, int],
27
28
  search_manager: MixedPrecisionSearchManager,
28
- target_resource_utilization: ResourceUtilization) -> List[int]:
29
+ target_resource_utilization: ResourceUtilization) -> Dict[BaseNode, int]:
29
30
  """
30
31
  A greedy procedure to try and improve a mixed-precision solution that was found by a mixed-precision optimization
31
32
  algorithm.
@@ -50,6 +51,8 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
50
51
  Logger.info(f'Target resource utilization constraint BOPs - Skipping MP greedy solution refinement')
51
52
  return mp_solution
52
53
 
54
+ assert search_manager.using_virtual_graph is False
55
+
53
56
  new_solution = mp_solution.copy()
54
57
  changed = True
55
58
 
@@ -58,17 +61,16 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
58
61
  nodes_ru = {}
59
62
  nodes_next_candidate = {}
60
63
 
61
- for node_idx in range(len(mp_solution)):
62
- if new_solution[node_idx] == 0:
64
+ for node in search_manager.mp_topo_configurable_nodes:
65
+ if new_solution[node] == 0:
63
66
  # layer has max config in the given solution, nothing to optimize
64
67
  continue
65
68
 
66
- current_node = search_manager.mp_topo_configurable_nodes[node_idx]
67
- node_candidates = current_node.candidates_quantization_cfg
69
+ node_candidates = node.candidates_quantization_cfg
68
70
 
69
71
  # only weights kernel attribute is quantized with weights mixed precision
70
72
  valid_candidates = _get_valid_candidates_indices(node_candidates,
71
- new_solution[node_idx],
73
+ new_solution[node],
72
74
  target_resource_utilization.activation_restricted(),
73
75
  target_resource_utilization.weight_restricted()
74
76
  )
@@ -77,7 +79,7 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
77
79
  updated_ru = []
78
80
  for valid_idx in valid_candidates:
79
81
  node_updated_ru = search_manager.compute_resource_utilization_for_config(
80
- config=search_manager.replace_config_in_index(new_solution, node_idx, valid_idx))
82
+ config=search_manager.copy_config_with_replacement(new_solution, node, valid_idx))
81
83
  updated_ru.append(node_updated_ru)
82
84
 
83
85
  # filter out new configs that don't hold the resource utilization restrictions
@@ -88,8 +90,8 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
88
90
  sorted_by_ru = sorted(node_filtered_ru, key=lambda node_ru: (node_ru[1].total_memory,
89
91
  node_ru[1].weights_memory,
90
92
  node_ru[1].activation_memory))
91
- nodes_ru[node_idx] = sorted_by_ru[0][1]
92
- nodes_next_candidate[node_idx] = sorted_by_ru[0][0]
93
+ nodes_ru[node] = sorted_by_ru[0][1]
94
+ nodes_next_candidate[node] = sorted_by_ru[0][0]
93
95
 
94
96
  if len(nodes_ru) > 0:
95
97
  # filter out new configs that don't hold the ru restrictions
@@ -102,7 +104,7 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
102
104
  new_solution[node_idx_to_upgrade] = nodes_next_candidate[node_idx_to_upgrade]
103
105
  changed = True
104
106
 
105
- if any([mp_solution[i] != new_solution[i] for i in range(len(mp_solution))]):
107
+ if any([mp_solution[n] != new_solution[n] for n in mp_solution]):
106
108
  Logger.info(f'Greedy MP algorithm changed configuration from (numbers represent indices of the '
107
109
  f'chosen bit-width candidate for each layer):\n{mp_solution}\nto\n{new_solution}')
108
110
 
@@ -464,7 +464,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
464
464
  weights_attr_cfg=attr_cfg,
465
465
  weights_channels_axis=weights_channels_axis)
466
466
 
467
- def get_attr_config(self, attr_name: Union[str, int]) -> WeightsAttrQuantizationConfig:
467
+ def get_attr_config(self, attr_name: 'WeightAttrT') -> WeightsAttrQuantizationConfig:
468
468
  """
469
469
  Returns a weights attribute config for an attribute that contains the given name.
470
470
  If multiple attributes that contain the given name are found - looking for the exact name, otherwise,
@@ -499,7 +499,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
499
499
 
500
500
  return attr_cfg
501
501
 
502
- def set_attr_config(self, attr_name: Union[str, int], attr_qc: WeightsAttrQuantizationConfig):
502
+ def set_attr_config(self, attr_name: 'WeightAttrT', attr_qc: WeightsAttrQuantizationConfig):
503
503
  """
504
504
  Adding a new attribute with quantization configuration to the node's weights configurations mapping.
505
505
 
@@ -513,7 +513,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
513
513
  else:
514
514
  self.attributes_config_mapping[attr_name] = attr_qc
515
515
 
516
- def has_attribute_config(self, attr_name: Union[str, int]) -> bool:
516
+ def has_attribute_config(self, attr_name: 'WeightAttrT') -> bool:
517
517
  """
518
518
  Checks whether the node weights configuration contains a configuration for a given weights attribute.
519
519
 
@@ -541,6 +541,14 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
541
541
  """
542
542
  return list(self.pos_attributes_config_mapping.keys()) + list(self.attributes_config_mapping.keys())
543
543
 
544
+ def get_all_weight_attrs_configs(self) -> Dict['WeightAttrT', AttributeQuantizationConfig]:
545
+ """ Get quantization configs for all weights.
546
+
547
+ Returns:
548
+ A dict from weight attribute to its config.
549
+ """
550
+ return {attr: self.get_attr_config(attr) for attr in self.all_weight_attrs}
551
+
544
552
  def _extract_config_for_attributes_with_name(self, attr_name) -> Dict[str, WeightsAttrQuantizationConfig]:
545
553
  """
546
554
  Extract the saved attributes that contain the given attribute name.
@@ -560,7 +568,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
560
568
  return attrs_with_name
561
569
 
562
570
  def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any,
563
- attr_name: Union[str, int] = None, *args: List[Any], **kwargs: Dict[str, Any]):
571
+ attr_name: 'WeightAttrT' = None, *args: List[Any], **kwargs: Dict[str, Any]):
564
572
  """
565
573
  This method overrides the parent class set_quant_config_attr to enable setting a specific weights
566
574
  attribute config parameter.
@@ -137,11 +137,7 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
137
137
 
138
138
  float_weights = n.get_weights_by_keys(attr)
139
139
 
140
- max_cfg_candidates = n.find_max_candidates_indices()
141
- if not len(max_cfg_candidates) == 1:
142
- Logger.critical(f"A maximal configuration candidate must be defined; found multiple potential maximal candidates.")# pragma: no cover
143
-
144
- max_candidate_idx = max_cfg_candidates[0]
140
+ max_candidate_idx = n.find_max_candidate_index()
145
141
 
146
142
  return {'node_q_cfg': node_q_cfg_candidates,
147
143
  'float_weights': float_weights,
@@ -178,11 +174,7 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
178
174
  # if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
179
175
  n.sort_node_candidates(self.fw_info)
180
176
 
181
- max_cfg_candidates = n.find_max_candidates_indices()
182
- assert len(max_cfg_candidates) == 1, \
183
- f"A maximal config candidate must be defined, but some node have multiple potential maximal candidates"
184
- max_candidate_idx = max_cfg_candidates[0]
185
-
177
+ max_candidate_idx = n.find_max_candidate_index()
186
178
  kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
187
179
  activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': node_q_cfg_candidates,
188
180
  'max_candidate_idx': max_candidate_idx,
@@ -25,7 +25,7 @@ else:
25
25
  from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU # pragma: no cover
26
26
 
27
27
  from model_compression_toolkit.defaultdict import DefaultDict
28
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
28
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo, DEFAULT_KERNEL_ATTRIBUTES
29
29
  from mct_quantizers import QuantizationMethod
30
30
  from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
31
31
  from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
@@ -39,7 +39,7 @@ If a layer that is not listed here is queried, [None] is returned.
39
39
  KERNEL_ATTRIBUTES = DefaultDict({Conv2D: [KERNEL],
40
40
  DepthwiseConv2D: [DEPTHWISE_KERNEL],
41
41
  Dense: [KERNEL],
42
- Conv2DTranspose: [KERNEL]}, [None])
42
+ Conv2DTranspose: [KERNEL]}, DEFAULT_KERNEL_ATTRIBUTES)
43
43
 
44
44
 
45
45
  """
@@ -136,11 +136,7 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
136
136
 
137
137
  float_weights = n.get_weights_by_keys(attr)
138
138
 
139
- max_cfg_candidates = n.find_max_candidates_indices()
140
- if not len(max_cfg_candidates) == 1:
141
- Logger.critical(f"A maximal configuration candidate must be uniquely defined; however, multiple potential maximal candidates were found.") # pragma: no cover
142
-
143
- max_candidate_idx = max_cfg_candidates[0]
139
+ max_candidate_idx = n.find_max_candidate_index()
144
140
 
145
141
  return {'node_q_cfg': node_q_cfg_candidates,
146
142
  'float_weights': float_weights,
@@ -175,10 +171,7 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
175
171
  # if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
176
172
  n.sort_node_candidates(self.fw_info)
177
173
 
178
- max_cfg_candidates = n.find_max_candidates_indices()
179
- assert len(max_cfg_candidates) == 1, \
180
- f"A maximal configuration candidate must be uniquely defined; however, multiple potential maximal candidates were found."
181
- max_candidate_idx = max_cfg_candidates[0]
174
+ max_candidate_idx = n.find_max_candidate_index()
182
175
 
183
176
  kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
184
177
  activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': node_q_cfg_candidates,
@@ -18,7 +18,7 @@ from torch.nn import Conv2d, ConvTranspose2d, Linear
18
18
  from torch import sigmoid
19
19
 
20
20
  from model_compression_toolkit.defaultdict import DefaultDict
21
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
21
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo, DEFAULT_KERNEL_ATTRIBUTES
22
22
  from mct_quantizers import QuantizationMethod
23
23
  from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
24
24
  from model_compression_toolkit.core.pytorch.constants import KERNEL
@@ -33,7 +33,7 @@ If a layer that is not listed here is queried, [None] is returned.
33
33
  KERNEL_ATTRIBUTES = DefaultDict({Conv2d: [KERNEL],
34
34
  ConvTranspose2d: [KERNEL],
35
35
  Linear: [KERNEL]},
36
- [None])
36
+ DEFAULT_KERNEL_ATTRIBUTES)
37
37
 
38
38
  """
39
39
  Map a layer to its kernel's output and input channels indices.