mct-nightly 2.3.0.20250512.625__py3-none-any.whl → 2.3.0.20250513.611__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. {mct_nightly-2.3.0.20250512.625.dist-info → mct_nightly-2.3.0.20250513.611.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.3.0.20250512.625.dist-info → mct_nightly-2.3.0.20250513.611.dist-info}/RECORD +23 -23
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_implementation.py +6 -33
  5. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +22 -3
  6. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +8 -5
  7. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +69 -58
  8. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +82 -79
  9. model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py +32 -26
  10. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -4
  11. model_compression_toolkit/core/common/quantization/node_quantization_config.py +7 -0
  12. model_compression_toolkit/core/common/similarity_analyzer.py +1 -1
  13. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +37 -73
  14. model_compression_toolkit/core/keras/keras_implementation.py +8 -45
  15. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +7 -5
  16. model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +6 -5
  17. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +46 -78
  18. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +7 -9
  19. model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +12 -10
  20. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -41
  21. {mct_nightly-2.3.0.20250512.625.dist-info → mct_nightly-2.3.0.20250513.611.dist-info}/WHEEL +0 -0
  22. {mct_nightly-2.3.0.20250512.625.dist-info → mct_nightly-2.3.0.20250513.611.dist-info}/licenses/LICENSE.md +0 -0
  23. {mct_nightly-2.3.0.20250512.625.dist-info → mct_nightly-2.3.0.20250513.611.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250512.625
3
+ Version: 2.3.0.20250513.611
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Author-email: ssi-dnn-dev@sony.com
6
6
  Classifier: Programming Language :: Python :: 3
@@ -1,5 +1,5 @@
1
- mct_nightly-2.3.0.20250512.625.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=eq4dWyKngj1OrsAWFhmI9beCZaLAWyvU5DgVpuAMXK4,1557
1
+ mct_nightly-2.3.0.20250513.611.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=p_G6GkwHl_GiPtc0E2qL6iUBG-UpYcgFx1HDi073s0Q,1557
3
3
  model_compression_toolkit/constants.py,sha256=iJ6vfTjC2oFIZWt8wvHoxEw5YJi3yl0Hd4q30_8q0Zc,3958
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
@@ -12,14 +12,14 @@ model_compression_toolkit/core/quantization_prep_runner.py,sha256=DPevqQ8brkdut8
12
12
  model_compression_toolkit/core/runner.py,sha256=_r6cieb7Ur2BeHQK5XxTZHogjyA0utybvIVbH06CBHY,13056
13
13
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
14
14
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
15
- model_compression_toolkit/core/common/framework_implementation.py,sha256=L88uv_sfYM_56FSmxXP--emjv01_lk7IPqOI7QBZEt0,22939
15
+ model_compression_toolkit/core/common/framework_implementation.py,sha256=JQI_eoZZoNk5Y_jAxLfYt9-wzfs7zGpTldz9UblxmMc,21182
16
16
  model_compression_toolkit/core/common/framework_info.py,sha256=5tderHT-7Cd21QrRFIJj3hH_gAcnlivOzwZ5m1ldJOs,6526
17
17
  model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
18
18
  model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
19
19
  model_compression_toolkit/core/common/model_collector.py,sha256=Tno3-qx9jmPZAZyLYgbPlMLHakVfuEH5deuToZNuCb0,13195
20
20
  model_compression_toolkit/core/common/model_validation.py,sha256=LaG8wd6aZl0OJgieE3SeiVDEPxtk8IHq9-3wSnmWhY4,1214
21
21
  model_compression_toolkit/core/common/node_prior_info.py,sha256=WXX_PrGVG9M9I_REG5ZzFBohwmV4yf356sZnrja_FLo,2832
22
- model_compression_toolkit/core/common/similarity_analyzer.py,sha256=FikcIqgQQpfiXr9VJvgl-wk8OyH7-LvC8ku7TkhJfJM,9200
22
+ model_compression_toolkit/core/common/similarity_analyzer.py,sha256=S3f6WgHyw62dGcxpX51FGKyfebe2zv9ABKbjtGyKRvY,9215
23
23
  model_compression_toolkit/core/common/user_info.py,sha256=dSRMnT-oewmdOziIpEuW-s9K7vTSeyUBxT4z9neXurI,1648
24
24
  model_compression_toolkit/core/common/back2framework/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
25
25
  model_compression_toolkit/core/common/back2framework/base_model_builder.py,sha256=V1oShKzbSkdcTvREn8VnQQBzvm-tTHkWMXqMkYozF2s,2023
@@ -66,13 +66,13 @@ model_compression_toolkit/core/common/mixed_precision/configurable_quant_id.py,s
66
66
  model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py,sha256=7dKMi5S0zQZ16m8NWn1XIuoXsKuZUg64G4-uK8-j1PQ,5177
67
67
  model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=-x8edUyudu1EAEM66AuXPtgayLpzbxoLNubfEbFM5kU,2867
68
68
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=6pLUEEIqRTVIlCYQC4JIvY55KAvuBHEX8uTOQ-1Ac4Q,3859
69
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=onHgDwfw8CUbZFNU-RYit9eqA6FrzAtFA3akVZ2d7IM,4533
69
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=BO4ouM_UVS9Fg0z95gLJSMz1ep6YQC5za_iXI_qW2yQ,5399
70
70
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=-hOMBucYn12ePyLd0b1KxniPOIRu4b53SwEzv0bWToI,4943
71
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=d5-3j2e_rdcQOT7c4s0p7640i3nSetjJ6MgMhhMM7dc,6152
72
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=Lk5cftihGpgFQoyqnRGiwJFFqkI8dkx0l1q0sVJi2CE,27505
73
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=R3UIO9lKf-lpEGfJOqgpQAXdP1IWMatWxXKYDkhWj_E,28096
74
- model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
75
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=S1ChgxtUjzXJufNWyRbKoNdyNC6fGUjPeComDMx8ZCo,9479
71
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=KhiHGpmN5QbpyJQnTZmXigdXFlSlRNqpOOyKGj1Fwek,6412
72
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=OzRhlJ2IS9Dwv0rgobee0xTtAeRwlBC6KvVEcx2_oB0,28089
73
+ model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=4uhUXKgwyMrJqEVK5uJzVr67GI5YzDTHLveV4maB7z0,28079
74
+ model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=Zn6SgzGLWWKmuYGHd1YtKxZdYnQWRDeXEkKlBiTbHcs,2929
75
+ model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=MY8df-c_kITEr_7hOctaxhdiq29hSTA0La9Qo0oTJJY,9678
76
76
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
77
77
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
78
78
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=-kNcmQQFVHRPizInaRrCEIuh_q_57CWxC6CIV6azF4g,39640
@@ -106,7 +106,7 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
106
106
  model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
107
107
  model_compression_toolkit/core/common/quantization/debug_config.py,sha256=uH45Uq3Tp9FIyMynex_WY2_y-Kv8LuPw2XXZydnpW5A,1649
108
108
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=n2A8pO7_DMMae4o69U0I00iW6mzeRlRfKHDxlQUBBuI,7204
109
- model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=0OJZtQuv-StbKZOpalvGi9lcpHJNRPeuclevSaCPggc,29792
109
+ model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=XmuG1ZwBxZEEcVKBwCo_v3vfjVjIJqtyO94QCeczddw,30131
110
110
  model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=UkSVW7d1OF_Px9gAjsqqK65aYhIBFWaBO-_IH6_AFfg,4403
111
111
  model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=HfBkSiRTOf9mNF-TNQHTCCs3xSg66F20no0O6vl5v1Y,2154
112
112
  model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=7eG7dl1TcbdnHwgmvyjarxLs0o6Lw_9VAjXAm4rsiBk,3791
@@ -158,7 +158,7 @@ model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiO
158
158
  model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
159
159
  model_compression_toolkit/core/keras/data_util.py,sha256=jm54o-SlI1DJ-sEvRuX9OyLN68tEt0VxcqrdIjR98Ag,8366
160
160
  model_compression_toolkit/core/keras/default_framework_info.py,sha256=DvK1Tr6z3cQlJw1nx62iFaeSsQSXJl55xOIcJ1uNGu8,5020
161
- model_compression_toolkit/core/keras/keras_implementation.py,sha256=_15BrSGTRSSp_8ayuo2x-hdKanew1xuIPSumP46IGSA,32545
161
+ model_compression_toolkit/core/keras/keras_implementation.py,sha256=bQFnX6OjUBo3q4aEPrxbXigqSoIqAsm5YLoXWDvqghE,30048
162
162
  model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
163
163
  model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=HUmzEXDQ8LGX7uOYSRiLZ2TNbYxLX9J9IeAa6QYlifg,3927
164
164
  model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=XBCmUrHy_fNQCfSjnXCpwuEtc7cda4hXySuiIzhFGqc,5696
@@ -168,7 +168,7 @@ model_compression_toolkit/core/keras/back2framework/factory_model_builder.py,sha
168
168
  model_compression_toolkit/core/keras/back2framework/float_model_builder.py,sha256=9SFHhX-JnkB8PvYIIHRYlReBDI_RkZY9LditzW_ElLk,2444
169
169
  model_compression_toolkit/core/keras/back2framework/instance_builder.py,sha256=fBj13c6zkVoWX4JJG18_uXPptiEJqXClE_zFbaFB6Q8,4517
170
170
  model_compression_toolkit/core/keras/back2framework/keras_model_builder.py,sha256=TY86-Mb8hmo8RgCcQvkSYIthYOqV9e4VIMpqIyouJ4Y,17397
171
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py,sha256=BTDJB6VUAyVapzkwnftdXkv9RaQfwp_GIEk1FyovdGg,14813
171
+ model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py,sha256=s3uxI5X7q82L1Bui02Z16yeGjKBiFXjjJoPoBhUGO0Y,11804
172
172
  model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py,sha256=5wFb4nx_F0Wu4c8pLf6n6OzxOHtpOJ6_3mQsNSXIudU,2481
173
173
  model_compression_toolkit/core/keras/graph_substitutions/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
174
174
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
@@ -198,8 +198,8 @@ model_compression_toolkit/core/keras/hessian/activation_hessian_scores_calculato
198
198
  model_compression_toolkit/core/keras/hessian/hessian_scores_calculator_keras.py,sha256=1o7X9GXSfpEmuB5ee2AaBQ2sN2xzX4-smbrq_0qOGRU,4454
199
199
  model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py,sha256=Rl6NNGkHMV0ioEM5bbM4XX7yHDqG6mMp4ifN2VQBDxE,12168
200
200
  model_compression_toolkit/core/keras/mixed_precision/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
201
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=WFwPtCcXR3qY86OML_jyzasvdd2DGhy4-GveAGpDOt0,5075
202
- model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py,sha256=38Lvwux9L35oT6muck6_FH7nDdH2N8_kuGDMj4-QNpE,6647
201
+ model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=aB3y8vaiXSmy1bpvlqXDswL3-FTz019s5r-lcb4FKhE,5254
202
+ model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py,sha256=GtW0yars8PzqP9uL_vfXrtqHwKiStmOxPng20rYaIjU,6805
203
203
  model_compression_toolkit/core/keras/pruning/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
204
204
  model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py,sha256=EJkblZ4OAjI5l29GKsUraam5Jn58Sogld47_rFFyr3k,12777
205
205
  model_compression_toolkit/core/keras/quantizer/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
@@ -224,7 +224,7 @@ model_compression_toolkit/core/pytorch/constants.py,sha256=Sg0hkUaMe88mI2_pd3Kqh
224
224
  model_compression_toolkit/core/pytorch/data_util.py,sha256=YYbT135HhlTt0q6XdD2JX7AS_L92f_uV2rWq2hsJOCA,6325
225
225
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=-byHTXmQEuOiqTAX45BHGi3mRRBF4_EfJ3XhpmVilSU,4355
226
226
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
227
- model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=c_QFo4e7t6b21CDakGhjVpqy5aXFxxqkdJ-s54HEOfs,31207
227
+ model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=33P9kttfgVKtbqOgmstSLuzs6MSwXaK8MbvtAScqiBI,28755
228
228
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
229
229
  model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=aIHl-dTAC4ISnWSKLD99c-1W3827vfRGyLjMBib-l3s,5618
230
230
  model_compression_toolkit/core/pytorch/utils.py,sha256=xNVE7YMtHupLEimIJcxmfcMGM4XKB9I1v0-K8lDeLB8,3936
@@ -232,7 +232,7 @@ model_compression_toolkit/core/pytorch/back2framework/__init__.py,sha256=H_WixgN
232
232
  model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,sha256=bwppTPRs6gL96nm7qPiKrNcBj4Krr0yEsOWjRF0aXmQ,2339
233
233
  model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
234
234
  model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
235
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=Pyk2hrwCB-9HyhZUqtKFK_t6YBL5yKB-S0CStfEhx_M,14675
235
+ model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=HAzzWOnPcIeDxQO1712254RNTBZD-gVSMSVnxqpfuQ0,11907
236
236
  model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=Z-ZQV-GWdOBGPbksiWBQ8MtFkQ41qgUKU5d5c8aNSjQ,21646
237
237
  model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
238
238
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
@@ -268,8 +268,8 @@ model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calcula
268
268
  model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py,sha256=8f_XlM8ZFVQPNGr1iECr1hv8QusYDrNU_vTkLQZE9RU,2477
269
269
  model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py,sha256=UzWxWDbr8koKZatEcPn8RCb0Zjm_7fKTvIGb98sp18k,8487
270
270
  model_compression_toolkit/core/pytorch/mixed_precision/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
271
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py,sha256=mcY_KOQgABIqGIMh0x6mNxaKp7SFNbkEIYavR2X7SQ4,4754
272
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py,sha256=zp1Xp75IDf9LN5YGO2UzeDbms_6ICQ_pSE1ORQr-SA8,6281
271
+ model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py,sha256=DVQrEJbB7MKj_LitU92cBxDApwnAAkilYvQzkr79ffg,4813
272
+ model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py,sha256=KDnwmbhvhJMfNg1IuTvvzBNEriPQH9bL9dJ5VvWTzpE,6631
273
273
  model_compression_toolkit/core/pytorch/pruning/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
274
274
  model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py,sha256=VfEEVwWEXKpVlZFnr7N6mvEjcpq85ROLg05ZvXfD1Pg,14764
275
275
  model_compression_toolkit/core/pytorch/quantizer/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
528
528
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
529
529
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
530
530
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
531
- mct_nightly-2.3.0.20250512.625.dist-info/METADATA,sha256=Rl3DUdbepLzTi6pJhMjSt65rk2rnMQgUiJYzRi-dCPk,25136
532
- mct_nightly-2.3.0.20250512.625.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
533
- mct_nightly-2.3.0.20250512.625.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
- mct_nightly-2.3.0.20250512.625.dist-info/RECORD,,
531
+ mct_nightly-2.3.0.20250513.611.dist-info/METADATA,sha256=dx0fsYTzsB_Y1IVuSNMaJPgPO4lhotb3TlDZ-dq2JF8,25136
532
+ mct_nightly-2.3.0.20250513.611.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
533
+ mct_nightly-2.3.0.20250513.611.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
+ mct_nightly-2.3.0.20250513.611.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250512.000625"
30
+ __version__ = "2.3.0.20250513.000611"
@@ -13,24 +13,20 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from abc import ABC, abstractmethod
16
- from typing import Callable, Any, List, Tuple, Dict, Generator
16
+ from typing import Callable, Any, List, Tuple, Generator, Type
17
17
 
18
18
  import numpy as np
19
19
 
20
20
  from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
21
- from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
22
21
  from model_compression_toolkit.core import common
23
22
  from model_compression_toolkit.core.common import BaseNode
24
- from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
25
23
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
26
24
  from model_compression_toolkit.core.common.graph.base_graph import Graph
27
- from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianInfoService
28
- from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
25
+ from model_compression_toolkit.core.common.hessian import HessianScoresRequest
29
26
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
30
27
  from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
31
28
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
32
29
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
33
- from model_compression_toolkit.core.common.user_info import UserInformation
34
30
 
35
31
 
36
32
  class FrameworkImplementation(ABC):
@@ -38,6 +34,10 @@ class FrameworkImplementation(ABC):
38
34
  An abstract class with abstract methods that should be implemented when supporting a new
39
35
  framework in MCT.
40
36
  """
37
+ weights_quant_layer_cls: Type
38
+ activation_quant_layer_cls: Type
39
+ configurable_weights_quantizer_cls: Type
40
+ configurable_activation_quantizer_cls: Type
41
41
 
42
42
  @property
43
43
  def constants(self):
@@ -327,33 +327,6 @@ class FrameworkImplementation(ABC):
327
327
  f'framework\'s get_substitutions_after_second_moment_correction '
328
328
  f'method.') # pragma: no cover
329
329
 
330
- @abstractmethod
331
- def get_sensitivity_evaluator(self,
332
- graph: Graph,
333
- quant_config: MixedPrecisionQuantizationConfig,
334
- representative_data_gen: Callable,
335
- fw_info: FrameworkInfo,
336
- hessian_info_service: HessianInfoService = None,
337
- disable_activation_for_metric: bool = False) -> SensitivityEvaluation:
338
- """
339
- Creates and returns an object which handles the computation of a sensitivity metric for a mixed-precision
340
- configuration (comparing to the float model).
341
-
342
- Args:
343
- graph: Graph to build its float and mixed-precision models.
344
- quant_config: QuantizationConfig of how the model should be quantized.
345
- representative_data_gen: Dataset to use for retrieving images for the models inputs.
346
- fw_info: FrameworkInfo object with information about the specific framework's model.
347
- disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
348
- hessian_info_service: HessianInfoService to fetch information based on Hessian-approximation.
349
-
350
- Returns:
351
- A function that computes the metric.
352
- """
353
-
354
- raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
355
- f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover
356
-
357
330
  def get_node_prior_info(self, node: BaseNode,
358
331
  fw_info: FrameworkInfo,
359
332
  graph: Graph) -> NodePriorInfo:
@@ -14,11 +14,23 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from dataclasses import dataclass, field
17
+ from enum import Enum
17
18
  from typing import List, Callable, Optional
18
19
  from model_compression_toolkit.constants import MP_DEFAULT_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
19
20
  from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
20
21
 
21
22
 
23
+ class MpMetricNormalization(Enum):
24
+ """
25
+ MAXBIT: normalize sensitivity metrics of layer candidates by max-bitwidth candidate (of that layer).
26
+ MINBIT: normalize sensitivity metrics of layer candidates by min-bitwidth candidate (of that layer).
27
+ NONE: no normalization.
28
+ """
29
+ MAXBIT = 'MAXBIT'
30
+ MINBIT = 'MINBIT'
31
+ NONE = 'NONE'
32
+
33
+
22
34
  @dataclass
23
35
  class MixedPrecisionQuantizationConfig:
24
36
  """
@@ -27,7 +39,6 @@ class MixedPrecisionQuantizationConfig:
27
39
  Args:
28
40
  compute_distance_fn (Callable): Function to compute a distance between two tensors. If None, using pre-defined distance methods based on the layer type for each layer.
29
41
  distance_weighting_method (MpDistanceWeighting): MpDistanceWeighting enum value that provides a function to use when weighting the distances among different layers when computing the sensitivity metric.
30
- custom_metric_fn (Callable): Function to compute a custom metric. As input gets the model_mp and returns a float value for metric. If None, uses interest point metric.
31
42
  num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
32
43
  configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one.
33
44
  num_interest_points_factor (float): A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric.
@@ -36,11 +47,16 @@ class MixedPrecisionQuantizationConfig:
36
47
  refine_mp_solution (bool): Whether to try to improve the final mixed-precision configuration using a greedy algorithm that searches layers to increase their bit-width, or not.
37
48
  metric_normalization_threshold (float): A threshold for checking the mixed precision distance metric values, In case of values larger than this threshold, the metric will be scaled to prevent numerical issues.
38
49
  hessian_batch_size (int): The Hessian computation batch size. used only if using mixed precision with Hessian-based objective.
39
- """
50
+ metric_normalization (MpMetricNormalization): Metric normalization method.
51
+ metric_epsilon (float | None): ensure minimal distance between the metric for any non-max-bidwidth candidate
52
+ and a max-bitwidth candidate, i.e. metric(non-max-bitwidth) >= metric(max-bitwidth) + epsilon.
53
+ If none, the computed metrics are used as is.
54
+ custom_metric_fn (Callable): Function to compute a custom metric. As input gets the model_mp and returns a
55
+ float value for metric. If None, uses interest point metric.
40
56
 
57
+ """
41
58
  compute_distance_fn: Optional[Callable] = None
42
59
  distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG
43
- custom_metric_fn: Optional[Callable] = None
44
60
  num_of_images: int = MP_DEFAULT_NUM_SAMPLES
45
61
  configuration_overwrite: Optional[List[int]] = None
46
62
  num_interest_points_factor: float = field(default=1.0, metadata={"description": "Should be between 0.0 and 1.0"})
@@ -49,6 +65,9 @@ class MixedPrecisionQuantizationConfig:
49
65
  refine_mp_solution: bool = True
50
66
  metric_normalization_threshold: float = 1e10
51
67
  hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
68
+ metric_normalization: MpMetricNormalization = MpMetricNormalization.NONE
69
+ metric_epsilon: Optional[float] = 1e-6
70
+ custom_metric_fn: Optional[Callable] = None
52
71
  _is_mixed_precision_enabled: bool = field(init=False, default=False)
53
72
 
54
73
  def __post_init__(self):
@@ -25,6 +25,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_searc
25
25
  MixedPrecisionSearchManager
26
26
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
27
27
  ResourceUtilization
28
+ from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
28
29
  from model_compression_toolkit.core.common.mixed_precision.solution_refinement_procedure import \
29
30
  greedy_solution_refinement_procedure
30
31
 
@@ -78,11 +79,12 @@ def search_bit_width(graph: Graph,
78
79
 
79
80
  # Set Sensitivity Evaluator for MP search. It should always work with the original MP graph,
80
81
  # even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
81
- se = fw_impl.get_sensitivity_evaluator(
82
+ se = SensitivityEvaluation(
82
83
  graph,
83
84
  mp_config,
84
85
  representative_data_gen=representative_data_gen,
85
86
  fw_info=fw_info,
87
+ fw_impl=fw_impl,
86
88
  disable_activation_for_metric=disable_activation_for_metric,
87
89
  hessian_info_service=hessian_info_service)
88
90
 
@@ -96,10 +98,11 @@ def search_bit_width(graph: Graph,
96
98
 
97
99
  # Search manager and LP are highly coupled, so LP search method was moved inside search manager.
98
100
  search_manager = MixedPrecisionSearchManager(graph,
99
- fw_info,
100
- fw_impl,
101
- se,
102
- target_resource_utilization)
101
+ fw_info=fw_info,
102
+ fw_impl=fw_impl,
103
+ sensitivity_evaluator=se,
104
+ target_resource_utilization=target_resource_utilization,
105
+ mp_config=mp_config)
103
106
  nodes_bit_cfg = search_manager.search()
104
107
 
105
108
  graph.skip_validation_check = False
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import os
16
+
15
17
  import itertools
16
18
 
17
19
  import copy
@@ -19,7 +21,7 @@ from collections import defaultdict
19
21
 
20
22
  from tqdm import tqdm
21
23
 
22
- from typing import Dict, List, Tuple, Optional
24
+ from typing import Dict, List, Tuple, Optional, Set
23
25
 
24
26
  import numpy as np
25
27
 
@@ -40,6 +42,8 @@ from model_compression_toolkit.core.common.mixed_precision.search_methods.linear
40
42
  from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
41
43
  from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
42
44
  from model_compression_toolkit.logger import Logger
45
+ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
46
+ MixedPrecisionQuantizationConfig, MpMetricNormalization
43
47
 
44
48
 
45
49
  class MixedPrecisionSearchManager:
@@ -52,7 +56,8 @@ class MixedPrecisionSearchManager:
52
56
  fw_info: FrameworkInfo,
53
57
  fw_impl: FrameworkImplementation,
54
58
  sensitivity_evaluator: SensitivityEvaluation,
55
- target_resource_utilization: ResourceUtilization):
59
+ target_resource_utilization: ResourceUtilization,
60
+ mp_config: MixedPrecisionQuantizationConfig):
56
61
  """
57
62
 
58
63
  Args:
@@ -74,21 +79,21 @@ class MixedPrecisionSearchManager:
74
79
 
75
80
  self.sensitivity_evaluator = sensitivity_evaluator
76
81
  self.target_resource_utilization = target_resource_utilization
82
+ self.mp_config = mp_config
77
83
 
78
84
  self.mp_topo_configurable_nodes = self.mp_graph.get_configurable_sorted_nodes(fw_info)
79
85
 
80
86
  self.ru_targets = target_resource_utilization.get_restricted_targets()
81
- self.ru_helper = MixedPrecisionRUHelper(self.original_graph, fw_info, fw_impl)
87
+ self.orig_graph_ru_helper = MixedPrecisionRUHelper(self.original_graph, fw_info, fw_impl)
82
88
 
83
89
  self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config(fw_info)
84
- self.max_ru_config: Dict[BaseNode, int] = self.mp_graph.get_max_candidates_config(fw_info)
85
90
 
86
- self.config_reconstruction_helper = ConfigReconstructionHelper(self.original_graph)
91
+ self.config_reconstructor = None
92
+ orig_min_config = self.min_ru_config
87
93
  if self.using_virtual_graph:
88
- real_min_ru_config = self.config_reconstruction_helper.reconstruct_full_configuration(self.min_ru_config)
89
- self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, real_min_ru_config)
90
- else:
91
- self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, self.min_ru_config)
94
+ self.config_reconstructor = ConfigReconstructionHelper(self.original_graph)
95
+ orig_min_config = self.config_reconstructor.reconstruct_full_configuration(self.min_ru_config)
96
+ self.min_ru = self.orig_graph_ru_helper.compute_utilization(self.ru_targets, orig_min_config)
92
97
 
93
98
  def search(self) -> Dict[BaseNode, int]:
94
99
  """
@@ -100,7 +105,7 @@ class MixedPrecisionSearchManager:
100
105
  mp_config = self._prepare_and_run_solver()
101
106
 
102
107
  if self.using_virtual_graph:
103
- mp_config = self.config_reconstruction_helper.reconstruct_full_configuration(mp_config)
108
+ mp_config = self.config_reconstructor.reconstruct_full_configuration(mp_config)
104
109
 
105
110
  return mp_config
106
111
 
@@ -143,61 +148,64 @@ class MixedPrecisionSearchManager:
143
148
  f"following targets: {unsatisfiable_targets}")
144
149
  return rel_target_ru
145
150
 
146
- def _build_sensitivity_mapping(self, eps: float = 1e-6) -> Dict[BaseNode, List[float]]:
151
+ def _build_sensitivity_mapping(self) -> Dict[BaseNode, List[float]]:
147
152
  """
148
153
  This function measures the sensitivity of a change in a bitwidth of a layer on the entire model.
149
154
 
150
- Args:
151
- eps: if sensitivity for a non-max candidate is lower than for a max candidate, we set it to
152
- sensitivity of a max candidate + epsilon.
153
-
154
155
  Returns:
155
156
  Mapping from nodes to their bitwidth candidates sensitivity.
156
157
  """
157
-
158
158
  Logger.info('Starting to evaluate metrics')
159
-
160
- orig_sorted_nodes = self.original_graph.get_configurable_sorted_nodes(self.fw_info)
161
-
162
- def topo_cfg(cfg: dict) -> list:
163
- topo_cfg = [cfg[n] for n in orig_sorted_nodes]
164
- assert len(topo_cfg) == len(cfg)
165
- return topo_cfg
166
-
167
- def compute_metric(cfg, node_idx=None, baseline_cfg=None):
168
- return self.sensitivity_evaluator.compute_metric(topo_cfg(cfg),
169
- node_idx,
170
- topo_cfg(baseline_cfg) if baseline_cfg else None)
171
-
172
- if self.using_virtual_graph:
173
- origin_max_config = self.config_reconstruction_helper.reconstruct_full_configuration(self.max_ru_config)
174
- max_config_value = compute_metric(origin_max_config)
175
- else:
176
- max_config_value = compute_metric(self.max_ru_config)
159
+ norm_method = self.mp_config.metric_normalization
160
+ eps = self.mp_config.metric_epsilon
161
+
162
+ verbose = 'VERBOSE_MP_METRIC' in os.environ
163
+
164
+ def normalize(node_candidates_metrics, max_ind):
165
+ if norm_method == MpMetricNormalization.NONE:
166
+ return node_candidates_metrics
167
+ if norm_method == MpMetricNormalization.MAXBIT:
168
+ ref_ind = max_ind
169
+ elif norm_method == MpMetricNormalization.MINBIT:
170
+ ref_ind = node.find_min_candidate_index()
171
+ else: # pragma: no cover
172
+ raise ValueError(f'Unexpected MpMetricNormalization mode {norm_method}')
173
+ normalized_metrics = node_candidates_metrics / node_candidates_metrics[ref_ind]
174
+ if verbose and not np.array_equal(normalized_metrics, node_candidates_metrics):
175
+ print(f'{"normalized metric:":25}', candidates_sensitivity)
176
+ return normalized_metrics
177
+
178
+ def ensure_maxbit_minimal_metric(node_candidates_metrics, max_ind):
179
+ if eps is None:
180
+ return node_candidates_metrics
181
+ # We want maxbit configuration to have the minimal distance metric (so that optimization objective
182
+ # doesn't prefer lower bits). If we got a smaller metric for non-maxbit, we update it to metric(maxbit)+eps.
183
+ max_val = node_candidates_metrics[max_ind]
184
+ metrics = np.maximum(node_candidates_metrics, max_val + eps)
185
+ metrics[max_ind] = max_val
186
+ if verbose and not np.array_equal(metrics, node_candidates_metrics):
187
+ print(f'{"eps-adjusted metric:":25}', candidates_sensitivity)
188
+ return metrics
177
189
 
178
190
  layer_to_metrics_mapping = defaultdict(list)
179
191
  for node_idx, node in tqdm(enumerate(self.mp_topo_configurable_nodes)):
192
+ candidates_sensitivity = np.empty(len(node.candidates_quantization_cfg))
180
193
  for bitwidth_idx, _ in enumerate(node.candidates_quantization_cfg):
181
- if self.max_ru_config[node] == bitwidth_idx:
182
- # This is a computation of the metric for the max configuration, assign pre-calculated value
183
- layer_to_metrics_mapping[node].append(max_config_value)
184
- continue
185
-
186
- # Create a configuration that differs at one layer only from the baseline model
187
- mp_model_configuration = self.max_ru_config.copy()
188
- mp_model_configuration[node] = bitwidth_idx
189
-
190
- # Build a distance matrix using the function we got from the framework implementation.
191
194
  if self.using_virtual_graph:
192
- # Reconstructing original graph's configuration from virtual graph's configuration
193
- orig_mp_config = self.config_reconstruction_helper.reconstruct_full_configuration(mp_model_configuration)
194
- changed_nodes = [orig_sorted_nodes.index(n) for n, ind in orig_mp_config.items()
195
- if origin_max_config[n] != ind]
196
- metric_value = compute_metric(orig_mp_config, changed_nodes, origin_max_config)
195
+ a_cfg, w_cfg = self.config_reconstructor.reconstruct_separate_aw_configs({node: bitwidth_idx})
197
196
  else:
198
- metric_value = compute_metric(mp_model_configuration, [node_idx], self.max_ru_config)
199
- metric_value = max(metric_value, max_config_value + eps)
200
- layer_to_metrics_mapping[node].append(metric_value)
197
+ a_cfg = {node: bitwidth_idx} if node.has_configurable_activation() else {}
198
+ w_cfg = {node: bitwidth_idx} if node.has_any_configurable_weight() else {}
199
+ candidates_sensitivity[bitwidth_idx] = self.sensitivity_evaluator.compute_metric(
200
+ mp_a_cfg={n.name: ind for n, ind in a_cfg.items()},
201
+ mp_w_cfg={n.name: ind for n, ind in w_cfg.items()}
202
+ )
203
+ if verbose:
204
+ print(f'{node.name}\n{"raw metric:":25}', candidates_sensitivity)
205
+ max_ind = node.find_max_candidate_index()
206
+ candidates_sensitivity = normalize(candidates_sensitivity, max_ind)
207
+ candidates_sensitivity = ensure_maxbit_minimal_metric(candidates_sensitivity, max_ind)
208
+ layer_to_metrics_mapping[node] = candidates_sensitivity
201
209
 
202
210
  # Finalize distance metric mapping
203
211
  self._finalize_distance_metric(layer_to_metrics_mapping)
@@ -244,8 +252,9 @@ class MixedPrecisionSearchManager:
244
252
  else:
245
253
  cfg = self.min_ru_config.copy()
246
254
  cfg[node] = candidate_idx
247
- real_cfg = self.config_reconstruction_helper.reconstruct_full_configuration(cfg)
248
- candidate_rus = self.ru_helper.compute_utilization(self.ru_targets, real_cfg)
255
+ if self.using_virtual_graph:
256
+ cfg = self.config_reconstructor.reconstruct_full_configuration(cfg)
257
+ candidate_rus = self.orig_graph_ru_helper.compute_utilization(self.ru_targets, cfg)
249
258
 
250
259
  for target, ru in candidate_rus.items():
251
260
  rus_per_candidate[target].append(ru)
@@ -283,8 +292,8 @@ class MixedPrecisionSearchManager:
283
292
  with the given config.
284
293
 
285
294
  """
286
- act_qcs, w_qcs = self.ru_helper.get_quantization_candidates(config)
287
- ru = self.ru_helper.ru_calculator.compute_resource_utilization(
295
+ act_qcs, w_qcs = self.orig_graph_ru_helper.get_quantization_candidates(config)
296
+ ru = self.orig_graph_ru_helper.ru_calculator.compute_resource_utilization(
288
297
  target_criterion=TargetInclusionCriterion.AnyQuantized, bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs,
289
298
  w_qcs=w_qcs, ru_targets=self.ru_targets, allow_unused_qcs=True)
290
299
  return ru
@@ -303,7 +312,7 @@ class MixedPrecisionSearchManager:
303
312
  # normalize metric for numerical stability
304
313
  max_dist = max(itertools.chain.from_iterable(layer_to_metrics_mapping.values()))
305
314
 
306
- if max_dist >= self.sensitivity_evaluator.quant_config.metric_normalization_threshold:
315
+ if max_dist >= self.mp_config.metric_normalization_threshold:
307
316
  Logger.warning(f"The mixed precision distance metric values indicate a large error in the quantized model."
308
317
  f"this can cause numerical issues."
309
318
  f"The program will proceed with mixed precision search after scaling the metric values,"
@@ -387,7 +396,9 @@ class ConfigReconstructionHelper:
387
396
 
388
397
  return orig_cfg
389
398
 
390
- def reconstruct_separate_aw_configs(self, virtual_cfg: Dict[BaseNode, int], include_non_configurable: bool) \
399
+ def reconstruct_separate_aw_configs(self,
400
+ virtual_cfg: Dict[BaseNode, int],
401
+ include_non_configurable: bool = False) \
391
402
  -> Tuple[Dict[BaseNode, int], Dict[BaseNode, int]]:
392
403
  """
393
404
  Retrieves original activation and weights nodes and corresponding candidates for a given configuration of the