mct-nightly 1.11.0.20240303.post423__py3-none-any.whl → 1.11.0.20240305.post352__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.11.0.20240303.post423.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/METADATA +5 -5
- {mct_nightly-1.11.0.20240303.post423.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/RECORD +32 -30
- model_compression_toolkit/core/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +4 -70
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +19 -1
- model_compression_toolkit/core/common/quantization/core_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +0 -3
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +0 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -2
- model_compression_toolkit/core/keras/kpi_data_facade.py +5 -6
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +19 -19
- model_compression_toolkit/core/pytorch/constants.py +3 -0
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +5 -5
- model_compression_toolkit/core/pytorch/pruning/__init__.py +14 -0
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +315 -0
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +2 -2
- model_compression_toolkit/gptq/keras/quantization_facade.py +4 -4
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +3 -3
- model_compression_toolkit/pruning/__init__.py +1 -0
- model_compression_toolkit/pruning/pytorch/__init__.py +14 -0
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +166 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +4 -7
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +3 -6
- model_compression_toolkit/qat/keras/quantization_facade.py +6 -9
- model_compression_toolkit/qat/pytorch/quantization_facade.py +3 -7
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +0 -64
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +0 -53
- {mct_nightly-1.11.0.20240303.post423.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.11.0.20240303.post423.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/WHEEL +0 -0
- {mct_nightly-1.11.0.20240303.post423.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: mct-nightly
|
|
3
|
-
Version: 1.11.0.
|
|
3
|
+
Version: 1.11.0.20240305.post352
|
|
4
4
|
Summary: A Model Compression Toolkit for neural networks
|
|
5
5
|
Home-page: UNKNOWN
|
|
6
6
|
License: UNKNOWN
|
|
@@ -173,11 +173,11 @@ This pruning technique is designed to compress models for specific hardware arch
|
|
|
173
173
|
taking into account the target platform's Single Instruction, Multiple Data (SIMD) capabilities.
|
|
174
174
|
By pruning groups of channels (SIMD groups), our approach not only reduces model size
|
|
175
175
|
and complexity, but ensures that better utilization of channels is in line with the SIMD architecture
|
|
176
|
-
for a target KPI of weights memory footprint.
|
|
176
|
+
for a target KPI of weights memory footprint.
|
|
177
|
+
[Keras API](https://sony.github.io/model_optimization/docs/api/experimental_api_docs/methods/keras_pruning_experimental.html)
|
|
178
|
+
[Pytorch API](https://sony.github.io/model_optimization/docs/api/experimental_api_docs/methods/pytorch_pruning_experimental.html)
|
|
177
179
|
|
|
178
180
|
|
|
179
|
-
<u>_Note: Currently, only Keras models pruning is supported._</u>
|
|
180
|
-
|
|
181
181
|
#### Results
|
|
182
182
|
|
|
183
183
|
Results for applying pruning to reduce the parameters of the following models by 50%:
|
|
@@ -185,7 +185,7 @@ Results for applying pruning to reduce the parameters of the following models by
|
|
|
185
185
|
| Model | Dense Model Accuracy | Pruned Model Accuracy |
|
|
186
186
|
|-----------------|----------------------|-----------------------|
|
|
187
187
|
| ResNet50 [2] | 75.1 | 72.4 |
|
|
188
|
-
| DenseNet121 [
|
|
188
|
+
| DenseNet121 [3] | 74.44 | 71.71 |
|
|
189
189
|
|
|
190
190
|
|
|
191
191
|
|
|
@@ -2,7 +2,7 @@ model_compression_toolkit/__init__.py,sha256=7NFeY28X4NSHeLTq1JcKkAnp031EMjorU-I
|
|
|
2
2
|
model_compression_toolkit/constants.py,sha256=_OW_bUeQmf08Bb4oVZ0KfUt-rcCeNOmdBv3aP7NF5fM,3631
|
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
|
4
4
|
model_compression_toolkit/logger.py,sha256=b9DVktZ-LymFcRxv2aL_sdiE6S2sSrFGWltx6dgEuUY,4863
|
|
5
|
-
model_compression_toolkit/core/__init__.py,sha256=
|
|
5
|
+
model_compression_toolkit/core/__init__.py,sha256=P-7OYR4TFYxVV_ZpIJBogkX8bGvXcijlF65Ez3ivjhc,1838
|
|
6
6
|
model_compression_toolkit/core/analyzer.py,sha256=dbsD61pakp_9JXNyAScLdtJvcXny9jr_cMbET0Bd3Sg,2975
|
|
7
7
|
model_compression_toolkit/core/exporter.py,sha256=U_-ea-zYHsnIt2ydameMLZ_gzDaCMI1dRa5IjA8RUuc,4233
|
|
8
8
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=3xp0WYqyeRdlBkf5R6uD2zWubg_JPttOwS7JRhKykBY,10043
|
|
@@ -11,7 +11,7 @@ model_compression_toolkit/core/runner.py,sha256=RgN9l0v7aFYu6MTuIZGAB2syr6NBqG_v
|
|
|
11
11
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
|
12
12
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
|
13
13
|
model_compression_toolkit/core/common/data_loader.py,sha256=jCoVIb4yeOWyCrCNRB1W-mgLSyqNVGEepFXrIqufVc4,4119
|
|
14
|
-
model_compression_toolkit/core/common/framework_implementation.py,sha256=
|
|
14
|
+
model_compression_toolkit/core/common/framework_implementation.py,sha256=n3T0uOfeni_P-5ut4ModjPdtRUWCWzfAIOW0LdOlVb4,20529
|
|
15
15
|
model_compression_toolkit/core/common/framework_info.py,sha256=1ZMMGS9ip-kSflqkartyNRt9aQ5ub1WepuTRcTy-YSQ,6337
|
|
16
16
|
model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
|
|
17
17
|
model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
|
|
@@ -62,10 +62,10 @@ model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py,sha256
|
|
|
62
62
|
model_compression_toolkit/core/common/mixed_precision/configurable_quant_id.py,sha256=LLDguK7afsbN742ucLpmJr5TUfTyFpK1vbf2bpVr1v0,882
|
|
63
63
|
model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py,sha256=kmyBcqGh3qYqo42gIZzouQEljTNpF9apQt6cXEVkTQ0,3871
|
|
64
64
|
model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=x0cweemRG3_7FlvAbxFK5Zi77qpoKAGqtGndY8MtgwM,2222
|
|
65
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=
|
|
66
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=
|
|
65
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=wKiG-d_qg5dgPpZRK6dvP0MCXUKIfDNpeCHQrI3JxqY,4245
|
|
66
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=mN-QeabIu_Mz1IzPeQjqgqprCTdwGm4ThYX0gZAek-E,7103
|
|
67
67
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=XrHMQAzIUwyHqrY30cwK1lP5jKI0toeWFEG4e3qKzlY,35988
|
|
68
|
-
model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=
|
|
68
|
+
model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=z_0hz6mX27Q-Xzg-ftvoyPrw_i9faBGslkP0u0WGxTg,28223
|
|
69
69
|
model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
|
|
70
70
|
model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=A0-XsePLlqeYDgAFrDLJjYa4mY3VFKRWYQ3vHpPE-c4,6326
|
|
71
71
|
model_compression_toolkit/core/common/mixed_precision/kpi_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
@@ -83,7 +83,7 @@ model_compression_toolkit/core/common/network_editors/node_filters.py,sha256=uML
|
|
|
83
83
|
model_compression_toolkit/core/common/pruning/__init__.py,sha256=DGJybkDQtKMSMFoZ-nZ3ZifA8uJ6G_D20wHhKHNlmU0,699
|
|
84
84
|
model_compression_toolkit/core/common/pruning/channels_grouping.py,sha256=4jsr1xEBNpV2c4ipi366IfHoHCJVqoRUTTOJdlRomvc,3892
|
|
85
85
|
model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py,sha256=Go3FPHEyvm_Ra0c3XmdUVEFLCGWFoddMuYXd3Trdxkw,7594
|
|
86
|
-
model_compression_toolkit/core/common/pruning/memory_calculator.py,sha256=
|
|
86
|
+
model_compression_toolkit/core/common/pruning/memory_calculator.py,sha256=RnMmgNDHekKFOj-b-ad5rhjuKUvbVawy1A31nxuCRTg,19217
|
|
87
87
|
model_compression_toolkit/core/common/pruning/prune_graph.py,sha256=ddbZLuWvlNoj5so_5NRbIuG5qDFxD9ApG2gPirbov8o,3317
|
|
88
88
|
model_compression_toolkit/core/common/pruning/pruner.py,sha256=vXxzBXQ-oAEnw6PAD1SUiNXX7Xix4JJ0LAmV04sjFz0,7313
|
|
89
89
|
model_compression_toolkit/core/common/pruning/pruning_config.py,sha256=IfF824hNttyw2i4Tuf3g8CUfelJR3eZuOLzf2aEZNAM,3442
|
|
@@ -99,19 +99,18 @@ model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py,sha256=hk
|
|
|
99
99
|
model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256=gmzD32xsfJH8vkkqaspS7vYa6VWayk1GJe-NfoAEugQ,5901
|
|
100
100
|
model_compression_toolkit/core/common/quantization/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
|
101
101
|
model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=VW0vGCPKMxbhl2cB_zHx3g9c7cNS4ctVEAvnaNq17jw,5153
|
|
102
|
-
model_compression_toolkit/core/common/quantization/core_config.py,sha256=
|
|
102
|
+
model_compression_toolkit/core/common/quantization/core_config.py,sha256=IkD4Jl9PWdPucfUMq0TtyUl5DBJvha7Dd2xSW7_7dz8,2015
|
|
103
103
|
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=HtkMmneN-EmAzgZK4Vp4M8Sqm5QKdrvNyyZMpaVqYzY,1482
|
|
104
104
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=4GCr4Z6pRMbxIAnq4s7YtdMSqwbRwUzTzCFfs2ahVfk,6137
|
|
105
105
|
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=Q634XzMtjqReiLni8974y13apzbZ9nref-XBGjH17-0,16761
|
|
106
106
|
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=0547euhiaEX4vgpvIwHd7-pZ3iI7QmIc6Y_qHV4Y5sY,6713
|
|
107
|
-
model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=
|
|
108
|
-
model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=
|
|
107
|
+
model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=f2Qa2majjO-gIN3lxqsA8icKJ9FMP-sKbw3lI6XNgBg,2137
|
|
108
|
+
model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=mrgVzZszWjxnjT8zm77UVLWKTOwd2thGBo6WNqAS4X8,3867
|
|
109
109
|
model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=_OQEFAdYDTHu2Qp-qs02Z1CDxugUKG6k5eCePS1WpXY,2939
|
|
110
110
|
model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=UK_YshvZI0-LrKeT9gFGYcMA7pma1kaR5JAfzJH3HNw,3614
|
|
111
111
|
model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=qkrWJXLyDSIJhvT8tO9Nh51f4abyVR8zMFuaaMRRrRw,12304
|
|
112
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=
|
|
112
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
|
|
113
113
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=rwCedE0zggamSBY50rqh-xqZpIMrn8o96YH_jMCuPrk,16505
|
|
114
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py,sha256=qDfJbvY64KLOG6n18ddEPTFGrKHlaXzZ136TrVpgH9s,2917
|
|
115
114
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=h8Zmpq3KdcsdUUy7K1fvWOVSki0mxT8wtKZXGmgFl74,7405
|
|
116
115
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
|
|
117
116
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=W4j9IB1Grj_Ku1pLjPxb-HLcYU9LTDuf9_0JilbqU2w,8484
|
|
@@ -122,7 +121,6 @@ model_compression_toolkit/core/common/quantization/quantization_params_generatio
|
|
|
122
121
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=53OFL3IZxtH-UPakf3h_LZkaZAa3cgc2oqgMUe3Sg8o,9689
|
|
123
122
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py,sha256=oiJn1twYpTaq_z5qX4d8_nnk_jouYWHq8z0WAknl5oE,7879
|
|
124
123
|
model_compression_toolkit/core/common/quantization/quantizers/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
125
|
-
model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py,sha256=ZS3IXGbUTW580vwVF5jgxfPVYVL3tQrpvoSqDxVu7zQ,2325
|
|
126
124
|
model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py,sha256=P0x_y18LypBxP2tV9OWizheYfILqvaMC8RwHo04sUpQ,2761
|
|
127
125
|
model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py,sha256=5JuPwb9HDHaYQj1YyNWGY7GdjJ105Yr8iEEZhzfuRW4,14190
|
|
128
126
|
model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py,sha256=FVeuK-LeuAsRFcqo5uaNHmb6oTOFs21ltghtqswl6KM,5486
|
|
@@ -152,10 +150,10 @@ model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7V
|
|
|
152
150
|
model_compression_toolkit/core/keras/constants.py,sha256=YhuzRqXAdkRFzLT5lRD_jtLVYcUb-d4fUm-D49z5XOg,3158
|
|
153
151
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
|
154
152
|
model_compression_toolkit/core/keras/default_framework_info.py,sha256=Ha4HTHuiw_KTS5Po1Xnv6GyK9eprpDhYWf-eooS62Ys,4961
|
|
155
|
-
model_compression_toolkit/core/keras/keras_implementation.py,sha256=
|
|
153
|
+
model_compression_toolkit/core/keras/keras_implementation.py,sha256=kvrZeq3gmHeJKZXPyz55NGwO3EseTgo6022ZCWaUPW8,28144
|
|
156
154
|
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
|
|
157
155
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=Aqh31wOPaiZcJIOm-uJwzev0eTMdJyXaOk97rs4z7BU,3879
|
|
158
|
-
model_compression_toolkit/core/keras/kpi_data_facade.py,sha256=
|
|
156
|
+
model_compression_toolkit/core/keras/kpi_data_facade.py,sha256=nZdRhQuIXjoL3sq2HjffKHWfnafVM-j_oqk0Cc5Op3I,4889
|
|
159
157
|
model_compression_toolkit/core/keras/tf_tensor_numpy.py,sha256=1kBs9URqZTfmRXAsCqvnekV5bKUL3MyqGbORewLIwu8,2457
|
|
160
158
|
model_compression_toolkit/core/keras/back2framework/__init__.py,sha256=rhIiXg_nBgUZ-baE3M6SzCuQbcnq4iebY1jtJBvKHOM,808
|
|
161
159
|
model_compression_toolkit/core/keras/back2framework/factory_model_builder.py,sha256=GSh1Piz5qpA7IlvHTMqUvPn7WBDa0IHEDZdd_TzY9XA,2226
|
|
@@ -192,7 +190,7 @@ model_compression_toolkit/core/keras/mixed_precision/__init__.py,sha256=sw7LOPN1
|
|
|
192
190
|
model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=s4m5jR2OsBCgDY0W0rB2T3YQwZPQvhZvNfSAPc9Vggo,4922
|
|
193
191
|
model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py,sha256=X_1qI-1iD6HxG6ERe48OlCelw-nn3k4NZpeV1vtfNOs,6363
|
|
194
192
|
model_compression_toolkit/core/keras/pruning/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
|
|
195
|
-
model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py,sha256=
|
|
193
|
+
model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py,sha256=KUnSnpx8ty8NmNc6HA-HiWg-FBg7rBwkbLgOM_5ZfJA,12676
|
|
196
194
|
model_compression_toolkit/core/keras/quantizer/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
197
195
|
model_compression_toolkit/core/keras/quantizer/base_quantizer.py,sha256=eMRjAUU189-AVwNGMlV0M-ZlL48ZYmILzutheUT00xU,1628
|
|
198
196
|
model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py,sha256=Oi64CD83OopPoQNAarl2MJRbCKujU2W8Wdrs9KOPNWk,6151
|
|
@@ -211,11 +209,11 @@ model_compression_toolkit/core/keras/statistics_correction/__init__.py,sha256=9H
|
|
|
211
209
|
model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py,sha256=XNCtT9klMcsO1v5KA3MmCq_WgXOIT5QSzbfTOa9T-04,3060
|
|
212
210
|
model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
213
211
|
model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
214
|
-
model_compression_toolkit/core/pytorch/constants.py,sha256=
|
|
212
|
+
model_compression_toolkit/core/pytorch/constants.py,sha256=NI-J7REuxn06oEIHsmJ4GqtNC3TbV8xlkJjt5Ar-c4U,2626
|
|
215
213
|
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
|
|
216
|
-
model_compression_toolkit/core/pytorch/kpi_data_facade.py,sha256=
|
|
214
|
+
model_compression_toolkit/core/pytorch/kpi_data_facade.py,sha256=xQdg8vtkwhx1uzElrh0KtwKdWFr6b2Guzv913iA_VoI,4978
|
|
217
215
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=IoMvTch5awAEPvB6Tg6ANhFGXvfSgv7JLsUBlxpMwk4,4330
|
|
218
|
-
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=
|
|
216
|
+
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=EEHOnsXsvWlNm4x-4ADuFHsIc8vLrZP9pMh3raegaNg,26074
|
|
219
217
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=n_B4a6FMwM9D2w8kzy3oenBWZgXNZuIZgTJC6JEuTy0,3250
|
|
220
218
|
model_compression_toolkit/core/pytorch/utils.py,sha256=Bxep5o-Zw3O_iLdlzidiU186r0O6MXXtkJIZm-L1PiA,2847
|
|
221
219
|
model_compression_toolkit/core/pytorch/back2framework/__init__.py,sha256=H_WixgN0elVWf3exgGYsi58imPoYDj5eYPeh6x4yfug,813
|
|
@@ -254,6 +252,8 @@ model_compression_toolkit/core/pytorch/hessian/weights_trace_hessian_calculator_
|
|
|
254
252
|
model_compression_toolkit/core/pytorch/mixed_precision/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
255
253
|
model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py,sha256=iEtLjZw4l8Oyj81QnxbQFRKVuTbQEI98P5LuU7SNGq8,4656
|
|
256
254
|
model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py,sha256=VUN9vvWQWAh281C0xgV3w4T2DkSaxFZ-xmBgF50vGdo,5961
|
|
255
|
+
model_compression_toolkit/core/pytorch/pruning/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
|
256
|
+
model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py,sha256=a1-P0ZI8jO1qlYNWqas9AvRb3bwcrNluul6qPc140R8,14601
|
|
257
257
|
model_compression_toolkit/core/pytorch/quantizer/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
258
258
|
model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py,sha256=rox-f5wbRyxU1UHeHyaoIDXB9r9fCXm1dPN4FVwHqTc,6464
|
|
259
259
|
model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py,sha256=uyeBtNokyDUikk-YkDP_mN_2DX0J5oPm3kSfdSUT2Ck,4420
|
|
@@ -341,7 +341,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
|
|
|
341
341
|
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
|
|
342
342
|
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=75j56X2AcNI_5hInsLvXnWZOGMZIKQY2hStVKBaA_Bc,17705
|
|
343
343
|
model_compression_toolkit/gptq/keras/graph_info.py,sha256=FIGqzJbG6GkdHenvdMu-tGTjp4j9BewdF_spmWCb4Mo,4627
|
|
344
|
-
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=
|
|
344
|
+
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=STtrf7bzqVr1P1Pah3cftGRcYSF7i-tSkcjrcuMdpTw,15162
|
|
345
345
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
|
346
346
|
model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=8NrJBftKFbMAF_jYaAbLP6GBwpCv3Ln1NKURaV75zko,4770
|
|
347
347
|
model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
|
|
@@ -358,7 +358,7 @@ model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa9
|
|
|
358
358
|
model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
|
|
359
359
|
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=RfkJNQbeqGZQvlmV0dZO7YJ894Gx2asLnnIHFdWNEZ0,15078
|
|
360
360
|
model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=-0GDC2cr-XXS7cTFTnDflJivGN7VaPnzVPsxCE-vZNU,3955
|
|
361
|
-
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=
|
|
361
|
+
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=Mr0AWmqaxdr9O4neBej4hB6BWxwP_okHjQ2SwUoVvrM,13318
|
|
362
362
|
model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
|
|
363
363
|
model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=Zb-P0yRyZHHBlDvUBdRwxDpdduEJyJp6OT9pfKFF5ks,4171
|
|
364
364
|
model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
|
|
@@ -370,20 +370,22 @@ model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_qu
|
|
|
370
370
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py,sha256=FgPSKoV8p8y-gLNz359XdOPD6w_wpDvcJFtTNLWqYb0,9099
|
|
371
371
|
model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
372
372
|
model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py,sha256=6uxq_w62jn8DDOt9T7VtA6jZ8jTAPcbTufKFOYpVUm4,8768
|
|
373
|
-
model_compression_toolkit/pruning/__init__.py,sha256=
|
|
373
|
+
model_compression_toolkit/pruning/__init__.py,sha256=lQMZS8G0pvR1LVi53nnJHNXgLNTan_MWMdwsVxhjrow,1106
|
|
374
374
|
model_compression_toolkit/pruning/keras/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
|
|
375
375
|
model_compression_toolkit/pruning/keras/pruning_facade.py,sha256=PHKZYBHrVyR348-a6gw44NrV8Ra9iaeFJ0WbWYpzX8k,8020
|
|
376
|
+
model_compression_toolkit/pruning/pytorch/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
|
|
377
|
+
model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=hO13kbCIyoAgVW6hRoJCUpHr6ThOuqz1Vkg-cftJY5k,8906
|
|
376
378
|
model_compression_toolkit/ptq/__init__.py,sha256=50QBTXOKdj9XLjXtrvf0mhC9FlW6TOi9-pjl96RLR14,930
|
|
377
379
|
model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
|
|
378
380
|
model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
379
|
-
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=
|
|
381
|
+
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=J6luWnjkyk8KvXwZnVb5sEG9fynhXIOBMbIRPXDs2g8,9805
|
|
380
382
|
model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
381
|
-
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=
|
|
383
|
+
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=gsr_tPfY9HbAxKWaHgiq3J_osXUkOoazdLuvWdyyg3A,8567
|
|
382
384
|
model_compression_toolkit/qat/__init__.py,sha256=BYKgH1NwB9fqF1TszULQ5tDfLI-GqgZV5sao-lDN9EM,1091
|
|
383
385
|
model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
|
|
384
386
|
model_compression_toolkit/qat/common/qat_config.py,sha256=kbSxFL6_u28furq5mW_75STWDmyX4clPt-seJAnX3IQ,3445
|
|
385
387
|
model_compression_toolkit/qat/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
386
|
-
model_compression_toolkit/qat/keras/quantization_facade.py,sha256=
|
|
388
|
+
model_compression_toolkit/qat/keras/quantization_facade.py,sha256=3hdUBm-Q8mRe_sBrpHKFguVaUQPEH95ZeiyYN_EclwI,16073
|
|
387
389
|
model_compression_toolkit/qat/keras/quantizer/__init__.py,sha256=zmYyCa25_KLCSUCGUDRslh3RCIjcRMxc_oXa54Aui-4,996
|
|
388
390
|
model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py,sha256=gPuIgQb8OafvC3SuA8jNsGoy8S8eTsDCEKuh36WDNss,2104
|
|
389
391
|
model_compression_toolkit/qat/keras/quantizer/quant_utils.py,sha256=cBULOgWUodcBO1lHevZggdTevuDYI6tQceV86U2x6DA,2543
|
|
@@ -395,7 +397,7 @@ model_compression_toolkit/qat/keras/quantizer/ste_rounding/__init__.py,sha256=cc
|
|
|
395
397
|
model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=I4KlaGv17k71IyjuSG9M0OlXlD5P0pfvKa6oCyRQ5FE,13517
|
|
396
398
|
model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py,sha256=EED6LfqhX_OhDRJ9e4GwbpgNC9vq7hoXyJS2VPvG2qc,10789
|
|
397
399
|
model_compression_toolkit/qat/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
398
|
-
model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=
|
|
400
|
+
model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=pNBkkT9Ey1CmrQUqg-NHdf4tFVjgoVldX8bqXJp9lik,12473
|
|
399
401
|
model_compression_toolkit/qat/pytorch/quantizer/__init__.py,sha256=xYa4C8pr9cG1f3mQQcBXO_u3IdJN-zl7leZxuXDs86w,1003
|
|
400
402
|
model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py,sha256=FnhuFCuQoSf78FM1z1UZgXXd3k-mKSM7i9dYOuJUmeA,2213
|
|
401
403
|
model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py,sha256=GOYRDXvQSGe_iUFVmvDy5BqC952hu_-rQO06n8QCyw0,5491
|
|
@@ -470,8 +472,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
|
|
|
470
472
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
|
471
473
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
472
474
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=SbvRlIdE32PEBsINt1bhSqvrKL_zbM9V-aeSkOn-sw4,3083
|
|
473
|
-
mct_nightly-1.11.0.
|
|
474
|
-
mct_nightly-1.11.0.
|
|
475
|
-
mct_nightly-1.11.0.
|
|
476
|
-
mct_nightly-1.11.0.
|
|
477
|
-
mct_nightly-1.11.0.
|
|
475
|
+
mct_nightly-1.11.0.20240305.post352.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
476
|
+
mct_nightly-1.11.0.20240305.post352.dist-info/METADATA,sha256=_XbUd27NO_Oa_pHVS9Jw32qfVPNAtdEeCOqXXLAoPhA,17377
|
|
477
|
+
mct_nightly-1.11.0.20240305.post352.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
|
478
|
+
mct_nightly-1.11.0.20240305.post352.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
479
|
+
mct_nightly-1.11.0.20240305.post352.dist-info/RECORD,,
|
|
@@ -22,6 +22,6 @@ from model_compression_toolkit.core.common.mixed_precision import mixed_precisio
|
|
|
22
22
|
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, QuantizationErrorMethod, DEFAULTCONFIG
|
|
23
23
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
|
24
24
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
25
|
-
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
|
|
25
|
+
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
|
|
26
26
|
from model_compression_toolkit.core.keras.kpi_data_facade import keras_kpi_data
|
|
27
27
|
from model_compression_toolkit.core.pytorch.kpi_data_facade import pytorch_kpi_data
|
|
@@ -18,7 +18,7 @@ from typing import Callable, Any, List, Tuple, Dict
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
|
|
20
20
|
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
|
|
21
|
-
from model_compression_toolkit.core import
|
|
21
|
+
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
|
|
22
22
|
from model_compression_toolkit.core import common
|
|
23
23
|
from model_compression_toolkit.core.common import BaseNode
|
|
24
24
|
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
|
|
@@ -304,7 +304,7 @@ class FrameworkImplementation(ABC):
|
|
|
304
304
|
@abstractmethod
|
|
305
305
|
def get_sensitivity_evaluator(self,
|
|
306
306
|
graph: Graph,
|
|
307
|
-
quant_config:
|
|
307
|
+
quant_config: MixedPrecisionQuantizationConfig,
|
|
308
308
|
representative_data_gen: Callable,
|
|
309
309
|
fw_info: FrameworkInfo,
|
|
310
310
|
hessian_info_service: HessianInfoService = None,
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py
CHANGED
|
@@ -13,16 +13,13 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from
|
|
17
|
-
from typing import List, Callable, Tuple
|
|
16
|
+
from typing import List, Callable
|
|
18
17
|
|
|
19
|
-
from model_compression_toolkit.logger import Logger
|
|
20
18
|
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import get_average_weights
|
|
21
|
-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, DEFAULTCONFIG
|
|
22
19
|
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
|
|
23
20
|
|
|
24
21
|
|
|
25
|
-
class
|
|
22
|
+
class MixedPrecisionQuantizationConfig:
|
|
26
23
|
|
|
27
24
|
def __init__(self,
|
|
28
25
|
compute_distance_fn: Callable = None,
|
|
@@ -36,8 +33,6 @@ class MixedPrecisionQuantizationConfigV2:
|
|
|
36
33
|
metric_normalization_threshold: float = 1e10):
|
|
37
34
|
"""
|
|
38
35
|
Class with mixed precision parameters to quantize the input model.
|
|
39
|
-
Unlike QuantizationConfig, number of bits for quantization is a list of possible bit widths to
|
|
40
|
-
support mixed-precision model quantization.
|
|
41
36
|
|
|
42
37
|
Args:
|
|
43
38
|
compute_distance_fn (Callable): Function to compute a distance between two tensors.
|
|
@@ -70,67 +65,6 @@ class MixedPrecisionQuantizationConfigV2:
|
|
|
70
65
|
self.metric_normalization_threshold = metric_normalization_threshold
|
|
71
66
|
|
|
72
67
|
|
|
73
|
-
class MixedPrecisionQuantizationConfig(QuantizationConfig):
|
|
74
|
-
|
|
75
|
-
def __init__(self,
|
|
76
|
-
qc: QuantizationConfig = DEFAULTCONFIG,
|
|
77
|
-
compute_distance_fn: Callable = compute_mse,
|
|
78
|
-
distance_weighting_method: Callable = get_average_weights,
|
|
79
|
-
num_of_images: int = 32,
|
|
80
|
-
configuration_overwrite: List[int] = None,
|
|
81
|
-
num_interest_points_factor: float = 1.0):
|
|
82
|
-
"""
|
|
83
|
-
Class to wrap all different parameters the library quantize the input model according to.
|
|
84
|
-
Unlike QuantizationConfig, number of bits for quantization is a list of possible bit widths to
|
|
85
|
-
support mixed-precision model quantization.
|
|
86
|
-
|
|
87
|
-
Args:
|
|
88
|
-
qc (QuantizationConfig): QuantizationConfig object containing parameters of how the model should be quantized.
|
|
89
|
-
compute_distance_fn (Callable): Function to compute a distance between two tensors.
|
|
90
|
-
distance_weighting_method (Callable): Function to use when weighting the distances among different layers when computing the sensitivity metric.
|
|
91
|
-
num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
|
|
92
|
-
configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one.
|
|
93
|
-
num_interest_points_factor: A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric.
|
|
94
|
-
|
|
95
|
-
"""
|
|
96
|
-
|
|
97
|
-
super().__init__(**qc.__dict__)
|
|
98
|
-
self.compute_distance_fn = compute_distance_fn
|
|
99
|
-
self.distance_weighting_method = distance_weighting_method
|
|
100
|
-
self.num_of_images = num_of_images
|
|
101
|
-
self.configuration_overwrite = configuration_overwrite
|
|
102
|
-
|
|
103
|
-
assert 0.0 < num_interest_points_factor <= 1.0, "num_interest_points_factor should represent a percentage of " \
|
|
104
|
-
"the base set of interest points that are required to be " \
|
|
105
|
-
"used for mixed-precision metric evaluation, " \
|
|
106
|
-
"thus, it should be between 0 to 1"
|
|
107
|
-
self.num_interest_points_factor = num_interest_points_factor
|
|
108
|
-
|
|
109
|
-
def separate_configs(self) -> Tuple[QuantizationConfig, MixedPrecisionQuantizationConfigV2]:
|
|
110
|
-
"""
|
|
111
|
-
A function to separate the old MixedPrecisionQuantizationConfig into QuantizationConfig
|
|
112
|
-
and MixedPrecisionQuantizationConfigV2
|
|
113
|
-
|
|
114
|
-
Returns: QuantizationConfig, MixedPrecisionQuantizationConfigV2
|
|
115
|
-
|
|
116
|
-
"""
|
|
117
|
-
_dummy_quant_config = QuantizationConfig()
|
|
118
|
-
_dummy_mp_config_experimental = MixedPrecisionQuantizationConfigV2()
|
|
119
|
-
qc_dict = {}
|
|
120
|
-
mp_dict = {}
|
|
121
|
-
for k, v in self.__dict__.items():
|
|
122
|
-
if hasattr(_dummy_quant_config, k):
|
|
123
|
-
qc_dict.update({k: v})
|
|
124
|
-
elif hasattr(_dummy_mp_config_experimental, k):
|
|
125
|
-
mp_dict.update({k: v})
|
|
126
|
-
else:
|
|
127
|
-
Logger.error(f'Attribute "{k}" mismatch: exists in MixedPrecisionQuantizationConfig but not in '
|
|
128
|
-
f'MixedPrecisionQuantizationConfigV2') # pragma: no cover
|
|
129
|
-
|
|
130
|
-
return QuantizationConfig(**qc_dict), MixedPrecisionQuantizationConfigV2(**mp_dict)
|
|
131
|
-
|
|
132
|
-
|
|
133
68
|
# Default quantization configuration the library use.
|
|
134
|
-
DEFAULT_MIXEDPRECISION_CONFIG = MixedPrecisionQuantizationConfig(
|
|
135
|
-
|
|
136
|
-
get_average_weights)
|
|
69
|
+
DEFAULT_MIXEDPRECISION_CONFIG = MixedPrecisionQuantizationConfig(compute_distance_fn=compute_mse,
|
|
70
|
+
distance_weighting_method=get_average_weights)
|
|
@@ -18,7 +18,7 @@ from enum import Enum
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
from typing import List, Callable, Dict
|
|
20
20
|
|
|
21
|
-
from model_compression_toolkit.core import
|
|
21
|
+
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
|
|
22
22
|
from model_compression_toolkit.core.common import Graph
|
|
23
23
|
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
|
24
24
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI, KPITarget
|
|
@@ -48,7 +48,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
|
|
|
48
48
|
fw_info: FrameworkInfo,
|
|
49
49
|
fw_impl: FrameworkImplementation,
|
|
50
50
|
target_kpi: KPI,
|
|
51
|
-
mp_config:
|
|
51
|
+
mp_config: MixedPrecisionQuantizationConfig,
|
|
52
52
|
representative_data_gen: Callable,
|
|
53
53
|
search_method: BitWidthSearchMethod = BitWidthSearchMethod.INTEGER_PROGRAMMING,
|
|
54
54
|
hessian_info_service: HessianInfoService=None) -> List[int]:
|
|
@@ -18,7 +18,7 @@ import numpy as np
|
|
|
18
18
|
from typing import Callable, Any, List, Tuple
|
|
19
19
|
|
|
20
20
|
from model_compression_toolkit.constants import AXIS, HESSIAN_OUTPUT_ALPHA
|
|
21
|
-
from model_compression_toolkit.core import FrameworkInfo,
|
|
21
|
+
from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig
|
|
22
22
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
23
23
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
|
24
24
|
|
|
@@ -37,7 +37,7 @@ class SensitivityEvaluation:
|
|
|
37
37
|
|
|
38
38
|
def __init__(self,
|
|
39
39
|
graph: Graph,
|
|
40
|
-
quant_config:
|
|
40
|
+
quant_config: MixedPrecisionQuantizationConfig,
|
|
41
41
|
representative_data_gen: Callable,
|
|
42
42
|
fw_info: FrameworkInfo,
|
|
43
43
|
fw_impl: Any,
|
|
@@ -307,7 +307,25 @@ class MemoryCalculator:
|
|
|
307
307
|
total_params += w.size
|
|
308
308
|
|
|
309
309
|
# Adjust the total parameter count if padded channels are to be included.
|
|
310
|
-
|
|
310
|
+
if output_mask is not None:
|
|
311
|
+
num_oc = np.sum(output_mask)
|
|
312
|
+
else:
|
|
313
|
+
# Get the node channel axis from framework info
|
|
314
|
+
channel_axis = self.fw_info.out_channel_axis_mapping.get(node.type)
|
|
315
|
+
if channel_axis is None:
|
|
316
|
+
Logger.error("Channel axis is not defined")
|
|
317
|
+
|
|
318
|
+
# Check if node.output_shape is a list of lists.
|
|
319
|
+
# In this case make sure all the out channels are the same value
|
|
320
|
+
if all(isinstance(sublist, list) for sublist in node.output_shape):
|
|
321
|
+
compare_value = node.output_shape[0][channel_axis]
|
|
322
|
+
if all(len(sublist) > channel_axis and sublist[channel_axis] == compare_value for sublist in node.output_shape):
|
|
323
|
+
num_oc = compare_value
|
|
324
|
+
else:
|
|
325
|
+
Logger.error("Number of out channels are not the same for all outputs of the node")
|
|
326
|
+
else:
|
|
327
|
+
num_oc = node.output_shape[channel_axis]
|
|
328
|
+
|
|
311
329
|
if include_padded_channels:
|
|
312
330
|
total_params = self.get_node_nparams_with_padded_channels(node, total_params, num_oc, node.get_simd())
|
|
313
331
|
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
|
16
16
|
from model_compression_toolkit.core.common.quantization.debug_config import DebugConfig
|
|
17
|
-
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import
|
|
17
|
+
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class CoreConfig:
|
|
@@ -23,14 +23,14 @@ class CoreConfig:
|
|
|
23
23
|
"""
|
|
24
24
|
def __init__(self,
|
|
25
25
|
quantization_config: QuantizationConfig = QuantizationConfig(),
|
|
26
|
-
mixed_precision_config:
|
|
26
|
+
mixed_precision_config: MixedPrecisionQuantizationConfig = None,
|
|
27
27
|
debug_config: DebugConfig = DebugConfig()
|
|
28
28
|
):
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
31
|
Args:
|
|
32
32
|
quantization_config (QuantizationConfig): Config for quantization.
|
|
33
|
-
mixed_precision_config (
|
|
33
|
+
mixed_precision_config (MixedPrecisionQuantizationConfig): Config for mixed precision quantization (optional, default=None).
|
|
34
34
|
debug_config (DebugConfig): Config for debugging and editing the network quantization process.
|
|
35
35
|
"""
|
|
36
36
|
self.quantization_config = quantization_config
|
|
@@ -17,7 +17,6 @@ from collections.abc import Callable
|
|
|
17
17
|
from functools import partial
|
|
18
18
|
|
|
19
19
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
20
|
-
from model_compression_toolkit.core.common.quantization.quantizers.kmeans_quantizer import kmeans_quantizer
|
|
21
20
|
from model_compression_toolkit.core.common.quantization.quantizers.lut_kmeans_quantizer import lut_kmeans_quantizer
|
|
22
21
|
from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import power_of_two_quantizer, \
|
|
23
22
|
symmetric_quantizer, uniform_quantizer
|
|
@@ -40,8 +39,6 @@ def get_weights_quantization_fn(weights_quantization_method: QuantizationMethod)
|
|
|
40
39
|
quantizer_fn = symmetric_quantizer
|
|
41
40
|
elif weights_quantization_method == QuantizationMethod.UNIFORM:
|
|
42
41
|
quantizer_fn = uniform_quantizer
|
|
43
|
-
elif weights_quantization_method == QuantizationMethod.KMEANS:
|
|
44
|
-
quantizer_fn = kmeans_quantizer
|
|
45
42
|
elif weights_quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER, QuantizationMethod.LUT_SYM_QUANTIZER]:
|
|
46
43
|
quantizer_fn = lut_kmeans_quantizer
|
|
47
44
|
else:
|
|
@@ -18,7 +18,6 @@ from functools import partial
|
|
|
18
18
|
|
|
19
19
|
from model_compression_toolkit.logger import Logger
|
|
20
20
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
21
|
-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.kmeans_params import kmeans_tensor
|
|
22
21
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import \
|
|
23
22
|
lut_kmeans_tensor, lut_kmeans_histogram
|
|
24
23
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import \
|
|
@@ -70,8 +69,6 @@ def get_weights_quantization_params_fn(weights_quantization_method: Quantization
|
|
|
70
69
|
params_fn = symmetric_selection_tensor
|
|
71
70
|
elif weights_quantization_method == QuantizationMethod.UNIFORM:
|
|
72
71
|
params_fn = uniform_selection_tensor
|
|
73
|
-
elif weights_quantization_method == QuantizationMethod.KMEANS:
|
|
74
|
-
params_fn = kmeans_tensor
|
|
75
72
|
elif weights_quantization_method == QuantizationMethod.LUT_POT_QUANTIZER:
|
|
76
73
|
params_fn = partial(lut_kmeans_tensor, is_symmetric=False)
|
|
77
74
|
elif weights_quantization_method == QuantizationMethod.LUT_SYM_QUANTIZER:
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py
CHANGED
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import power_of_two_no_clipping_selection_min_max, \
|
|
16
16
|
power_of_two_selection_histogram, power_of_two_selection_tensor
|
|
17
|
-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.kmeans_params import kmeans_tensor
|
|
18
17
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import lut_kmeans_tensor
|
|
19
18
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import symmetric_no_clipping_selection_min_max
|
|
20
19
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.uniform_selection import uniform_no_clipping_selection_min_max
|
|
@@ -57,7 +57,7 @@ else:
|
|
|
57
57
|
Concatenate, Add
|
|
58
58
|
from keras.layers.core import TFOpLambda
|
|
59
59
|
|
|
60
|
-
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig,
|
|
60
|
+
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfig
|
|
61
61
|
from model_compression_toolkit.core import common
|
|
62
62
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
63
63
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
@@ -355,7 +355,7 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
355
355
|
|
|
356
356
|
def get_sensitivity_evaluator(self,
|
|
357
357
|
graph: Graph,
|
|
358
|
-
quant_config:
|
|
358
|
+
quant_config: MixedPrecisionQuantizationConfig,
|
|
359
359
|
representative_data_gen: Callable,
|
|
360
360
|
fw_info: FrameworkInfo,
|
|
361
361
|
disable_activation_for_metric: bool = False,
|
|
@@ -14,8 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
from typing import Callable
|
|
17
|
-
|
|
18
|
-
from model_compression_toolkit.core import CoreConfig, MixedPrecisionQuantizationConfigV2
|
|
17
|
+
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, CoreConfig
|
|
19
18
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
20
19
|
from model_compression_toolkit.logger import Logger
|
|
21
20
|
from model_compression_toolkit.constants import TENSORFLOW
|
|
@@ -36,7 +35,7 @@ if FOUND_TF:
|
|
|
36
35
|
|
|
37
36
|
def keras_kpi_data(in_model: Model,
|
|
38
37
|
representative_data_gen: Callable,
|
|
39
|
-
core_config: CoreConfig,
|
|
38
|
+
core_config: CoreConfig = CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig()),
|
|
40
39
|
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
|
41
40
|
target_platform_capabilities: TargetPlatformCapabilities = KERAS_DEFAULT_TPC) -> KPI:
|
|
42
41
|
"""
|
|
@@ -73,9 +72,9 @@ if FOUND_TF:
|
|
|
73
72
|
|
|
74
73
|
"""
|
|
75
74
|
|
|
76
|
-
if not isinstance(core_config.mixed_precision_config,
|
|
77
|
-
Logger.error("KPI data computation can't be executed without
|
|
78
|
-
"Given quant_config is not of type
|
|
75
|
+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
76
|
+
Logger.error("KPI data computation can't be executed without MixedPrecisionQuantizationConfig object."
|
|
77
|
+
"Given quant_config is not of type MixedPrecisionQuantizationConfig.")
|
|
79
78
|
|
|
80
79
|
fw_impl = KerasImplementation()
|
|
81
80
|
|