mct-nightly 1.11.0.20240303.post423__py3-none-any.whl → 1.11.0.20240305.post352__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {mct_nightly-1.11.0.20240303.post423.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/METADATA +5 -5
  2. {mct_nightly-1.11.0.20240303.post423.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/RECORD +32 -30
  3. model_compression_toolkit/core/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_implementation.py +2 -2
  5. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +4 -70
  6. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  7. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  8. model_compression_toolkit/core/common/pruning/memory_calculator.py +19 -1
  9. model_compression_toolkit/core/common/quantization/core_config.py +3 -3
  10. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +0 -3
  11. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -3
  12. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +0 -1
  13. model_compression_toolkit/core/keras/keras_implementation.py +2 -2
  14. model_compression_toolkit/core/keras/kpi_data_facade.py +5 -6
  15. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +19 -19
  16. model_compression_toolkit/core/pytorch/constants.py +3 -0
  17. model_compression_toolkit/core/pytorch/kpi_data_facade.py +5 -5
  18. model_compression_toolkit/core/pytorch/pruning/__init__.py +14 -0
  19. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +315 -0
  20. model_compression_toolkit/core/pytorch/pytorch_implementation.py +2 -2
  21. model_compression_toolkit/gptq/keras/quantization_facade.py +4 -4
  22. model_compression_toolkit/gptq/pytorch/quantization_facade.py +3 -3
  23. model_compression_toolkit/pruning/__init__.py +1 -0
  24. model_compression_toolkit/pruning/pytorch/__init__.py +14 -0
  25. model_compression_toolkit/pruning/pytorch/pruning_facade.py +166 -0
  26. model_compression_toolkit/ptq/keras/quantization_facade.py +4 -7
  27. model_compression_toolkit/ptq/pytorch/quantization_facade.py +3 -6
  28. model_compression_toolkit/qat/keras/quantization_facade.py +6 -9
  29. model_compression_toolkit/qat/pytorch/quantization_facade.py +3 -7
  30. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +0 -64
  31. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +0 -53
  32. {mct_nightly-1.11.0.20240303.post423.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/LICENSE.md +0 -0
  33. {mct_nightly-1.11.0.20240303.post423.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/WHEEL +0 -0
  34. {mct_nightly-1.11.0.20240303.post423.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 1.11.0.20240303.post423
3
+ Version: 1.11.0.20240305.post352
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -173,11 +173,11 @@ This pruning technique is designed to compress models for specific hardware arch
173
173
  taking into account the target platform's Single Instruction, Multiple Data (SIMD) capabilities.
174
174
  By pruning groups of channels (SIMD groups), our approach not only reduces model size
175
175
  and complexity, but ensures that better utilization of channels is in line with the SIMD architecture
176
- for a target KPI of weights memory footprint.
176
+ for a target KPI of weights memory footprint.
177
+ [Keras API](https://sony.github.io/model_optimization/docs/api/experimental_api_docs/methods/keras_pruning_experimental.html)
178
+ [Pytorch API](https://sony.github.io/model_optimization/docs/api/experimental_api_docs/methods/pytorch_pruning_experimental.html)
177
179
 
178
180
 
179
- <u>_Note: Currently, only Keras models pruning is supported._</u>
180
-
181
181
  #### Results
182
182
 
183
183
  Results for applying pruning to reduce the parameters of the following models by 50%:
@@ -185,7 +185,7 @@ Results for applying pruning to reduce the parameters of the following models by
185
185
  | Model | Dense Model Accuracy | Pruned Model Accuracy |
186
186
  |-----------------|----------------------|-----------------------|
187
187
  | ResNet50 [2] | 75.1 | 72.4 |
188
- | DenseNet121 [2] | 75.0 | 71.15 |
188
+ | DenseNet121 [3] | 74.44 | 71.71 |
189
189
 
190
190
 
191
191
 
@@ -2,7 +2,7 @@ model_compression_toolkit/__init__.py,sha256=7NFeY28X4NSHeLTq1JcKkAnp031EMjorU-I
2
2
  model_compression_toolkit/constants.py,sha256=_OW_bUeQmf08Bb4oVZ0KfUt-rcCeNOmdBv3aP7NF5fM,3631
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=b9DVktZ-LymFcRxv2aL_sdiE6S2sSrFGWltx6dgEuUY,4863
5
- model_compression_toolkit/core/__init__.py,sha256=AXC_DLak39yxBGURNPm6LjtxB1P7iD1O3Iq-q6AA5xM,1874
5
+ model_compression_toolkit/core/__init__.py,sha256=P-7OYR4TFYxVV_ZpIJBogkX8bGvXcijlF65Ez3ivjhc,1838
6
6
  model_compression_toolkit/core/analyzer.py,sha256=dbsD61pakp_9JXNyAScLdtJvcXny9jr_cMbET0Bd3Sg,2975
7
7
  model_compression_toolkit/core/exporter.py,sha256=U_-ea-zYHsnIt2ydameMLZ_gzDaCMI1dRa5IjA8RUuc,4233
8
8
  model_compression_toolkit/core/graph_prep_runner.py,sha256=3xp0WYqyeRdlBkf5R6uD2zWubg_JPttOwS7JRhKykBY,10043
@@ -11,7 +11,7 @@ model_compression_toolkit/core/runner.py,sha256=RgN9l0v7aFYu6MTuIZGAB2syr6NBqG_v
11
11
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
12
12
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
13
13
  model_compression_toolkit/core/common/data_loader.py,sha256=jCoVIb4yeOWyCrCNRB1W-mgLSyqNVGEepFXrIqufVc4,4119
14
- model_compression_toolkit/core/common/framework_implementation.py,sha256=pWbFiioYqL69-_bGNdxerLzGgP-HSiegRzKyGkAo27w,20533
14
+ model_compression_toolkit/core/common/framework_implementation.py,sha256=n3T0uOfeni_P-5ut4ModjPdtRUWCWzfAIOW0LdOlVb4,20529
15
15
  model_compression_toolkit/core/common/framework_info.py,sha256=1ZMMGS9ip-kSflqkartyNRt9aQ5ub1WepuTRcTy-YSQ,6337
16
16
  model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
17
17
  model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
@@ -62,10 +62,10 @@ model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py,sha256
62
62
  model_compression_toolkit/core/common/mixed_precision/configurable_quant_id.py,sha256=LLDguK7afsbN742ucLpmJr5TUfTyFpK1vbf2bpVr1v0,882
63
63
  model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py,sha256=kmyBcqGh3qYqo42gIZzouQEljTNpF9apQt6cXEVkTQ0,3871
64
64
  model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=x0cweemRG3_7FlvAbxFK5Zi77qpoKAGqtGndY8MtgwM,2222
65
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=n7ZwJSd-6OT59Zg8eJW9PW_NjHy-2q2F4IKyV-_p9TM,8032
66
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=eRohepDyOr_cXVjgp2EnntPSIbWDADH0DIjP2Wc15oQ,7107
65
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=wKiG-d_qg5dgPpZRK6dvP0MCXUKIfDNpeCHQrI3JxqY,4245
66
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=mN-QeabIu_Mz1IzPeQjqgqprCTdwGm4ThYX0gZAek-E,7103
67
67
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=XrHMQAzIUwyHqrY30cwK1lP5jKI0toeWFEG4e3qKzlY,35988
68
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=xo_RyCUf2I-nvjof6p_40QMgiUvNpr77iyTZhH9u8ug,28227
68
+ model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=z_0hz6mX27Q-Xzg-ftvoyPrw_i9faBGslkP0u0WGxTg,28223
69
69
  model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
70
70
  model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=A0-XsePLlqeYDgAFrDLJjYa4mY3VFKRWYQ3vHpPE-c4,6326
71
71
  model_compression_toolkit/core/common/mixed_precision/kpi_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
@@ -83,7 +83,7 @@ model_compression_toolkit/core/common/network_editors/node_filters.py,sha256=uML
83
83
  model_compression_toolkit/core/common/pruning/__init__.py,sha256=DGJybkDQtKMSMFoZ-nZ3ZifA8uJ6G_D20wHhKHNlmU0,699
84
84
  model_compression_toolkit/core/common/pruning/channels_grouping.py,sha256=4jsr1xEBNpV2c4ipi366IfHoHCJVqoRUTTOJdlRomvc,3892
85
85
  model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py,sha256=Go3FPHEyvm_Ra0c3XmdUVEFLCGWFoddMuYXd3Trdxkw,7594
86
- model_compression_toolkit/core/common/pruning/memory_calculator.py,sha256=3VcvCfJGsYSS0VVdP5uaWoI2cncBYtSEU-PwF-OG3E0,18320
86
+ model_compression_toolkit/core/common/pruning/memory_calculator.py,sha256=RnMmgNDHekKFOj-b-ad5rhjuKUvbVawy1A31nxuCRTg,19217
87
87
  model_compression_toolkit/core/common/pruning/prune_graph.py,sha256=ddbZLuWvlNoj5so_5NRbIuG5qDFxD9ApG2gPirbov8o,3317
88
88
  model_compression_toolkit/core/common/pruning/pruner.py,sha256=vXxzBXQ-oAEnw6PAD1SUiNXX7Xix4JJ0LAmV04sjFz0,7313
89
89
  model_compression_toolkit/core/common/pruning/pruning_config.py,sha256=IfF824hNttyw2i4Tuf3g8CUfelJR3eZuOLzf2aEZNAM,3442
@@ -99,19 +99,18 @@ model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py,sha256=hk
99
99
  model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256=gmzD32xsfJH8vkkqaspS7vYa6VWayk1GJe-NfoAEugQ,5901
100
100
  model_compression_toolkit/core/common/quantization/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
101
101
  model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=VW0vGCPKMxbhl2cB_zHx3g9c7cNS4ctVEAvnaNq17jw,5153
102
- model_compression_toolkit/core/common/quantization/core_config.py,sha256=8DRM4Ar4Er-bllo56LG-Lcx9U2Ebd3jJctf4t2hOcXc,2021
102
+ model_compression_toolkit/core/common/quantization/core_config.py,sha256=IkD4Jl9PWdPucfUMq0TtyUl5DBJvha7Dd2xSW7_7dz8,2015
103
103
  model_compression_toolkit/core/common/quantization/debug_config.py,sha256=HtkMmneN-EmAzgZK4Vp4M8Sqm5QKdrvNyyZMpaVqYzY,1482
104
104
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=4GCr4Z6pRMbxIAnq4s7YtdMSqwbRwUzTzCFfs2ahVfk,6137
105
105
  model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=Q634XzMtjqReiLni8974y13apzbZ9nref-XBGjH17-0,16761
106
106
  model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=0547euhiaEX4vgpvIwHd7-pZ3iI7QmIc6Y_qHV4Y5sY,6713
107
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=9GWuVMW9ifWnflUjLZoHZtcd8NpHyHTUzCdJFKIaGlo,2352
108
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=sEPDeClFxh0uHEGznX7E3bSOJ_t0kUvyWcdxcyZJdwA,4090
107
+ model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=f2Qa2majjO-gIN3lxqsA8icKJ9FMP-sKbw3lI6XNgBg,2137
108
+ model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=mrgVzZszWjxnjT8zm77UVLWKTOwd2thGBo6WNqAS4X8,3867
109
109
  model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=_OQEFAdYDTHu2Qp-qs02Z1CDxugUKG6k5eCePS1WpXY,2939
110
110
  model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=UK_YshvZI0-LrKeT9gFGYcMA7pma1kaR5JAfzJH3HNw,3614
111
111
  model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=qkrWJXLyDSIJhvT8tO9Nh51f4abyVR8zMFuaaMRRrRw,12304
112
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=_U4IFPuzGyyAymjDjsPl2NF6UbFggqBaiA1Td3sug3I,1608
112
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
113
113
  model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=rwCedE0zggamSBY50rqh-xqZpIMrn8o96YH_jMCuPrk,16505
114
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py,sha256=qDfJbvY64KLOG6n18ddEPTFGrKHlaXzZ136TrVpgH9s,2917
115
114
  model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=h8Zmpq3KdcsdUUy7K1fvWOVSki0mxT8wtKZXGmgFl74,7405
116
115
  model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
117
116
  model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=W4j9IB1Grj_Ku1pLjPxb-HLcYU9LTDuf9_0JilbqU2w,8484
@@ -122,7 +121,6 @@ model_compression_toolkit/core/common/quantization/quantization_params_generatio
122
121
  model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=53OFL3IZxtH-UPakf3h_LZkaZAa3cgc2oqgMUe3Sg8o,9689
123
122
  model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py,sha256=oiJn1twYpTaq_z5qX4d8_nnk_jouYWHq8z0WAknl5oE,7879
124
123
  model_compression_toolkit/core/common/quantization/quantizers/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
125
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py,sha256=ZS3IXGbUTW580vwVF5jgxfPVYVL3tQrpvoSqDxVu7zQ,2325
126
124
  model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py,sha256=P0x_y18LypBxP2tV9OWizheYfILqvaMC8RwHo04sUpQ,2761
127
125
  model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py,sha256=5JuPwb9HDHaYQj1YyNWGY7GdjJ105Yr8iEEZhzfuRW4,14190
128
126
  model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py,sha256=FVeuK-LeuAsRFcqo5uaNHmb6oTOFs21ltghtqswl6KM,5486
@@ -152,10 +150,10 @@ model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7V
152
150
  model_compression_toolkit/core/keras/constants.py,sha256=YhuzRqXAdkRFzLT5lRD_jtLVYcUb-d4fUm-D49z5XOg,3158
153
151
  model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
154
152
  model_compression_toolkit/core/keras/default_framework_info.py,sha256=Ha4HTHuiw_KTS5Po1Xnv6GyK9eprpDhYWf-eooS62Ys,4961
155
- model_compression_toolkit/core/keras/keras_implementation.py,sha256=j6sJnH5UUpnyXAji01XVl8UiuHQ7JFCUpm7pRLXwYas,28148
153
+ model_compression_toolkit/core/keras/keras_implementation.py,sha256=kvrZeq3gmHeJKZXPyz55NGwO3EseTgo6022ZCWaUPW8,28144
156
154
  model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
157
155
  model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=Aqh31wOPaiZcJIOm-uJwzev0eTMdJyXaOk97rs4z7BU,3879
158
- model_compression_toolkit/core/keras/kpi_data_facade.py,sha256=0OkwSeuDMUQRelgvp4mz8enLsaSamsp5ePGkRKrZQOU,4826
156
+ model_compression_toolkit/core/keras/kpi_data_facade.py,sha256=nZdRhQuIXjoL3sq2HjffKHWfnafVM-j_oqk0Cc5Op3I,4889
159
157
  model_compression_toolkit/core/keras/tf_tensor_numpy.py,sha256=1kBs9URqZTfmRXAsCqvnekV5bKUL3MyqGbORewLIwu8,2457
160
158
  model_compression_toolkit/core/keras/back2framework/__init__.py,sha256=rhIiXg_nBgUZ-baE3M6SzCuQbcnq4iebY1jtJBvKHOM,808
161
159
  model_compression_toolkit/core/keras/back2framework/factory_model_builder.py,sha256=GSh1Piz5qpA7IlvHTMqUvPn7WBDa0IHEDZdd_TzY9XA,2226
@@ -192,7 +190,7 @@ model_compression_toolkit/core/keras/mixed_precision/__init__.py,sha256=sw7LOPN1
192
190
  model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=s4m5jR2OsBCgDY0W0rB2T3YQwZPQvhZvNfSAPc9Vggo,4922
193
191
  model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py,sha256=X_1qI-1iD6HxG6ERe48OlCelw-nn3k4NZpeV1vtfNOs,6363
194
192
  model_compression_toolkit/core/keras/pruning/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
195
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py,sha256=olRbKncMy159Mp2C9UC8VomQ55rDXYVx2hCJ-atQsic,12455
193
+ model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py,sha256=KUnSnpx8ty8NmNc6HA-HiWg-FBg7rBwkbLgOM_5ZfJA,12676
196
194
  model_compression_toolkit/core/keras/quantizer/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
197
195
  model_compression_toolkit/core/keras/quantizer/base_quantizer.py,sha256=eMRjAUU189-AVwNGMlV0M-ZlL48ZYmILzutheUT00xU,1628
198
196
  model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py,sha256=Oi64CD83OopPoQNAarl2MJRbCKujU2W8Wdrs9KOPNWk,6151
@@ -211,11 +209,11 @@ model_compression_toolkit/core/keras/statistics_correction/__init__.py,sha256=9H
211
209
  model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py,sha256=XNCtT9klMcsO1v5KA3MmCq_WgXOIT5QSzbfTOa9T-04,3060
212
210
  model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
213
211
  model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
214
- model_compression_toolkit/core/pytorch/constants.py,sha256=58SdcH1N3mfU4sAx9E54EdizY04bY4gKSf10dCGbW0U,2534
212
+ model_compression_toolkit/core/pytorch/constants.py,sha256=NI-J7REuxn06oEIHsmJ4GqtNC3TbV8xlkJjt5Ar-c4U,2626
215
213
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
216
- model_compression_toolkit/core/pytorch/kpi_data_facade.py,sha256=etYbVo5dwkBwjt8dqlrpGFijvd9y8oP_KX4JsdubnQM,4978
214
+ model_compression_toolkit/core/pytorch/kpi_data_facade.py,sha256=xQdg8vtkwhx1uzElrh0KtwKdWFr6b2Guzv913iA_VoI,4978
217
215
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=IoMvTch5awAEPvB6Tg6ANhFGXvfSgv7JLsUBlxpMwk4,4330
218
- model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=KGwEgOb4jwSWGi9mGd6DDBLcWOcveq41SMjx_0z-wlE,26078
216
+ model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=EEHOnsXsvWlNm4x-4ADuFHsIc8vLrZP9pMh3raegaNg,26074
219
217
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=n_B4a6FMwM9D2w8kzy3oenBWZgXNZuIZgTJC6JEuTy0,3250
220
218
  model_compression_toolkit/core/pytorch/utils.py,sha256=Bxep5o-Zw3O_iLdlzidiU186r0O6MXXtkJIZm-L1PiA,2847
221
219
  model_compression_toolkit/core/pytorch/back2framework/__init__.py,sha256=H_WixgN0elVWf3exgGYsi58imPoYDj5eYPeh6x4yfug,813
@@ -254,6 +252,8 @@ model_compression_toolkit/core/pytorch/hessian/weights_trace_hessian_calculator_
254
252
  model_compression_toolkit/core/pytorch/mixed_precision/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
255
253
  model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py,sha256=iEtLjZw4l8Oyj81QnxbQFRKVuTbQEI98P5LuU7SNGq8,4656
256
254
  model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py,sha256=VUN9vvWQWAh281C0xgV3w4T2DkSaxFZ-xmBgF50vGdo,5961
255
+ model_compression_toolkit/core/pytorch/pruning/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
256
+ model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py,sha256=a1-P0ZI8jO1qlYNWqas9AvRb3bwcrNluul6qPc140R8,14601
257
257
  model_compression_toolkit/core/pytorch/quantizer/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
258
258
  model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py,sha256=rox-f5wbRyxU1UHeHyaoIDXB9r9fCXm1dPN4FVwHqTc,6464
259
259
  model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py,sha256=uyeBtNokyDUikk-YkDP_mN_2DX0J5oPm3kSfdSUT2Ck,4420
@@ -341,7 +341,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
341
341
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
342
342
  model_compression_toolkit/gptq/keras/gptq_training.py,sha256=75j56X2AcNI_5hInsLvXnWZOGMZIKQY2hStVKBaA_Bc,17705
343
343
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=FIGqzJbG6GkdHenvdMu-tGTjp4j9BewdF_spmWCb4Mo,4627
344
- model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=UryGNh9JudVsPeUYgXRsC3zqPuWjlQlPxpt3v9YW00w,15170
344
+ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=STtrf7bzqVr1P1Pah3cftGRcYSF7i-tSkcjrcuMdpTw,15162
345
345
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
346
346
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=8NrJBftKFbMAF_jYaAbLP6GBwpCv3Ln1NKURaV75zko,4770
347
347
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
@@ -358,7 +358,7 @@ model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa9
358
358
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
359
359
  model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=RfkJNQbeqGZQvlmV0dZO7YJ894Gx2asLnnIHFdWNEZ0,15078
360
360
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=-0GDC2cr-XXS7cTFTnDflJivGN7VaPnzVPsxCE-vZNU,3955
361
- model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=ilP9B4rkSWOB94HF5fR8NKl3mZdYrdDARSxjKW92JvQ,13324
361
+ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=Mr0AWmqaxdr9O4neBej4hB6BWxwP_okHjQ2SwUoVvrM,13318
362
362
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
363
363
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=Zb-P0yRyZHHBlDvUBdRwxDpdduEJyJp6OT9pfKFF5ks,4171
364
364
  model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
@@ -370,20 +370,22 @@ model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_qu
370
370
  model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py,sha256=FgPSKoV8p8y-gLNz359XdOPD6w_wpDvcJFtTNLWqYb0,9099
371
371
  model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
372
372
  model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py,sha256=6uxq_w62jn8DDOt9T7VtA6jZ8jTAPcbTufKFOYpVUm4,8768
373
- model_compression_toolkit/pruning/__init__.py,sha256=inDsxrBI8aXkYI9ZeIeNVTwa5D4cv2ZQ2q9v7hb0_Og,1008
373
+ model_compression_toolkit/pruning/__init__.py,sha256=lQMZS8G0pvR1LVi53nnJHNXgLNTan_MWMdwsVxhjrow,1106
374
374
  model_compression_toolkit/pruning/keras/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
375
375
  model_compression_toolkit/pruning/keras/pruning_facade.py,sha256=PHKZYBHrVyR348-a6gw44NrV8Ra9iaeFJ0WbWYpzX8k,8020
376
+ model_compression_toolkit/pruning/pytorch/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
377
+ model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=hO13kbCIyoAgVW6hRoJCUpHr6ThOuqz1Vkg-cftJY5k,8906
376
378
  model_compression_toolkit/ptq/__init__.py,sha256=50QBTXOKdj9XLjXtrvf0mhC9FlW6TOi9-pjl96RLR14,930
377
379
  model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
378
380
  model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
379
- model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=39-bHTOZ7OENxvbC_HKelwjSO3e8BccGexgwWPClIDk,9969
381
+ model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=J6luWnjkyk8KvXwZnVb5sEG9fynhXIOBMbIRPXDs2g8,9805
380
382
  model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
381
- model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=2r_2BRAocLiMciGBa59wbpDt6U4RWMLsJW0hIwcIbLQ,8722
383
+ model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=gsr_tPfY9HbAxKWaHgiq3J_osXUkOoazdLuvWdyyg3A,8567
382
384
  model_compression_toolkit/qat/__init__.py,sha256=BYKgH1NwB9fqF1TszULQ5tDfLI-GqgZV5sao-lDN9EM,1091
383
385
  model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
384
386
  model_compression_toolkit/qat/common/qat_config.py,sha256=kbSxFL6_u28furq5mW_75STWDmyX4clPt-seJAnX3IQ,3445
385
387
  model_compression_toolkit/qat/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
386
- model_compression_toolkit/qat/keras/quantization_facade.py,sha256=hB_VrhSqwOjGnOT8BYpXkh52EMzo7I-62-IEGohg_74,16253
388
+ model_compression_toolkit/qat/keras/quantization_facade.py,sha256=3hdUBm-Q8mRe_sBrpHKFguVaUQPEH95ZeiyYN_EclwI,16073
387
389
  model_compression_toolkit/qat/keras/quantizer/__init__.py,sha256=zmYyCa25_KLCSUCGUDRslh3RCIjcRMxc_oXa54Aui-4,996
388
390
  model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py,sha256=gPuIgQb8OafvC3SuA8jNsGoy8S8eTsDCEKuh36WDNss,2104
389
391
  model_compression_toolkit/qat/keras/quantizer/quant_utils.py,sha256=cBULOgWUodcBO1lHevZggdTevuDYI6tQceV86U2x6DA,2543
@@ -395,7 +397,7 @@ model_compression_toolkit/qat/keras/quantizer/ste_rounding/__init__.py,sha256=cc
395
397
  model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=I4KlaGv17k71IyjuSG9M0OlXlD5P0pfvKa6oCyRQ5FE,13517
396
398
  model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py,sha256=EED6LfqhX_OhDRJ9e4GwbpgNC9vq7hoXyJS2VPvG2qc,10789
397
399
  model_compression_toolkit/qat/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
398
- model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=Kv2vUKY2cSt6kYB7J2B1SC_PGzSFqecpxDcKWGeRzuQ,12629
400
+ model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=pNBkkT9Ey1CmrQUqg-NHdf4tFVjgoVldX8bqXJp9lik,12473
399
401
  model_compression_toolkit/qat/pytorch/quantizer/__init__.py,sha256=xYa4C8pr9cG1f3mQQcBXO_u3IdJN-zl7leZxuXDs86w,1003
400
402
  model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py,sha256=FnhuFCuQoSf78FM1z1UZgXXd3k-mKSM7i9dYOuJUmeA,2213
401
403
  model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py,sha256=GOYRDXvQSGe_iUFVmvDy5BqC952hu_-rQO06n8QCyw0,5491
@@ -470,8 +472,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
470
472
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
471
473
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
472
474
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=SbvRlIdE32PEBsINt1bhSqvrKL_zbM9V-aeSkOn-sw4,3083
473
- mct_nightly-1.11.0.20240303.post423.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
474
- mct_nightly-1.11.0.20240303.post423.dist-info/METADATA,sha256=7TlXG0QPNK4ZWwDIaw_T_tOAkcIM7jwcxjA1v6RVudQ,17187
475
- mct_nightly-1.11.0.20240303.post423.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
476
- mct_nightly-1.11.0.20240303.post423.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
477
- mct_nightly-1.11.0.20240303.post423.dist-info/RECORD,,
475
+ mct_nightly-1.11.0.20240305.post352.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
476
+ mct_nightly-1.11.0.20240305.post352.dist-info/METADATA,sha256=_XbUd27NO_Oa_pHVS9Jw32qfVPNAtdEeCOqXXLAoPhA,17377
477
+ mct_nightly-1.11.0.20240305.post352.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
478
+ mct_nightly-1.11.0.20240305.post352.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
479
+ mct_nightly-1.11.0.20240305.post352.dist-info/RECORD,,
@@ -22,6 +22,6 @@ from model_compression_toolkit.core.common.mixed_precision import mixed_precisio
22
22
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, QuantizationErrorMethod, DEFAULTCONFIG
23
23
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
24
24
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
25
- from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig, MixedPrecisionQuantizationConfigV2
25
+ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
26
26
  from model_compression_toolkit.core.keras.kpi_data_facade import keras_kpi_data
27
27
  from model_compression_toolkit.core.pytorch.kpi_data_facade import pytorch_kpi_data
@@ -18,7 +18,7 @@ from typing import Callable, Any, List, Tuple, Dict
18
18
  import numpy as np
19
19
 
20
20
  from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
21
- from model_compression_toolkit.core import MixedPrecisionQuantizationConfigV2
21
+ from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
22
22
  from model_compression_toolkit.core import common
23
23
  from model_compression_toolkit.core.common import BaseNode
24
24
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
@@ -304,7 +304,7 @@ class FrameworkImplementation(ABC):
304
304
  @abstractmethod
305
305
  def get_sensitivity_evaluator(self,
306
306
  graph: Graph,
307
- quant_config: MixedPrecisionQuantizationConfigV2,
307
+ quant_config: MixedPrecisionQuantizationConfig,
308
308
  representative_data_gen: Callable,
309
309
  fw_info: FrameworkInfo,
310
310
  hessian_info_service: HessianInfoService = None,
@@ -13,16 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from enum import Enum
17
- from typing import List, Callable, Tuple
16
+ from typing import List, Callable
18
17
 
19
- from model_compression_toolkit.logger import Logger
20
18
  from model_compression_toolkit.core.common.mixed_precision.distance_weighting import get_average_weights
21
- from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, DEFAULTCONFIG
22
19
  from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
23
20
 
24
21
 
25
- class MixedPrecisionQuantizationConfigV2:
22
+ class MixedPrecisionQuantizationConfig:
26
23
 
27
24
  def __init__(self,
28
25
  compute_distance_fn: Callable = None,
@@ -36,8 +33,6 @@ class MixedPrecisionQuantizationConfigV2:
36
33
  metric_normalization_threshold: float = 1e10):
37
34
  """
38
35
  Class with mixed precision parameters to quantize the input model.
39
- Unlike QuantizationConfig, number of bits for quantization is a list of possible bit widths to
40
- support mixed-precision model quantization.
41
36
 
42
37
  Args:
43
38
  compute_distance_fn (Callable): Function to compute a distance between two tensors.
@@ -70,67 +65,6 @@ class MixedPrecisionQuantizationConfigV2:
70
65
  self.metric_normalization_threshold = metric_normalization_threshold
71
66
 
72
67
 
73
- class MixedPrecisionQuantizationConfig(QuantizationConfig):
74
-
75
- def __init__(self,
76
- qc: QuantizationConfig = DEFAULTCONFIG,
77
- compute_distance_fn: Callable = compute_mse,
78
- distance_weighting_method: Callable = get_average_weights,
79
- num_of_images: int = 32,
80
- configuration_overwrite: List[int] = None,
81
- num_interest_points_factor: float = 1.0):
82
- """
83
- Class to wrap all different parameters the library quantize the input model according to.
84
- Unlike QuantizationConfig, number of bits for quantization is a list of possible bit widths to
85
- support mixed-precision model quantization.
86
-
87
- Args:
88
- qc (QuantizationConfig): QuantizationConfig object containing parameters of how the model should be quantized.
89
- compute_distance_fn (Callable): Function to compute a distance between two tensors.
90
- distance_weighting_method (Callable): Function to use when weighting the distances among different layers when computing the sensitivity metric.
91
- num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
92
- configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one.
93
- num_interest_points_factor: A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric.
94
-
95
- """
96
-
97
- super().__init__(**qc.__dict__)
98
- self.compute_distance_fn = compute_distance_fn
99
- self.distance_weighting_method = distance_weighting_method
100
- self.num_of_images = num_of_images
101
- self.configuration_overwrite = configuration_overwrite
102
-
103
- assert 0.0 < num_interest_points_factor <= 1.0, "num_interest_points_factor should represent a percentage of " \
104
- "the base set of interest points that are required to be " \
105
- "used for mixed-precision metric evaluation, " \
106
- "thus, it should be between 0 to 1"
107
- self.num_interest_points_factor = num_interest_points_factor
108
-
109
- def separate_configs(self) -> Tuple[QuantizationConfig, MixedPrecisionQuantizationConfigV2]:
110
- """
111
- A function to separate the old MixedPrecisionQuantizationConfig into QuantizationConfig
112
- and MixedPrecisionQuantizationConfigV2
113
-
114
- Returns: QuantizationConfig, MixedPrecisionQuantizationConfigV2
115
-
116
- """
117
- _dummy_quant_config = QuantizationConfig()
118
- _dummy_mp_config_experimental = MixedPrecisionQuantizationConfigV2()
119
- qc_dict = {}
120
- mp_dict = {}
121
- for k, v in self.__dict__.items():
122
- if hasattr(_dummy_quant_config, k):
123
- qc_dict.update({k: v})
124
- elif hasattr(_dummy_mp_config_experimental, k):
125
- mp_dict.update({k: v})
126
- else:
127
- Logger.error(f'Attribute "{k}" mismatch: exists in MixedPrecisionQuantizationConfig but not in '
128
- f'MixedPrecisionQuantizationConfigV2') # pragma: no cover
129
-
130
- return QuantizationConfig(**qc_dict), MixedPrecisionQuantizationConfigV2(**mp_dict)
131
-
132
-
133
68
  # Default quantization configuration the library use.
134
- DEFAULT_MIXEDPRECISION_CONFIG = MixedPrecisionQuantizationConfig(DEFAULTCONFIG,
135
- compute_mse,
136
- get_average_weights)
69
+ DEFAULT_MIXEDPRECISION_CONFIG = MixedPrecisionQuantizationConfig(compute_distance_fn=compute_mse,
70
+ distance_weighting_method=get_average_weights)
@@ -18,7 +18,7 @@ from enum import Enum
18
18
  import numpy as np
19
19
  from typing import List, Callable, Dict
20
20
 
21
- from model_compression_toolkit.core import MixedPrecisionQuantizationConfigV2
21
+ from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
22
22
  from model_compression_toolkit.core.common import Graph
23
23
  from model_compression_toolkit.core.common.hessian import HessianInfoService
24
24
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI, KPITarget
@@ -48,7 +48,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
48
48
  fw_info: FrameworkInfo,
49
49
  fw_impl: FrameworkImplementation,
50
50
  target_kpi: KPI,
51
- mp_config: MixedPrecisionQuantizationConfigV2,
51
+ mp_config: MixedPrecisionQuantizationConfig,
52
52
  representative_data_gen: Callable,
53
53
  search_method: BitWidthSearchMethod = BitWidthSearchMethod.INTEGER_PROGRAMMING,
54
54
  hessian_info_service: HessianInfoService=None) -> List[int]:
@@ -18,7 +18,7 @@ import numpy as np
18
18
  from typing import Callable, Any, List, Tuple
19
19
 
20
20
  from model_compression_toolkit.constants import AXIS, HESSIAN_OUTPUT_ALPHA
21
- from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfigV2
21
+ from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig
22
22
  from model_compression_toolkit.core.common import Graph, BaseNode
23
23
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
24
24
 
@@ -37,7 +37,7 @@ class SensitivityEvaluation:
37
37
 
38
38
  def __init__(self,
39
39
  graph: Graph,
40
- quant_config: MixedPrecisionQuantizationConfigV2,
40
+ quant_config: MixedPrecisionQuantizationConfig,
41
41
  representative_data_gen: Callable,
42
42
  fw_info: FrameworkInfo,
43
43
  fw_impl: Any,
@@ -307,7 +307,25 @@ class MemoryCalculator:
307
307
  total_params += w.size
308
308
 
309
309
  # Adjust the total parameter count if padded channels are to be included.
310
- num_oc = np.sum(output_mask) if output_mask is not None else node.output_shape[-1]
310
+ if output_mask is not None:
311
+ num_oc = np.sum(output_mask)
312
+ else:
313
+ # Get the node channel axis from framework info
314
+ channel_axis = self.fw_info.out_channel_axis_mapping.get(node.type)
315
+ if channel_axis is None:
316
+ Logger.error("Channel axis is not defined")
317
+
318
+ # Check if node.output_shape is a list of lists.
319
+ # In this case make sure all the out channels are the same value
320
+ if all(isinstance(sublist, list) for sublist in node.output_shape):
321
+ compare_value = node.output_shape[0][channel_axis]
322
+ if all(len(sublist) > channel_axis and sublist[channel_axis] == compare_value for sublist in node.output_shape):
323
+ num_oc = compare_value
324
+ else:
325
+ Logger.error("Number of out channels are not the same for all outputs of the node")
326
+ else:
327
+ num_oc = node.output_shape[channel_axis]
328
+
311
329
  if include_padded_channels:
312
330
  total_params = self.get_node_nparams_with_padded_channels(node, total_params, num_oc, node.get_simd())
313
331
 
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
16
16
  from model_compression_toolkit.core.common.quantization.debug_config import DebugConfig
17
- from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfigV2
17
+ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
18
18
 
19
19
 
20
20
  class CoreConfig:
@@ -23,14 +23,14 @@ class CoreConfig:
23
23
  """
24
24
  def __init__(self,
25
25
  quantization_config: QuantizationConfig = QuantizationConfig(),
26
- mixed_precision_config: MixedPrecisionQuantizationConfigV2 = None,
26
+ mixed_precision_config: MixedPrecisionQuantizationConfig = None,
27
27
  debug_config: DebugConfig = DebugConfig()
28
28
  ):
29
29
  """
30
30
 
31
31
  Args:
32
32
  quantization_config (QuantizationConfig): Config for quantization.
33
- mixed_precision_config (MixedPrecisionQuantizationConfigV2): Config for mixed precision quantization (optional, default=None).
33
+ mixed_precision_config (MixedPrecisionQuantizationConfig): Config for mixed precision quantization (optional, default=None).
34
34
  debug_config (DebugConfig): Config for debugging and editing the network quantization process.
35
35
  """
36
36
  self.quantization_config = quantization_config
@@ -17,7 +17,6 @@ from collections.abc import Callable
17
17
  from functools import partial
18
18
 
19
19
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
- from model_compression_toolkit.core.common.quantization.quantizers.kmeans_quantizer import kmeans_quantizer
21
20
  from model_compression_toolkit.core.common.quantization.quantizers.lut_kmeans_quantizer import lut_kmeans_quantizer
22
21
  from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import power_of_two_quantizer, \
23
22
  symmetric_quantizer, uniform_quantizer
@@ -40,8 +39,6 @@ def get_weights_quantization_fn(weights_quantization_method: QuantizationMethod)
40
39
  quantizer_fn = symmetric_quantizer
41
40
  elif weights_quantization_method == QuantizationMethod.UNIFORM:
42
41
  quantizer_fn = uniform_quantizer
43
- elif weights_quantization_method == QuantizationMethod.KMEANS:
44
- quantizer_fn = kmeans_quantizer
45
42
  elif weights_quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER, QuantizationMethod.LUT_SYM_QUANTIZER]:
46
43
  quantizer_fn = lut_kmeans_quantizer
47
44
  else:
@@ -18,7 +18,6 @@ from functools import partial
18
18
 
19
19
  from model_compression_toolkit.logger import Logger
20
20
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
- from model_compression_toolkit.core.common.quantization.quantization_params_generation.kmeans_params import kmeans_tensor
22
21
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import \
23
22
  lut_kmeans_tensor, lut_kmeans_histogram
24
23
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import \
@@ -70,8 +69,6 @@ def get_weights_quantization_params_fn(weights_quantization_method: Quantization
70
69
  params_fn = symmetric_selection_tensor
71
70
  elif weights_quantization_method == QuantizationMethod.UNIFORM:
72
71
  params_fn = uniform_selection_tensor
73
- elif weights_quantization_method == QuantizationMethod.KMEANS:
74
- params_fn = kmeans_tensor
75
72
  elif weights_quantization_method == QuantizationMethod.LUT_POT_QUANTIZER:
76
73
  params_fn = partial(lut_kmeans_tensor, is_symmetric=False)
77
74
  elif weights_quantization_method == QuantizationMethod.LUT_SYM_QUANTIZER:
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import power_of_two_no_clipping_selection_min_max, \
16
16
  power_of_two_selection_histogram, power_of_two_selection_tensor
17
- from model_compression_toolkit.core.common.quantization.quantization_params_generation.kmeans_params import kmeans_tensor
18
17
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import lut_kmeans_tensor
19
18
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import symmetric_no_clipping_selection_min_max
20
19
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.uniform_selection import uniform_no_clipping_selection_min_max
@@ -57,7 +57,7 @@ else:
57
57
  Concatenate, Add
58
58
  from keras.layers.core import TFOpLambda
59
59
 
60
- from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfigV2
60
+ from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfig
61
61
  from model_compression_toolkit.core import common
62
62
  from model_compression_toolkit.core.common import Graph, BaseNode
63
63
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
@@ -355,7 +355,7 @@ class KerasImplementation(FrameworkImplementation):
355
355
 
356
356
  def get_sensitivity_evaluator(self,
357
357
  graph: Graph,
358
- quant_config: MixedPrecisionQuantizationConfigV2,
358
+ quant_config: MixedPrecisionQuantizationConfig,
359
359
  representative_data_gen: Callable,
360
360
  fw_info: FrameworkInfo,
361
361
  disable_activation_for_metric: bool = False,
@@ -14,8 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from typing import Callable
17
-
18
- from model_compression_toolkit.core import CoreConfig, MixedPrecisionQuantizationConfigV2
17
+ from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, CoreConfig
19
18
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
20
19
  from model_compression_toolkit.logger import Logger
21
20
  from model_compression_toolkit.constants import TENSORFLOW
@@ -36,7 +35,7 @@ if FOUND_TF:
36
35
 
37
36
  def keras_kpi_data(in_model: Model,
38
37
  representative_data_gen: Callable,
39
- core_config: CoreConfig,
38
+ core_config: CoreConfig = CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig()),
40
39
  fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
41
40
  target_platform_capabilities: TargetPlatformCapabilities = KERAS_DEFAULT_TPC) -> KPI:
42
41
  """
@@ -73,9 +72,9 @@ if FOUND_TF:
73
72
 
74
73
  """
75
74
 
76
- if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
77
- Logger.error("KPI data computation can't be executed without MixedPrecisionQuantizationConfigV2 object."
78
- "Given quant_config is not of type MixedPrecisionQuantizationConfigV2.")
75
+ if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
76
+ Logger.error("KPI data computation can't be executed without MixedPrecisionQuantizationConfig object."
77
+ "Given quant_config is not of type MixedPrecisionQuantizationConfig.")
79
78
 
80
79
  fw_impl = KerasImplementation()
81
80