mct-nightly 2.2.0.20240910.451__py3-none-any.whl → 2.2.0.20240912.453__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20240910.451.dist-info → mct_nightly-2.2.0.20240912.453.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20240910.451.dist-info → mct_nightly-2.2.0.20240912.453.dist-info}/RECORD +18 -18
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +41 -55
- model_compression_toolkit/core/common/quantization/core_config.py +20 -24
- model_compression_toolkit/core/common/quantization/debug_config.py +12 -17
- model_compression_toolkit/core/common/quantization/quantization_config.py +37 -79
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +2 -2
- model_compression_toolkit/core/runner.py +4 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +1 -1
- model_compression_toolkit/ptq/keras/quantization_facade.py +1 -1
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +1 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
- {mct_nightly-2.2.0.20240910.451.dist-info → mct_nightly-2.2.0.20240912.453.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20240910.451.dist-info → mct_nightly-2.2.0.20240912.453.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20240910.451.dist-info → mct_nightly-2.2.0.20240912.453.dist-info}/top_level.txt +0 -0
{mct_nightly-2.2.0.20240910.451.dist-info → mct_nightly-2.2.0.20240912.453.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=MQHqvnJpE47tpv4ydHmi75LJ0XLa-WSNrWwOeRTKlwQ,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -8,7 +8,7 @@ model_compression_toolkit/core/__init__.py,sha256=tnDtL9KmT0vsOU27SsJ19TKDEbIH-t
|
|
8
8
|
model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
|
9
9
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=7-b7Jd5jBVaXOWg5nSqbEyzBtdaGDbCxs8aqMV6GZ6I,11287
|
10
10
|
model_compression_toolkit/core/quantization_prep_runner.py,sha256=K9eJ7VbB_rpeyxX4yEnorOmSxFW3DkvofzxS6QI8Hp8,6454
|
11
|
-
model_compression_toolkit/core/runner.py,sha256=
|
11
|
+
model_compression_toolkit/core/runner.py,sha256=Wd0cNVMLOPX5cGY5kwz0J64rm87JKd-onJ2k01S9nLo,14362
|
12
12
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
13
13
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
14
14
|
model_compression_toolkit/core/common/framework_implementation.py,sha256=kSg2f7wS7e2EyvX6y0eKfNTTFvVFVrB8lvldJvcPvN8,20724
|
@@ -64,7 +64,7 @@ model_compression_toolkit/core/common/mixed_precision/configurable_quant_id.py,s
|
|
64
64
|
model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py,sha256=7dKMi5S0zQZ16m8NWn1XIuoXsKuZUg64G4-uK8-j1PQ,5177
|
65
65
|
model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=H8qYkJsk88OszUJo-Zde7vTmWiypLTg9KbbzIZ-hhvM,2812
|
66
66
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=klmaMQDeFc3IxRLf6YX4Dw1opFksbLyN10yFHdKAtLo,4875
|
67
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=
|
67
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=r1t025_QHshyoop-PZvL7x6UuXaeplCCU3h4VNBhJHo,4309
|
68
68
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=pk8HRoShDhiUprBC4m1AFQv1SacS4hOrj0MRdbq-5gY,7556
|
69
69
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=TTTux4YiOnQqt-2h7Y38959XaDwNZc0eufLMx_yws5U,37578
|
70
70
|
model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=QdxFQ0JxsrcSfk5LlUU_3oZpEK7bYwKelGzEHh0mnJY,27558
|
@@ -102,11 +102,11 @@ model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256
|
|
102
102
|
model_compression_toolkit/core/common/quantization/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
103
103
|
model_compression_toolkit/core/common/quantization/bit_width_config.py,sha256=IXHkpI9bH3AbrpC5T5bNYHcojHzeWQrrCpV-xZj5pks,5021
|
104
104
|
model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=yU-Cr6S4wOSkDk57iH2NVe-WII0whOhLryejkomCOt4,4940
|
105
|
-
model_compression_toolkit/core/common/quantization/core_config.py,sha256=
|
106
|
-
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=
|
105
|
+
model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
|
106
|
+
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=zJP2W9apUPX9RstpPWWK71wr9xJsg7j-s7lGV4_bQdc,1510
|
107
107
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
|
108
108
|
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=YycYN8_JMzvSR3pTVm5dT5x4zP3yBHn0Z9agnwrvOKI,26395
|
109
|
-
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=
|
109
|
+
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=BTDa1Izpdd4Z4essxTWP42V87f8mdq9vdKdVhE8vibo,3818
|
110
110
|
model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=eyosbVdnCwed7oMQ19tqnh0VoyGZ_UAuD_UnNoXyBpo,2210
|
111
111
|
model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=MwIOBZ4BlZSTIOG75PDvlI3JmZ6t8YjPc1VP9Adei60,3847
|
112
112
|
model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
|
@@ -142,7 +142,7 @@ model_compression_toolkit/core/common/substitutions/linear_collapsing_substituti
|
|
142
142
|
model_compression_toolkit/core/common/substitutions/remove_identity.py,sha256=TKU1TIU52UIkVnl0EZvWnDhLV9nIVZ4hqi-w1i4NXMk,2637
|
143
143
|
model_compression_toolkit/core/common/substitutions/residual_collapsing.py,sha256=N82mso5j3EJQlKt9EMHjjEJ67FmdGQeCfN8U5grOFXo,4830
|
144
144
|
model_compression_toolkit/core/common/substitutions/scale_equalization.py,sha256=p57u25qdW2pimxzGwgMXEBV4S-LzXuTVAlIM7830WfU,10966
|
145
|
-
model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=
|
145
|
+
model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=9Wq-nZahcmKkZmoo9Pqgb_v_6Rd0z_8HlVjbEbKvl8M,29977
|
146
146
|
model_compression_toolkit/core/common/substitutions/softmax_shift.py,sha256=R-0ZqhYAuZLEFWHvB2UTPm52L6gWHGdRdEnwGxKSeGI,2625
|
147
147
|
model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py,sha256=aXzUOJfgKPfQpEGfiIun26fgfCqazBG1mBpzoc4Ezxs,3477
|
148
148
|
model_compression_toolkit/core/common/substitutions/weights_activation_split.py,sha256=h85L2VlDOqbLd-N98wA3SdYWiblBgSsPceNuLanJd70,4737
|
@@ -354,7 +354,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
|
|
354
354
|
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
|
355
355
|
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=NXTNsVrO9DTh0uvc8V7rFaM0fYg2OA18ZrYd-cKZ7Z4,19159
|
356
356
|
model_compression_toolkit/gptq/keras/graph_info.py,sha256=MKIfrRTRH3zCuxCR1g9ZVIFyuSSr0e0sDybqh4LDM7E,4672
|
357
|
-
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=
|
357
|
+
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=iSHnMEdoIqHYqLCTsdK8uxhKbZuuaDOu_BeQ10Z492U,15715
|
358
358
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
359
359
|
model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=Rbl9urzkmACvVxICSEyJ02qFOBxWK0UQWtysFJzBVZw,4899
|
360
360
|
model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
|
@@ -371,7 +371,7 @@ model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa9
|
|
371
371
|
model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
|
372
372
|
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=-daninmlPGfKsBNPB2C3gT6rK0G5YeyJsuOLA0JlfBU,16633
|
373
373
|
model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
|
374
|
-
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=
|
374
|
+
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=lw9pOV5SKOw9kqOsfskuUiSH_UGOPRczTMpyzN_WTjY,13953
|
375
375
|
model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
|
376
376
|
model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=fKg-PNOhGBiL-4eySS9Fyw0GkA76Pq8jT_HbJuJ8iZU,4143
|
377
377
|
model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
|
@@ -391,14 +391,14 @@ model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=oStXze__7XCm0
|
|
391
391
|
model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
|
392
392
|
model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
|
393
393
|
model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
394
|
-
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=
|
394
|
+
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=7_xoCYzA5TKwJSqMf8GlxlZHOmpAwNdmkfudwJsTIiI,10972
|
395
395
|
model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
396
|
-
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=
|
396
|
+
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=ynOZ30heMp1bSTCYS6vn3EmmycDJn4G72IYZlzkRFPA,9309
|
397
397
|
model_compression_toolkit/qat/__init__.py,sha256=b2mURFGsvaZz_CdAD_w2I4Cdu8ZDN-2iGHMBHTKT5ws,1128
|
398
398
|
model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
|
399
399
|
model_compression_toolkit/qat/common/qat_config.py,sha256=xtfVSoyELGXynHNrw86dB9FU3Inu0zwehc3wLrh7JvY,2918
|
400
400
|
model_compression_toolkit/qat/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
401
|
-
model_compression_toolkit/qat/keras/quantization_facade.py,sha256=
|
401
|
+
model_compression_toolkit/qat/keras/quantization_facade.py,sha256=LNM2HW4cNei3tUhwLdNtsWrox_uSAhaswFxWiMEIrPM,17278
|
402
402
|
model_compression_toolkit/qat/keras/quantizer/__init__.py,sha256=zmYyCa25_KLCSUCGUDRslh3RCIjcRMxc_oXa54Aui-4,996
|
403
403
|
model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py,sha256=hoY3AETaLSRP7YfecZ32tyUUj-X_DHRWkV8nALYeRlY,2202
|
404
404
|
model_compression_toolkit/qat/keras/quantizer/quant_utils.py,sha256=cBULOgWUodcBO1lHevZggdTevuDYI6tQceV86U2x6DA,2543
|
@@ -410,7 +410,7 @@ model_compression_toolkit/qat/keras/quantizer/ste_rounding/__init__.py,sha256=cc
|
|
410
410
|
model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=fPAC49mBlB5ViaQT_xHUTC8EvH84OsBX3WAPusqYcM8,13538
|
411
411
|
model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py,sha256=6YS0v1qCq5dRqtLKHc2gHaKJWfql84TxtZ7pypaZock,10810
|
412
412
|
model_compression_toolkit/qat/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
413
|
-
model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=
|
413
|
+
model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=NnFy2E_7SR2m8vfh8Q8VrXOXhe7rMScgXnYBtDpsqVs,13456
|
414
414
|
model_compression_toolkit/qat/pytorch/quantizer/__init__.py,sha256=xYa4C8pr9cG1f3mQQcBXO_u3IdJN-zl7leZxuXDs86w,1003
|
415
415
|
model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_weight_quantizer.py,sha256=gjzrnBAZr5c_OrDpSjxpQYa_jKImv7ll52cng07_2oE,1813
|
416
416
|
model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py,sha256=lM10cGUkkTDtRyLLdWj5Rk0cgvcxp0uaCseyvrnk_Vg,5752
|
@@ -536,8 +536,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
536
536
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
|
537
537
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
538
538
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
539
|
-
mct_nightly-2.2.0.
|
540
|
-
mct_nightly-2.2.0.
|
541
|
-
mct_nightly-2.2.0.
|
542
|
-
mct_nightly-2.2.0.
|
543
|
-
mct_nightly-2.2.0.
|
539
|
+
mct_nightly-2.2.0.20240912.453.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
540
|
+
mct_nightly-2.2.0.20240912.453.dist-info/METADATA,sha256=s63h4FLAEjfw0suj6RB6Td_lWn9sGzpmOA4CYE2GkiA,20813
|
541
|
+
mct_nightly-2.2.0.20240912.453.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
542
|
+
mct_nightly-2.2.0.20240912.453.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
543
|
+
mct_nightly-2.2.0.20240912.453.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.2.0.
|
30
|
+
__version__ = "2.2.0.20240912.000453"
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py
CHANGED
@@ -13,75 +13,61 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from
|
17
|
-
|
16
|
+
from dataclasses import dataclass, field
|
17
|
+
from typing import List, Callable, Optional
|
18
18
|
from model_compression_toolkit.constants import MP_DEFAULT_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
|
19
19
|
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
|
20
20
|
|
21
21
|
|
22
|
+
@dataclass
|
22
23
|
class MixedPrecisionQuantizationConfig:
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
self.
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
assert 0.0 < num_interest_points_factor <= 1.0, "num_interest_points_factor should represent a percentage of " \
|
59
|
-
"the base set of interest points that are required to be " \
|
60
|
-
"used for mixed-precision metric evaluation, " \
|
61
|
-
"thus, it should be between 0 to 1"
|
62
|
-
self.num_interest_points_factor = num_interest_points_factor
|
63
|
-
|
64
|
-
self.use_hessian_based_scores = use_hessian_based_scores
|
65
|
-
self.norm_scores = norm_scores
|
66
|
-
self.hessian_batch_size = hessian_batch_size
|
67
|
-
|
68
|
-
self.metric_normalization_threshold = metric_normalization_threshold
|
69
|
-
|
70
|
-
self._mixed_precision_enable = False
|
24
|
+
"""
|
25
|
+
Class with mixed precision parameters to quantize the input model.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
compute_distance_fn (Callable): Function to compute a distance between two tensors. If None, using pre-defined distance methods based on the layer type for each layer.
|
29
|
+
distance_weighting_method (MpDistanceWeighting): MpDistanceWeighting enum value that provides a function to use when weighting the distances among different layers when computing the sensitivity metric.
|
30
|
+
num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
|
31
|
+
configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one.
|
32
|
+
num_interest_points_factor (float): A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric.
|
33
|
+
use_hessian_based_scores (bool): Whether to use Hessian-based scores for weighted average distance metric computation.
|
34
|
+
norm_scores (bool): Whether to normalize the returned scores for the weighted distance metric (to get values between 0 and 1).
|
35
|
+
refine_mp_solution (bool): Whether to try to improve the final mixed-precision configuration using a greedy algorithm that searches layers to increase their bit-width, or not.
|
36
|
+
metric_normalization_threshold (float): A threshold for checking the mixed precision distance metric values, In case of values larger than this threshold, the metric will be scaled to prevent numerical issues.
|
37
|
+
hessian_batch_size (int): The Hessian computation batch size. used only if using mixed precision with Hessian-based objective.
|
38
|
+
"""
|
39
|
+
|
40
|
+
compute_distance_fn: Optional[Callable] = None
|
41
|
+
distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG
|
42
|
+
num_of_images: int = MP_DEFAULT_NUM_SAMPLES
|
43
|
+
configuration_overwrite: Optional[List[int]] = None
|
44
|
+
num_interest_points_factor: float = field(default=1.0, metadata={"description": "Should be between 0.0 and 1.0"})
|
45
|
+
use_hessian_based_scores: bool = False
|
46
|
+
norm_scores: bool = True
|
47
|
+
refine_mp_solution: bool = True
|
48
|
+
metric_normalization_threshold: float = 1e10
|
49
|
+
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
|
50
|
+
_is_mixed_precision_enabled: bool = field(init=False, default=False)
|
51
|
+
|
52
|
+
def __post_init__(self):
|
53
|
+
# Validate num_interest_points_factor
|
54
|
+
assert 0.0 < self.num_interest_points_factor <= 1.0, \
|
55
|
+
"num_interest_points_factor should represent a percentage of " \
|
56
|
+
"the base set of interest points that are required to be " \
|
57
|
+
"used for mixed-precision metric evaluation, " \
|
58
|
+
"thus, it should be between 0 to 1"
|
71
59
|
|
72
60
|
def set_mixed_precision_enable(self):
|
73
61
|
"""
|
74
62
|
Set a flag in mixed precision config indicating that mixed precision is enabled.
|
75
63
|
"""
|
76
|
-
|
77
|
-
self._mixed_precision_enable = True
|
64
|
+
self._is_mixed_precision_enabled = True
|
78
65
|
|
79
66
|
@property
|
80
|
-
def
|
67
|
+
def is_mixed_precision_enabled(self):
|
81
68
|
"""
|
82
69
|
A property that indicates whether mixed precision quantization is enabled.
|
83
70
|
|
84
71
|
Returns: True if mixed precision quantization is enabled
|
85
|
-
|
86
72
|
"""
|
87
|
-
return self.
|
73
|
+
return self._is_mixed_precision_enabled
|
@@ -12,41 +12,37 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
from dataclasses import dataclass, field
|
16
|
+
from typing import Optional
|
17
|
+
|
15
18
|
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
16
19
|
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
17
20
|
from model_compression_toolkit.core.common.quantization.debug_config import DebugConfig
|
18
21
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
|
19
22
|
|
20
23
|
|
24
|
+
@dataclass
|
21
25
|
class CoreConfig:
|
22
26
|
"""
|
23
|
-
A
|
24
|
-
"""
|
25
|
-
def __init__(self,
|
26
|
-
quantization_config: QuantizationConfig = None,
|
27
|
-
mixed_precision_config: MixedPrecisionQuantizationConfig = None,
|
28
|
-
bit_width_config: BitWidthConfig = None,
|
29
|
-
debug_config: DebugConfig = None
|
30
|
-
):
|
31
|
-
"""
|
27
|
+
A dataclass to hold the configurations classes of the MCT-core.
|
32
28
|
|
33
|
-
|
34
|
-
|
35
|
-
|
29
|
+
Args:
|
30
|
+
quantization_config (QuantizationConfig): Config for quantization.
|
31
|
+
mixed_precision_config (MixedPrecisionQuantizationConfig): Config for mixed precision quantization.
|
36
32
|
If None, a default MixedPrecisionQuantizationConfig is used.
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
self.quantization_config = QuantizationConfig() if quantization_config is None else quantization_config
|
41
|
-
self.bit_width_config = BitWidthConfig() if bit_width_config is None else bit_width_config
|
42
|
-
self.debug_config = DebugConfig() if debug_config is None else debug_config
|
33
|
+
bit_width_config (BitWidthConfig): Config for manual bit-width selection.
|
34
|
+
debug_config (DebugConfig): Config for debugging and editing the network quantization process.
|
35
|
+
"""
|
43
36
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
37
|
+
quantization_config: QuantizationConfig = field(default_factory=QuantizationConfig)
|
38
|
+
mixed_precision_config: MixedPrecisionQuantizationConfig = field(default_factory=MixedPrecisionQuantizationConfig)
|
39
|
+
bit_width_config: BitWidthConfig = field(default_factory=BitWidthConfig)
|
40
|
+
debug_config: DebugConfig = field(default_factory=DebugConfig)
|
48
41
|
|
49
42
|
@property
|
50
|
-
def
|
51
|
-
|
43
|
+
def is_mixed_precision_enabled(self) -> bool:
|
44
|
+
"""
|
45
|
+
A property that indicates whether mixed precision is enabled.
|
46
|
+
"""
|
47
|
+
return bool(self.mixed_precision_config and self.mixed_precision_config.is_mixed_precision_enabled)
|
52
48
|
|
@@ -13,29 +13,24 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
|
16
|
+
from dataclasses import dataclass, field
|
17
17
|
from typing import List
|
18
18
|
|
19
19
|
from model_compression_toolkit.core.common.network_editors.edit_network import EditRule
|
20
20
|
|
21
21
|
|
22
|
+
@dataclass
|
22
23
|
class DebugConfig:
|
23
24
|
"""
|
24
|
-
A
|
25
|
-
"""
|
26
|
-
def __init__(self,
|
27
|
-
analyze_similarity: bool = False,
|
28
|
-
network_editor: List[EditRule] = [],
|
29
|
-
simulate_scheduler: bool = False):
|
30
|
-
"""
|
25
|
+
A dataclass for MCT core debug information.
|
31
26
|
|
32
|
-
|
27
|
+
Args:
|
28
|
+
analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is
|
29
|
+
enabled) or not. Can be used to pinpoint problematic layers in the quantization process.
|
30
|
+
network_editor (List[EditRule]): A list of rules and actions to edit the network for quantization.
|
31
|
+
simulate_scheduler (bool): Simulate scheduler behavior to compute operators' order and cuts.
|
32
|
+
"""
|
33
33
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
simulate_scheduler (bool): Simulate scheduler behaviour to compute operators order and cuts.
|
38
|
-
"""
|
39
|
-
self.analyze_similarity = analyze_similarity
|
40
|
-
self.network_editor = network_editor
|
41
|
-
self.simulate_scheduler = simulate_scheduler
|
34
|
+
analyze_similarity: bool = False
|
35
|
+
network_editor: List[EditRule] = field(default_factory=list)
|
36
|
+
simulate_scheduler: bool = False
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
|
16
|
+
from dataclasses import dataclass, field
|
17
17
|
import math
|
18
18
|
from enum import Enum
|
19
19
|
|
@@ -46,86 +46,44 @@ class QuantizationErrorMethod(Enum):
|
|
46
46
|
HMSE = 6
|
47
47
|
|
48
48
|
|
49
|
+
@dataclass
|
49
50
|
class QuantizationConfig:
|
51
|
+
"""
|
52
|
+
A class that encapsulates all the different parameters used by the library to quantize a model.
|
53
|
+
|
54
|
+
Examples:
|
55
|
+
You can create a quantization configuration to apply to a model. For example, to quantize a model's weights and
|
56
|
+
activations using thresholds, with weight threshold selection based on MSE and activation threshold selection
|
57
|
+
using NOCLIPPING (min/max), while enabling relu_bound_to_power_of_2 and weights_bias_correction,
|
58
|
+
you can instantiate a quantization configuration like this:
|
59
|
+
|
60
|
+
>>> import model_compression_toolkit as mct
|
61
|
+
>>> qc = mct.core.QuantizationConfig(activation_error_method=mct.core.QuantizationErrorMethod.NOCLIPPING, weights_error_method=mct.core.QuantizationErrorMethod.MSE, relu_bound_to_power_of_2=True, weights_bias_correction=True)
|
62
|
+
|
63
|
+
|
64
|
+
The QuantizationConfig instance can then be used in the quantization workflow,
|
65
|
+
such as with Keras in the function: :func:~model_compression_toolkit.ptq.keras_post_training_quantization`.
|
66
|
+
|
67
|
+
"""
|
50
68
|
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
concat_threshold_update: bool = False):
|
70
|
-
"""
|
71
|
-
Class to wrap all different parameters the library quantize the input model according to.
|
72
|
-
|
73
|
-
Args:
|
74
|
-
activation_error_method (QuantizationErrorMethod): Which method to use from QuantizationErrorMethod for activation quantization threshold selection.
|
75
|
-
weights_error_method (QuantizationErrorMethod): Which method to use from QuantizationErrorMethod for activation quantization threshold selection.
|
76
|
-
relu_bound_to_power_of_2 (bool): Whether to use relu to power of 2 scaling correction or not.
|
77
|
-
weights_bias_correction (bool): Whether to use weights bias correction or not.
|
78
|
-
weights_second_moment_correction (bool): Whether to use weights second_moment correction or not.
|
79
|
-
input_scaling (bool): Whether to use input scaling or not.
|
80
|
-
softmax_shift (bool): Whether to use softmax shift or not.
|
81
|
-
shift_negative_activation_correction (bool): Whether to use shifting negative activation correction or not.
|
82
|
-
activation_channel_equalization (bool): Whether to use activation channel equalization correction or not.
|
83
|
-
z_threshold (float): Value of z score for outliers removal.
|
84
|
-
min_threshold (float): Minimum threshold to use during thresholds selection.
|
85
|
-
l_p_value (int): The p value of L_p norm threshold selection.
|
86
|
-
block_collapsing (bool): Whether to collapse block one to another in the input network
|
87
|
-
shift_negative_ratio (float): Value for the ratio between the minimal negative value of a non-linearity output to its activation threshold, which above it - shifting negative activation should occur if enabled.
|
88
|
-
shift_negative_threshold_recalculation (bool): Whether or not to recompute the threshold after shifting negative activation.
|
89
|
-
shift_negative_params_search (bool): Whether to search for optimal shift and threshold in shift negative activation.
|
90
|
-
|
91
|
-
Examples:
|
92
|
-
One may create a quantization configuration to quantize a model according to.
|
93
|
-
For example, to quantize a model's weights and activation using thresholds, such that
|
94
|
-
weights threshold selection is done using MSE, activation threshold selection is done using NOCLIPPING (min/max),
|
95
|
-
enabling relu_bound_to_power_of_2, weights_bias_correction,
|
96
|
-
one can instantiate a quantization configuration:
|
97
|
-
|
98
|
-
>>> import model_compression_toolkit as mct
|
99
|
-
>>> qc = mct.core.QuantizationConfig(activation_error_method=mct.core.QuantizationErrorMethod.NOCLIPPING, weights_error_method=mct.core.QuantizationErrorMethod.MSE, relu_bound_to_power_of_2=True, weights_bias_correction=True)
|
100
|
-
|
101
|
-
|
102
|
-
The QuantizationConfig instanse can then be passed to
|
103
|
-
:func:`~model_compression_toolkit.ptq.keras_post_training_quantization`
|
104
|
-
|
105
|
-
"""
|
106
|
-
|
107
|
-
self.activation_error_method = activation_error_method
|
108
|
-
self.weights_error_method = weights_error_method
|
109
|
-
self.relu_bound_to_power_of_2 = relu_bound_to_power_of_2
|
110
|
-
self.weights_bias_correction = weights_bias_correction
|
111
|
-
self.weights_second_moment_correction = weights_second_moment_correction
|
112
|
-
self.activation_channel_equalization = activation_channel_equalization
|
113
|
-
self.input_scaling = input_scaling
|
114
|
-
self.softmax_shift = softmax_shift
|
115
|
-
self.min_threshold = min_threshold
|
116
|
-
self.shift_negative_activation_correction = shift_negative_activation_correction
|
117
|
-
self.z_threshold = z_threshold
|
118
|
-
self.l_p_value = l_p_value
|
119
|
-
self.linear_collapsing = linear_collapsing
|
120
|
-
self.residual_collapsing = residual_collapsing
|
121
|
-
self.shift_negative_ratio = shift_negative_ratio
|
122
|
-
self.shift_negative_threshold_recalculation = shift_negative_threshold_recalculation
|
123
|
-
self.shift_negative_params_search = shift_negative_params_search
|
124
|
-
self.concat_threshold_update = concat_threshold_update
|
125
|
-
|
126
|
-
def __repr__(self):
|
127
|
-
# Used for debugging, thus no cover.
|
128
|
-
return str(self.__dict__) # pragma: no cover
|
69
|
+
activation_error_method: QuantizationErrorMethod = QuantizationErrorMethod.MSE
|
70
|
+
weights_error_method: QuantizationErrorMethod = QuantizationErrorMethod.MSE
|
71
|
+
relu_bound_to_power_of_2: bool = False
|
72
|
+
weights_bias_correction: bool = True
|
73
|
+
weights_second_moment_correction: bool = False
|
74
|
+
input_scaling: bool = False
|
75
|
+
softmax_shift: bool = False
|
76
|
+
shift_negative_activation_correction: bool = True
|
77
|
+
activation_channel_equalization: bool = False
|
78
|
+
z_threshold: float = math.inf
|
79
|
+
min_threshold: float = MIN_THRESHOLD
|
80
|
+
l_p_value: int = 2
|
81
|
+
linear_collapsing: bool = True
|
82
|
+
residual_collapsing: bool = True
|
83
|
+
shift_negative_ratio: float = 0.05
|
84
|
+
shift_negative_threshold_recalculation: bool = False
|
85
|
+
shift_negative_params_search: bool = False
|
86
|
+
concat_threshold_update: bool = False
|
129
87
|
|
130
88
|
|
131
89
|
# Default quantization configuration the library use.
|
@@ -360,7 +360,7 @@ def shift_negative_function(graph: Graph,
|
|
360
360
|
graph=graph,
|
361
361
|
quant_config=core_config.quantization_config,
|
362
362
|
tpc=graph.tpc,
|
363
|
-
mixed_precision_enable=core_config.
|
363
|
+
mixed_precision_enable=core_config.is_mixed_precision_enabled)
|
364
364
|
|
365
365
|
for candidate_qc in pad_node.candidates_quantization_cfg:
|
366
366
|
candidate_qc.activation_quantization_cfg.enable_activation_quantization = False
|
@@ -377,7 +377,7 @@ def shift_negative_function(graph: Graph,
|
|
377
377
|
graph=graph,
|
378
378
|
quant_config=core_config.quantization_config,
|
379
379
|
tpc=graph.tpc,
|
380
|
-
mixed_precision_enable=core_config.
|
380
|
+
mixed_precision_enable=core_config.is_mixed_precision_enabled)
|
381
381
|
|
382
382
|
original_non_linear_activation_nbits = non_linear_node_cfg_candidate.activation_n_bits
|
383
383
|
# The non-linear node's output should be float, so we approximate it by using 16bits quantization.
|
@@ -119,7 +119,7 @@ def core_runner(in_model: Any,
|
|
119
119
|
tpc,
|
120
120
|
core_config.bit_width_config,
|
121
121
|
tb_w,
|
122
|
-
mixed_precision_enable=core_config.
|
122
|
+
mixed_precision_enable=core_config.is_mixed_precision_enabled,
|
123
123
|
running_gptq=running_gptq)
|
124
124
|
|
125
125
|
hessian_info_service = HessianInfoService(graph=graph, representative_dataset_gen=representative_data_gen,
|
@@ -136,7 +136,7 @@ def core_runner(in_model: Any,
|
|
136
136
|
######################################
|
137
137
|
# Finalize bit widths
|
138
138
|
######################################
|
139
|
-
if core_config.
|
139
|
+
if core_config.is_mixed_precision_enabled:
|
140
140
|
if core_config.mixed_precision_config.configuration_overwrite is None:
|
141
141
|
|
142
142
|
filter_candidates_for_mixed_precision(graph, target_resource_utilization, fw_info, tpc)
|
@@ -161,7 +161,7 @@ def core_runner(in_model: Any,
|
|
161
161
|
else:
|
162
162
|
bit_widths_config = []
|
163
163
|
|
164
|
-
tg = set_bit_widths(core_config.
|
164
|
+
tg = set_bit_widths(core_config.is_mixed_precision_enabled,
|
165
165
|
tg,
|
166
166
|
bit_widths_config)
|
167
167
|
|
@@ -175,7 +175,7 @@ def core_runner(in_model: Any,
|
|
175
175
|
fw_info=fw_info,
|
176
176
|
fw_impl=fw_impl)
|
177
177
|
|
178
|
-
if core_config.
|
178
|
+
if core_config.is_mixed_precision_enabled:
|
179
179
|
# Retrieve lists of tuples (node, node's final weights/activation bitwidth)
|
180
180
|
weights_conf_nodes_bitwidth = tg.get_final_weights_config(fw_info)
|
181
181
|
activation_conf_nodes_bitwidth = tg.get_final_activation_config()
|
@@ -199,7 +199,7 @@ if FOUND_TF:
|
|
199
199
|
KerasModelValidation(model=in_model,
|
200
200
|
fw_info=DEFAULT_KERAS_INFO).validate()
|
201
201
|
|
202
|
-
if core_config.
|
202
|
+
if core_config.is_mixed_precision_enabled:
|
203
203
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
204
204
|
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
|
205
205
|
"Ensure usage of the correct API for keras_post_training_quantization "
|
@@ -165,7 +165,7 @@ if FOUND_TORCH:
|
|
165
165
|
|
166
166
|
"""
|
167
167
|
|
168
|
-
if core_config.
|
168
|
+
if core_config.is_mixed_precision_enabled:
|
169
169
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
170
170
|
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
|
171
171
|
"Ensure usage of the correct API for 'pytorch_gradient_post_training_quantization' "
|
@@ -124,7 +124,7 @@ if FOUND_TF:
|
|
124
124
|
KerasModelValidation(model=in_model,
|
125
125
|
fw_info=fw_info).validate()
|
126
126
|
|
127
|
-
if core_config.
|
127
|
+
if core_config.is_mixed_precision_enabled:
|
128
128
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
129
129
|
Logger.critical("Given quantization config to mixed-precision facade is not of type "
|
130
130
|
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
|
@@ -96,7 +96,7 @@ if FOUND_TORCH:
|
|
96
96
|
|
97
97
|
fw_info = DEFAULT_PYTORCH_INFO
|
98
98
|
|
99
|
-
if core_config.
|
99
|
+
if core_config.is_mixed_precision_enabled:
|
100
100
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
101
101
|
Logger.critical("Given quantization config to mixed-precision facade is not of type "
|
102
102
|
"MixedPrecisionQuantizationConfig. Please use "
|
@@ -176,7 +176,7 @@ if FOUND_TF:
|
|
176
176
|
KerasModelValidation(model=in_model,
|
177
177
|
fw_info=DEFAULT_KERAS_INFO).validate()
|
178
178
|
|
179
|
-
if core_config.
|
179
|
+
if core_config.is_mixed_precision_enabled:
|
180
180
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
181
181
|
Logger.critical("Given quantization config to mixed-precision facade is not of type "
|
182
182
|
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization API,"
|
@@ -145,7 +145,7 @@ if FOUND_TORCH:
|
|
145
145
|
f"If you encounter an issue, please open an issue in our GitHub "
|
146
146
|
f"project https://github.com/sony/model_optimization")
|
147
147
|
|
148
|
-
if core_config.
|
148
|
+
if core_config.is_mixed_precision_enabled:
|
149
149
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
150
150
|
Logger.critical("Given quantization config to mixed-precision facade is not of type "
|
151
151
|
"MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API,"
|
{mct_nightly-2.2.0.20240910.451.dist-info → mct_nightly-2.2.0.20240912.453.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.2.0.20240910.451.dist-info → mct_nightly-2.2.0.20240912.453.dist-info}/top_level.txt
RENAMED
File without changes
|