mct-nightly 2.2.0.20240911.455__py3-none-any.whl → 2.2.0.20240912.453__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.2.0.20240911.455
3
+ Version: 2.2.0.20240912.453
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=td8mVc5zVMOS9tP4ad7FK5FNTuAiS3FlWtKKtdaQKik,1573
1
+ model_compression_toolkit/__init__.py,sha256=MQHqvnJpE47tpv4ydHmi75LJ0XLa-WSNrWwOeRTKlwQ,1573
2
2
  model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -8,7 +8,7 @@ model_compression_toolkit/core/__init__.py,sha256=tnDtL9KmT0vsOU27SsJ19TKDEbIH-t
8
8
  model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
9
9
  model_compression_toolkit/core/graph_prep_runner.py,sha256=7-b7Jd5jBVaXOWg5nSqbEyzBtdaGDbCxs8aqMV6GZ6I,11287
10
10
  model_compression_toolkit/core/quantization_prep_runner.py,sha256=K9eJ7VbB_rpeyxX4yEnorOmSxFW3DkvofzxS6QI8Hp8,6454
11
- model_compression_toolkit/core/runner.py,sha256=ryHhW5Qqu7XHVkngLF0uLX8oa4CxNAIF4PoGBeUjoSk,14346
11
+ model_compression_toolkit/core/runner.py,sha256=Wd0cNVMLOPX5cGY5kwz0J64rm87JKd-onJ2k01S9nLo,14362
12
12
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
13
13
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
14
14
  model_compression_toolkit/core/common/framework_implementation.py,sha256=kSg2f7wS7e2EyvX6y0eKfNTTFvVFVrB8lvldJvcPvN8,20724
@@ -64,7 +64,7 @@ model_compression_toolkit/core/common/mixed_precision/configurable_quant_id.py,s
64
64
  model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py,sha256=7dKMi5S0zQZ16m8NWn1XIuoXsKuZUg64G4-uK8-j1PQ,5177
65
65
  model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=H8qYkJsk88OszUJo-Zde7vTmWiypLTg9KbbzIZ-hhvM,2812
66
66
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=klmaMQDeFc3IxRLf6YX4Dw1opFksbLyN10yFHdKAtLo,4875
67
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=rppRZJdSCQGiZsd93QxoUIhj51eETvQbuI5JiC2TUeA,4963
67
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=r1t025_QHshyoop-PZvL7x6UuXaeplCCU3h4VNBhJHo,4309
68
68
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=pk8HRoShDhiUprBC4m1AFQv1SacS4hOrj0MRdbq-5gY,7556
69
69
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=TTTux4YiOnQqt-2h7Y38959XaDwNZc0eufLMx_yws5U,37578
70
70
  model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=QdxFQ0JxsrcSfk5LlUU_3oZpEK7bYwKelGzEHh0mnJY,27558
@@ -102,11 +102,11 @@ model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256
102
102
  model_compression_toolkit/core/common/quantization/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
103
103
  model_compression_toolkit/core/common/quantization/bit_width_config.py,sha256=IXHkpI9bH3AbrpC5T5bNYHcojHzeWQrrCpV-xZj5pks,5021
104
104
  model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=yU-Cr6S4wOSkDk57iH2NVe-WII0whOhLryejkomCOt4,4940
105
- model_compression_toolkit/core/common/quantization/core_config.py,sha256=f0uSuY9mX-vLX_1s2DemPARQlAXmLPKJKPtCArz3pZI,2670
106
- model_compression_toolkit/core/common/quantization/debug_config.py,sha256=8G8SpE_4rb8xBp8d6mMq8R_OnXJ_1oxB2g-Lxk9EJCM,1691
105
+ model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
106
+ model_compression_toolkit/core/common/quantization/debug_config.py,sha256=zJP2W9apUPX9RstpPWWK71wr9xJsg7j-s7lGV4_bQdc,1510
107
107
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
108
108
  model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=YycYN8_JMzvSR3pTVm5dT5x4zP3yBHn0Z9agnwrvOKI,26395
109
- model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=7dDs9pq9dM9ADVeIi7wyMpW9ZbAI9GLujgxt7nxvnng,7105
109
+ model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=BTDa1Izpdd4Z4essxTWP42V87f8mdq9vdKdVhE8vibo,3818
110
110
  model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=eyosbVdnCwed7oMQ19tqnh0VoyGZ_UAuD_UnNoXyBpo,2210
111
111
  model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=MwIOBZ4BlZSTIOG75PDvlI3JmZ6t8YjPc1VP9Adei60,3847
112
112
  model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
@@ -142,7 +142,7 @@ model_compression_toolkit/core/common/substitutions/linear_collapsing_substituti
142
142
  model_compression_toolkit/core/common/substitutions/remove_identity.py,sha256=TKU1TIU52UIkVnl0EZvWnDhLV9nIVZ4hqi-w1i4NXMk,2637
143
143
  model_compression_toolkit/core/common/substitutions/residual_collapsing.py,sha256=N82mso5j3EJQlKt9EMHjjEJ67FmdGQeCfN8U5grOFXo,4830
144
144
  model_compression_toolkit/core/common/substitutions/scale_equalization.py,sha256=p57u25qdW2pimxzGwgMXEBV4S-LzXuTVAlIM7830WfU,10966
145
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=AqQ0cTMz0d1qziQD5uUeYJON0wfXKvRIADuonF8Hobs,29969
145
+ model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=9Wq-nZahcmKkZmoo9Pqgb_v_6Rd0z_8HlVjbEbKvl8M,29977
146
146
  model_compression_toolkit/core/common/substitutions/softmax_shift.py,sha256=R-0ZqhYAuZLEFWHvB2UTPm52L6gWHGdRdEnwGxKSeGI,2625
147
147
  model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py,sha256=aXzUOJfgKPfQpEGfiIun26fgfCqazBG1mBpzoc4Ezxs,3477
148
148
  model_compression_toolkit/core/common/substitutions/weights_activation_split.py,sha256=h85L2VlDOqbLd-N98wA3SdYWiblBgSsPceNuLanJd70,4737
@@ -354,7 +354,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
354
354
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
355
355
  model_compression_toolkit/gptq/keras/gptq_training.py,sha256=NXTNsVrO9DTh0uvc8V7rFaM0fYg2OA18ZrYd-cKZ7Z4,19159
356
356
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=MKIfrRTRH3zCuxCR1g9ZVIFyuSSr0e0sDybqh4LDM7E,4672
357
- model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=t4Jxtu8qyGbIftI5l2sb79Ydd85XM6GyDpkCqiotVF8,15711
357
+ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=iSHnMEdoIqHYqLCTsdK8uxhKbZuuaDOu_BeQ10Z492U,15715
358
358
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
359
359
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=Rbl9urzkmACvVxICSEyJ02qFOBxWK0UQWtysFJzBVZw,4899
360
360
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
@@ -371,7 +371,7 @@ model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa9
371
371
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
372
372
  model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=-daninmlPGfKsBNPB2C3gT6rK0G5YeyJsuOLA0JlfBU,16633
373
373
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
374
- model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=TMus5LYJnTngLKot7coVax8gsIzPDYVU9m6orFPvWSY,13949
374
+ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=lw9pOV5SKOw9kqOsfskuUiSH_UGOPRczTMpyzN_WTjY,13953
375
375
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
376
376
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=fKg-PNOhGBiL-4eySS9Fyw0GkA76Pq8jT_HbJuJ8iZU,4143
377
377
  model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
@@ -391,14 +391,14 @@ model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=oStXze__7XCm0
391
391
  model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
392
392
  model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
393
393
  model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
394
- model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=DAAJPd6pKLgiwoJT-_u2dvVOO4Ox6IgJgfiUbnNRBwQ,10968
394
+ model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=7_xoCYzA5TKwJSqMf8GlxlZHOmpAwNdmkfudwJsTIiI,10972
395
395
  model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
396
- model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=xHVTrm9Fyk_j4j8G1Pb97qacN_gn9cGYpsT1HXdTc1A,9305
396
+ model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=ynOZ30heMp1bSTCYS6vn3EmmycDJn4G72IYZlzkRFPA,9309
397
397
  model_compression_toolkit/qat/__init__.py,sha256=b2mURFGsvaZz_CdAD_w2I4Cdu8ZDN-2iGHMBHTKT5ws,1128
398
398
  model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
399
399
  model_compression_toolkit/qat/common/qat_config.py,sha256=xtfVSoyELGXynHNrw86dB9FU3Inu0zwehc3wLrh7JvY,2918
400
400
  model_compression_toolkit/qat/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
401
- model_compression_toolkit/qat/keras/quantization_facade.py,sha256=VaZTqK53TOWrXebnJzoHHD99DxOgS4NzHGbmYWaajWA,17274
401
+ model_compression_toolkit/qat/keras/quantization_facade.py,sha256=LNM2HW4cNei3tUhwLdNtsWrox_uSAhaswFxWiMEIrPM,17278
402
402
  model_compression_toolkit/qat/keras/quantizer/__init__.py,sha256=zmYyCa25_KLCSUCGUDRslh3RCIjcRMxc_oXa54Aui-4,996
403
403
  model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py,sha256=hoY3AETaLSRP7YfecZ32tyUUj-X_DHRWkV8nALYeRlY,2202
404
404
  model_compression_toolkit/qat/keras/quantizer/quant_utils.py,sha256=cBULOgWUodcBO1lHevZggdTevuDYI6tQceV86U2x6DA,2543
@@ -410,7 +410,7 @@ model_compression_toolkit/qat/keras/quantizer/ste_rounding/__init__.py,sha256=cc
410
410
  model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=fPAC49mBlB5ViaQT_xHUTC8EvH84OsBX3WAPusqYcM8,13538
411
411
  model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py,sha256=6YS0v1qCq5dRqtLKHc2gHaKJWfql84TxtZ7pypaZock,10810
412
412
  model_compression_toolkit/qat/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
413
- model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=1eg0jMgFzRLYIFnG9GJnJ8U3W4IOM-4Z27s9Wq-JeOQ,13452
413
+ model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=NnFy2E_7SR2m8vfh8Q8VrXOXhe7rMScgXnYBtDpsqVs,13456
414
414
  model_compression_toolkit/qat/pytorch/quantizer/__init__.py,sha256=xYa4C8pr9cG1f3mQQcBXO_u3IdJN-zl7leZxuXDs86w,1003
415
415
  model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_weight_quantizer.py,sha256=gjzrnBAZr5c_OrDpSjxpQYa_jKImv7ll52cng07_2oE,1813
416
416
  model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py,sha256=lM10cGUkkTDtRyLLdWj5Rk0cgvcxp0uaCseyvrnk_Vg,5752
@@ -536,8 +536,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
536
536
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
537
537
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
538
538
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
539
- mct_nightly-2.2.0.20240911.455.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
540
- mct_nightly-2.2.0.20240911.455.dist-info/METADATA,sha256=okdby2_mBNXyKyLCkF7xT2Ma8QPEhiP1UzWEJPAFbPY,20813
541
- mct_nightly-2.2.0.20240911.455.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
542
- mct_nightly-2.2.0.20240911.455.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
543
- mct_nightly-2.2.0.20240911.455.dist-info/RECORD,,
539
+ mct_nightly-2.2.0.20240912.453.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
540
+ mct_nightly-2.2.0.20240912.453.dist-info/METADATA,sha256=s63h4FLAEjfw0suj6RB6Td_lWn9sGzpmOA4CYE2GkiA,20813
541
+ mct_nightly-2.2.0.20240912.453.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
542
+ mct_nightly-2.2.0.20240912.453.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
543
+ mct_nightly-2.2.0.20240912.453.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.2.0.20240911.000455"
30
+ __version__ = "2.2.0.20240912.000453"
@@ -13,75 +13,61 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import List, Callable
17
-
16
+ from dataclasses import dataclass, field
17
+ from typing import List, Callable, Optional
18
18
  from model_compression_toolkit.constants import MP_DEFAULT_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
19
19
  from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
20
20
 
21
21
 
22
+ @dataclass
22
23
  class MixedPrecisionQuantizationConfig:
23
-
24
- def __init__(self,
25
- compute_distance_fn: Callable = None,
26
- distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG,
27
- num_of_images: int = MP_DEFAULT_NUM_SAMPLES,
28
- configuration_overwrite: List[int] = None,
29
- num_interest_points_factor: float = 1.0,
30
- use_hessian_based_scores: bool = False,
31
- norm_scores: bool = True,
32
- refine_mp_solution: bool = True,
33
- metric_normalization_threshold: float = 1e10,
34
- hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE):
35
- """
36
- Class with mixed precision parameters to quantize the input model.
37
-
38
- Args:
39
- compute_distance_fn (Callable): Function to compute a distance between two tensors. If None, using pre-defined distance methods based on the layer type for each layer.
40
- distance_weighting_method (MpDistanceWeighting): MpDistanceWeighting enum value that provides a function to use when weighting the distances among different layers when computing the sensitivity metric.
41
- num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
42
- configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one.
43
- num_interest_points_factor (float): A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric.
44
- use_hessian_based_scores (bool): Whether to use Hessian-based scores for weighted average distance metric computation.
45
- norm_scores (bool): Whether to normalize the returned scores for the weighted distance metric (to get values between 0 and 1).
46
- refine_mp_solution (bool): Whether to try to improve the final mixed-precision configuration using a greedy algorithm that searches layers to increase their bit-width, or not.
47
- metric_normalization_threshold (float): A threshold for checking the mixed precision distance metric values, In case of values larger than this threshold, the metric will be scaled to prevent numerical issues.
48
- hessian_batch_size (int): The Hessian computation batch size. used only if using mixed precision with Hessian-based objective.
49
-
50
- """
51
-
52
- self.compute_distance_fn = compute_distance_fn
53
- self.distance_weighting_method = distance_weighting_method
54
- self.num_of_images = num_of_images
55
- self.configuration_overwrite = configuration_overwrite
56
- self.refine_mp_solution = refine_mp_solution
57
-
58
- assert 0.0 < num_interest_points_factor <= 1.0, "num_interest_points_factor should represent a percentage of " \
59
- "the base set of interest points that are required to be " \
60
- "used for mixed-precision metric evaluation, " \
61
- "thus, it should be between 0 to 1"
62
- self.num_interest_points_factor = num_interest_points_factor
63
-
64
- self.use_hessian_based_scores = use_hessian_based_scores
65
- self.norm_scores = norm_scores
66
- self.hessian_batch_size = hessian_batch_size
67
-
68
- self.metric_normalization_threshold = metric_normalization_threshold
69
-
70
- self._mixed_precision_enable = False
24
+ """
25
+ Class with mixed precision parameters to quantize the input model.
26
+
27
+ Args:
28
+ compute_distance_fn (Callable): Function to compute a distance between two tensors. If None, using pre-defined distance methods based on the layer type for each layer.
29
+ distance_weighting_method (MpDistanceWeighting): MpDistanceWeighting enum value that provides a function to use when weighting the distances among different layers when computing the sensitivity metric.
30
+ num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
31
+ configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one.
32
+ num_interest_points_factor (float): A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric.
33
+ use_hessian_based_scores (bool): Whether to use Hessian-based scores for weighted average distance metric computation.
34
+ norm_scores (bool): Whether to normalize the returned scores for the weighted distance metric (to get values between 0 and 1).
35
+ refine_mp_solution (bool): Whether to try to improve the final mixed-precision configuration using a greedy algorithm that searches layers to increase their bit-width, or not.
36
+ metric_normalization_threshold (float): A threshold for checking the mixed precision distance metric values, In case of values larger than this threshold, the metric will be scaled to prevent numerical issues.
37
+ hessian_batch_size (int): The Hessian computation batch size. used only if using mixed precision with Hessian-based objective.
38
+ """
39
+
40
+ compute_distance_fn: Optional[Callable] = None
41
+ distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG
42
+ num_of_images: int = MP_DEFAULT_NUM_SAMPLES
43
+ configuration_overwrite: Optional[List[int]] = None
44
+ num_interest_points_factor: float = field(default=1.0, metadata={"description": "Should be between 0.0 and 1.0"})
45
+ use_hessian_based_scores: bool = False
46
+ norm_scores: bool = True
47
+ refine_mp_solution: bool = True
48
+ metric_normalization_threshold: float = 1e10
49
+ hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
50
+ _is_mixed_precision_enabled: bool = field(init=False, default=False)
51
+
52
+ def __post_init__(self):
53
+ # Validate num_interest_points_factor
54
+ assert 0.0 < self.num_interest_points_factor <= 1.0, \
55
+ "num_interest_points_factor should represent a percentage of " \
56
+ "the base set of interest points that are required to be " \
57
+ "used for mixed-precision metric evaluation, " \
58
+ "thus, it should be between 0 to 1"
71
59
 
72
60
  def set_mixed_precision_enable(self):
73
61
  """
74
62
  Set a flag in mixed precision config indicating that mixed precision is enabled.
75
63
  """
76
-
77
- self._mixed_precision_enable = True
64
+ self._is_mixed_precision_enabled = True
78
65
 
79
66
  @property
80
- def mixed_precision_enable(self):
67
+ def is_mixed_precision_enabled(self):
81
68
  """
82
69
  A property that indicates whether mixed precision quantization is enabled.
83
70
 
84
71
  Returns: True if mixed precision quantization is enabled
85
-
86
72
  """
87
- return self._mixed_precision_enable
73
+ return self._is_mixed_precision_enabled
@@ -12,41 +12,37 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from dataclasses import dataclass, field
16
+ from typing import Optional
17
+
15
18
  from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
16
19
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
17
20
  from model_compression_toolkit.core.common.quantization.debug_config import DebugConfig
18
21
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
19
22
 
20
23
 
24
+ @dataclass
21
25
  class CoreConfig:
22
26
  """
23
- A class to hold the configurations classes of the MCT-core.
24
- """
25
- def __init__(self,
26
- quantization_config: QuantizationConfig = None,
27
- mixed_precision_config: MixedPrecisionQuantizationConfig = None,
28
- bit_width_config: BitWidthConfig = None,
29
- debug_config: DebugConfig = None
30
- ):
31
- """
27
+ A dataclass to hold the configurations classes of the MCT-core.
32
28
 
33
- Args:
34
- quantization_config (QuantizationConfig): Config for quantization.
35
- mixed_precision_config (MixedPrecisionQuantizationConfig): Config for mixed precision quantization.
29
+ Args:
30
+ quantization_config (QuantizationConfig): Config for quantization.
31
+ mixed_precision_config (MixedPrecisionQuantizationConfig): Config for mixed precision quantization.
36
32
  If None, a default MixedPrecisionQuantizationConfig is used.
37
- bit_width_config (BitWidthConfig): Config for manual bit-width selection.
38
- debug_config (DebugConfig): Config for debugging and editing the network quantization process.
39
- """
40
- self.quantization_config = QuantizationConfig() if quantization_config is None else quantization_config
41
- self.bit_width_config = BitWidthConfig() if bit_width_config is None else bit_width_config
42
- self.debug_config = DebugConfig() if debug_config is None else debug_config
33
+ bit_width_config (BitWidthConfig): Config for manual bit-width selection.
34
+ debug_config (DebugConfig): Config for debugging and editing the network quantization process.
35
+ """
43
36
 
44
- if mixed_precision_config is None:
45
- self.mixed_precision_config = MixedPrecisionQuantizationConfig()
46
- else:
47
- self.mixed_precision_config = mixed_precision_config
37
+ quantization_config: QuantizationConfig = field(default_factory=QuantizationConfig)
38
+ mixed_precision_config: MixedPrecisionQuantizationConfig = field(default_factory=MixedPrecisionQuantizationConfig)
39
+ bit_width_config: BitWidthConfig = field(default_factory=BitWidthConfig)
40
+ debug_config: DebugConfig = field(default_factory=DebugConfig)
48
41
 
49
42
  @property
50
- def mixed_precision_enable(self):
51
- return self.mixed_precision_config is not None and self.mixed_precision_config.mixed_precision_enable
43
+ def is_mixed_precision_enabled(self) -> bool:
44
+ """
45
+ A property that indicates whether mixed precision is enabled.
46
+ """
47
+ return bool(self.mixed_precision_config and self.mixed_precision_config.is_mixed_precision_enabled)
52
48
 
@@ -13,29 +13,24 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
-
16
+ from dataclasses import dataclass, field
17
17
  from typing import List
18
18
 
19
19
  from model_compression_toolkit.core.common.network_editors.edit_network import EditRule
20
20
 
21
21
 
22
+ @dataclass
22
23
  class DebugConfig:
23
24
  """
24
- A class for MCT core debug information.
25
- """
26
- def __init__(self,
27
- analyze_similarity: bool = False,
28
- network_editor: List[EditRule] = [],
29
- simulate_scheduler: bool = False):
30
- """
25
+ A dataclass for MCT core debug information.
31
26
 
32
- Args:
27
+ Args:
28
+ analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is
29
+ enabled) or not. Can be used to pinpoint problematic layers in the quantization process.
30
+ network_editor (List[EditRule]): A list of rules and actions to edit the network for quantization.
31
+ simulate_scheduler (bool): Simulate scheduler behavior to compute operators' order and cuts.
32
+ """
33
33
 
34
- analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is
35
- enabled) or not. Can be used to pinpoint problematic layers in the quantization process.
36
- network_editor (List[EditRule]): A list of rules and actions to edit the network for quantization.
37
- simulate_scheduler (bool): Simulate scheduler behaviour to compute operators order and cuts.
38
- """
39
- self.analyze_similarity = analyze_similarity
40
- self.network_editor = network_editor
41
- self.simulate_scheduler = simulate_scheduler
34
+ analyze_similarity: bool = False
35
+ network_editor: List[EditRule] = field(default_factory=list)
36
+ simulate_scheduler: bool = False
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
-
16
+ from dataclasses import dataclass, field
17
17
  import math
18
18
  from enum import Enum
19
19
 
@@ -46,86 +46,44 @@ class QuantizationErrorMethod(Enum):
46
46
  HMSE = 6
47
47
 
48
48
 
49
+ @dataclass
49
50
  class QuantizationConfig:
51
+ """
52
+ A class that encapsulates all the different parameters used by the library to quantize a model.
53
+
54
+ Examples:
55
+ You can create a quantization configuration to apply to a model. For example, to quantize a model's weights and
56
+ activations using thresholds, with weight threshold selection based on MSE and activation threshold selection
57
+ using NOCLIPPING (min/max), while enabling relu_bound_to_power_of_2 and weights_bias_correction,
58
+ you can instantiate a quantization configuration like this:
59
+
60
+ >>> import model_compression_toolkit as mct
61
+ >>> qc = mct.core.QuantizationConfig(activation_error_method=mct.core.QuantizationErrorMethod.NOCLIPPING, weights_error_method=mct.core.QuantizationErrorMethod.MSE, relu_bound_to_power_of_2=True, weights_bias_correction=True)
62
+
63
+
64
+ The QuantizationConfig instance can then be used in the quantization workflow,
65
+ such as with Keras in the function: :func:~model_compression_toolkit.ptq.keras_post_training_quantization`.
66
+
67
+ """
50
68
 
51
- def __init__(self,
52
- activation_error_method: QuantizationErrorMethod = QuantizationErrorMethod.MSE,
53
- weights_error_method: QuantizationErrorMethod = QuantizationErrorMethod.MSE,
54
- relu_bound_to_power_of_2: bool = False,
55
- weights_bias_correction: bool = True,
56
- weights_second_moment_correction: bool = False,
57
- input_scaling: bool = False,
58
- softmax_shift: bool = False,
59
- shift_negative_activation_correction: bool = True,
60
- activation_channel_equalization: bool = False,
61
- z_threshold: float = math.inf,
62
- min_threshold: float = MIN_THRESHOLD,
63
- l_p_value: int = 2,
64
- linear_collapsing: bool = True,
65
- residual_collapsing: bool = True,
66
- shift_negative_ratio: float = 0.05,
67
- shift_negative_threshold_recalculation: bool = False,
68
- shift_negative_params_search: bool = False,
69
- concat_threshold_update: bool = False):
70
- """
71
- Class to wrap all different parameters the library quantize the input model according to.
72
-
73
- Args:
74
- activation_error_method (QuantizationErrorMethod): Which method to use from QuantizationErrorMethod for activation quantization threshold selection.
75
- weights_error_method (QuantizationErrorMethod): Which method to use from QuantizationErrorMethod for activation quantization threshold selection.
76
- relu_bound_to_power_of_2 (bool): Whether to use relu to power of 2 scaling correction or not.
77
- weights_bias_correction (bool): Whether to use weights bias correction or not.
78
- weights_second_moment_correction (bool): Whether to use weights second_moment correction or not.
79
- input_scaling (bool): Whether to use input scaling or not.
80
- softmax_shift (bool): Whether to use softmax shift or not.
81
- shift_negative_activation_correction (bool): Whether to use shifting negative activation correction or not.
82
- activation_channel_equalization (bool): Whether to use activation channel equalization correction or not.
83
- z_threshold (float): Value of z score for outliers removal.
84
- min_threshold (float): Minimum threshold to use during thresholds selection.
85
- l_p_value (int): The p value of L_p norm threshold selection.
86
- block_collapsing (bool): Whether to collapse block one to another in the input network
87
- shift_negative_ratio (float): Value for the ratio between the minimal negative value of a non-linearity output to its activation threshold, which above it - shifting negative activation should occur if enabled.
88
- shift_negative_threshold_recalculation (bool): Whether or not to recompute the threshold after shifting negative activation.
89
- shift_negative_params_search (bool): Whether to search for optimal shift and threshold in shift negative activation.
90
-
91
- Examples:
92
- One may create a quantization configuration to quantize a model according to.
93
- For example, to quantize a model's weights and activation using thresholds, such that
94
- weights threshold selection is done using MSE, activation threshold selection is done using NOCLIPPING (min/max),
95
- enabling relu_bound_to_power_of_2, weights_bias_correction,
96
- one can instantiate a quantization configuration:
97
-
98
- >>> import model_compression_toolkit as mct
99
- >>> qc = mct.core.QuantizationConfig(activation_error_method=mct.core.QuantizationErrorMethod.NOCLIPPING, weights_error_method=mct.core.QuantizationErrorMethod.MSE, relu_bound_to_power_of_2=True, weights_bias_correction=True)
100
-
101
-
102
- The QuantizationConfig instanse can then be passed to
103
- :func:`~model_compression_toolkit.ptq.keras_post_training_quantization`
104
-
105
- """
106
-
107
- self.activation_error_method = activation_error_method
108
- self.weights_error_method = weights_error_method
109
- self.relu_bound_to_power_of_2 = relu_bound_to_power_of_2
110
- self.weights_bias_correction = weights_bias_correction
111
- self.weights_second_moment_correction = weights_second_moment_correction
112
- self.activation_channel_equalization = activation_channel_equalization
113
- self.input_scaling = input_scaling
114
- self.softmax_shift = softmax_shift
115
- self.min_threshold = min_threshold
116
- self.shift_negative_activation_correction = shift_negative_activation_correction
117
- self.z_threshold = z_threshold
118
- self.l_p_value = l_p_value
119
- self.linear_collapsing = linear_collapsing
120
- self.residual_collapsing = residual_collapsing
121
- self.shift_negative_ratio = shift_negative_ratio
122
- self.shift_negative_threshold_recalculation = shift_negative_threshold_recalculation
123
- self.shift_negative_params_search = shift_negative_params_search
124
- self.concat_threshold_update = concat_threshold_update
125
-
126
- def __repr__(self):
127
- # Used for debugging, thus no cover.
128
- return str(self.__dict__) # pragma: no cover
69
+ activation_error_method: QuantizationErrorMethod = QuantizationErrorMethod.MSE
70
+ weights_error_method: QuantizationErrorMethod = QuantizationErrorMethod.MSE
71
+ relu_bound_to_power_of_2: bool = False
72
+ weights_bias_correction: bool = True
73
+ weights_second_moment_correction: bool = False
74
+ input_scaling: bool = False
75
+ softmax_shift: bool = False
76
+ shift_negative_activation_correction: bool = True
77
+ activation_channel_equalization: bool = False
78
+ z_threshold: float = math.inf
79
+ min_threshold: float = MIN_THRESHOLD
80
+ l_p_value: int = 2
81
+ linear_collapsing: bool = True
82
+ residual_collapsing: bool = True
83
+ shift_negative_ratio: float = 0.05
84
+ shift_negative_threshold_recalculation: bool = False
85
+ shift_negative_params_search: bool = False
86
+ concat_threshold_update: bool = False
129
87
 
130
88
 
131
89
  # Default quantization configuration the library use.
@@ -360,7 +360,7 @@ def shift_negative_function(graph: Graph,
360
360
  graph=graph,
361
361
  quant_config=core_config.quantization_config,
362
362
  tpc=graph.tpc,
363
- mixed_precision_enable=core_config.mixed_precision_enable)
363
+ mixed_precision_enable=core_config.is_mixed_precision_enabled)
364
364
 
365
365
  for candidate_qc in pad_node.candidates_quantization_cfg:
366
366
  candidate_qc.activation_quantization_cfg.enable_activation_quantization = False
@@ -377,7 +377,7 @@ def shift_negative_function(graph: Graph,
377
377
  graph=graph,
378
378
  quant_config=core_config.quantization_config,
379
379
  tpc=graph.tpc,
380
- mixed_precision_enable=core_config.mixed_precision_enable)
380
+ mixed_precision_enable=core_config.is_mixed_precision_enabled)
381
381
 
382
382
  original_non_linear_activation_nbits = non_linear_node_cfg_candidate.activation_n_bits
383
383
  # The non-linear node's output should be float, so we approximate it by using 16bits quantization.
@@ -119,7 +119,7 @@ def core_runner(in_model: Any,
119
119
  tpc,
120
120
  core_config.bit_width_config,
121
121
  tb_w,
122
- mixed_precision_enable=core_config.mixed_precision_enable,
122
+ mixed_precision_enable=core_config.is_mixed_precision_enabled,
123
123
  running_gptq=running_gptq)
124
124
 
125
125
  hessian_info_service = HessianInfoService(graph=graph, representative_dataset_gen=representative_data_gen,
@@ -136,7 +136,7 @@ def core_runner(in_model: Any,
136
136
  ######################################
137
137
  # Finalize bit widths
138
138
  ######################################
139
- if core_config.mixed_precision_enable:
139
+ if core_config.is_mixed_precision_enabled:
140
140
  if core_config.mixed_precision_config.configuration_overwrite is None:
141
141
 
142
142
  filter_candidates_for_mixed_precision(graph, target_resource_utilization, fw_info, tpc)
@@ -161,7 +161,7 @@ def core_runner(in_model: Any,
161
161
  else:
162
162
  bit_widths_config = []
163
163
 
164
- tg = set_bit_widths(core_config.mixed_precision_enable,
164
+ tg = set_bit_widths(core_config.is_mixed_precision_enabled,
165
165
  tg,
166
166
  bit_widths_config)
167
167
 
@@ -175,7 +175,7 @@ def core_runner(in_model: Any,
175
175
  fw_info=fw_info,
176
176
  fw_impl=fw_impl)
177
177
 
178
- if core_config.mixed_precision_enable:
178
+ if core_config.is_mixed_precision_enabled:
179
179
  # Retrieve lists of tuples (node, node's final weights/activation bitwidth)
180
180
  weights_conf_nodes_bitwidth = tg.get_final_weights_config(fw_info)
181
181
  activation_conf_nodes_bitwidth = tg.get_final_activation_config()
@@ -199,7 +199,7 @@ if FOUND_TF:
199
199
  KerasModelValidation(model=in_model,
200
200
  fw_info=DEFAULT_KERAS_INFO).validate()
201
201
 
202
- if core_config.mixed_precision_enable:
202
+ if core_config.is_mixed_precision_enabled:
203
203
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
204
204
  Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
205
205
  "Ensure usage of the correct API for keras_post_training_quantization "
@@ -165,7 +165,7 @@ if FOUND_TORCH:
165
165
 
166
166
  """
167
167
 
168
- if core_config.mixed_precision_enable:
168
+ if core_config.is_mixed_precision_enabled:
169
169
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
170
170
  Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
171
171
  "Ensure usage of the correct API for 'pytorch_gradient_post_training_quantization' "
@@ -124,7 +124,7 @@ if FOUND_TF:
124
124
  KerasModelValidation(model=in_model,
125
125
  fw_info=fw_info).validate()
126
126
 
127
- if core_config.mixed_precision_enable:
127
+ if core_config.is_mixed_precision_enabled:
128
128
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
129
129
  Logger.critical("Given quantization config to mixed-precision facade is not of type "
130
130
  "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
@@ -96,7 +96,7 @@ if FOUND_TORCH:
96
96
 
97
97
  fw_info = DEFAULT_PYTORCH_INFO
98
98
 
99
- if core_config.mixed_precision_enable:
99
+ if core_config.is_mixed_precision_enabled:
100
100
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
101
101
  Logger.critical("Given quantization config to mixed-precision facade is not of type "
102
102
  "MixedPrecisionQuantizationConfig. Please use "
@@ -176,7 +176,7 @@ if FOUND_TF:
176
176
  KerasModelValidation(model=in_model,
177
177
  fw_info=DEFAULT_KERAS_INFO).validate()
178
178
 
179
- if core_config.mixed_precision_enable:
179
+ if core_config.is_mixed_precision_enabled:
180
180
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
181
181
  Logger.critical("Given quantization config to mixed-precision facade is not of type "
182
182
  "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization API,"
@@ -145,7 +145,7 @@ if FOUND_TORCH:
145
145
  f"If you encounter an issue, please open an issue in our GitHub "
146
146
  f"project https://github.com/sony/model_optimization")
147
147
 
148
- if core_config.mixed_precision_enable:
148
+ if core_config.is_mixed_precision_enabled:
149
149
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
150
150
  Logger.critical("Given quantization config to mixed-precision facade is not of type "
151
151
  "MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API,"