mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -1,105 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import List
17
- from tensorflow.keras.layers import Layer
18
- from tensorflow.python.util.object_identity import Reference as TFReference
19
-
20
- from model_compression_toolkit import get_target_platform_capabilities
21
- from model_compression_toolkit.core import common
22
- from model_compression_toolkit.core.common import BaseNode
23
- from model_compression_toolkit.core.common.constants import TENSORFLOW
24
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
25
-
26
- from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
27
- from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL
28
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
29
- from model_compression_toolkit.qat.keras.quantizer.quantization_dispatcher_builder import \
30
- quantization_dispatcher_builder
31
- from model_compression_toolkit import qunatizers_infrastructure as qi
32
-
33
- DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
34
-
35
-
36
- def _is_qat_applicable(node: common.BaseNode,
37
- fw_info: FrameworkInfo) -> bool:
38
- """
39
- A function for deciding if a layer should be fine-tuned during QAT
40
- Args:
41
- node (BaseNode): Node for quantization decision
42
- fw_info (FrameworkInfo): Keras quantization information
43
-
44
- Returns:
45
- A boolean whether the layer is to be wrapped with a QuantizeWrapper
46
- """
47
-
48
- return fw_info.is_kernel_op(node.type) and node.is_weights_quantization_enabled()
49
-
50
-
51
- def qat_wrapper(n: common.BaseNode, layer: Layer):
52
- """
53
- A function which takes a computational graph node and a keras layer and perform the quantization wrapping
54
- Args:
55
- n: A node of mct graph.
56
- layer: A keras layer
57
-
58
- Returns: Wrapped layer
59
-
60
- """
61
- if _is_qat_applicable(n, DEFAULT_KERAS_INFO):
62
- return qi.KerasQuantizationWrapper(layer, quantization_dispatcher_builder(n, DEFAULT_KERAS_INFO))
63
- else:
64
- return layer
65
-
66
-
67
- class QATKerasModelBuilder(KerasModelBuilder):
68
- """
69
- Builder of QAT Keras models.
70
- """
71
-
72
- def __init__(self,
73
- graph: common.Graph,
74
- append2output=None,
75
- fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
76
- return_float_outputs: bool = False):
77
- """
78
-
79
- Args:
80
- graph: Graph to build the model from.
81
- append2output: Nodes to append to model's output.
82
- fw_info: Information about the specific framework of the model that is built.
83
- return_float_outputs: Whether the model returns float tensors or not.
84
- """
85
- super().__init__(graph,
86
- append2output,
87
- fw_info,
88
- return_float_outputs,
89
- wrapper=qat_wrapper)
90
-
91
- def _quantize_node_activations(self,
92
- node: BaseNode,
93
- input_tensors: List[TFReference]) -> List[TFReference]:
94
- """
95
- Quantize node's activation given input tensors.
96
-
97
- Args:
98
- node: Node to quantize its outputs.
99
- input_tensors: Input tensors of the node.
100
-
101
- Returns:
102
- Output of the node.
103
-
104
- """
105
- return node.final_activation_quantization_cfg.quantize_node_output(input_tensors)
@@ -1,56 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from typing import Dict
18
- from model_compression_toolkit.core import common
19
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
20
- from model_compression_toolkit import qunatizers_infrastructure as qi
21
- from model_compression_toolkit.qat.keras.quantizer.ste_rounding.symmetirc_ste import STEWeightQuantizer
22
- from model_compression_toolkit.qat.keras.quantizer.ste_rounding.uniform_ste import STEUniformWeightQuantizer
23
-
24
- METHOD2QUANTIZER = {qi.QuantizationMethod.SYMMETRIC: STEWeightQuantizer,
25
- qi.QuantizationMethod.POWER_OF_TWO: STEWeightQuantizer,
26
- qi.QuantizationMethod.UNIFORM: STEUniformWeightQuantizer}
27
-
28
-
29
- def quantization_dispatcher_builder(n: common.BaseNode,
30
- fw_info: FrameworkInfo,
31
- method2quantizer: Dict[
32
- qi.QuantizationMethod, qi.BaseKerasQuantizer] = METHOD2QUANTIZER) -> qi.KerasNodeQuantizationDispatcher:
33
- """
34
- Build a NodeQuantizationDispatcher for a node according to its quantization configuration and
35
- a global NoOpQuantizeConfig object.
36
-
37
- Args:
38
- n: Node to build its QuantizeConfig.
39
- fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
40
- method2quantizer: A mapping between quantization method to quantizer.
41
-
42
- Returns:
43
- A QuantizeConfig object with the appropriate quantizers (according to the node's
44
- quantization configuration).
45
- """
46
- nqd = qi.KerasNodeQuantizationDispatcher()
47
- if n.is_weights_quantization_enabled():
48
- attributes = fw_info.get_kernel_op_attributes(n.type)
49
- for attr in attributes:
50
- qunatizer_class = method2quantizer.get(n.final_weights_quantization_cfg.weights_quantization_method)
51
- if qunatizer_class is None:
52
- common.Logger.error(
53
- f'Unknown Quantiztion method: {n.final_weights_quantization_cfg.weights_quantization_method}')
54
- nqd.add_weight_quantizer(attr, qunatizer_class(n.final_weights_quantization_cfg))
55
-
56
- return nqd
@@ -1,145 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import Dict
17
-
18
- import numpy as np
19
- import tensorflow as tf
20
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
21
- from tensorflow.python.framework.tensor_shape import TensorShape
22
-
23
- from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
24
-
25
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
26
- from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
27
- from model_compression_toolkit.qat.common.constants import FQ_MIN, FQ_MAX
28
- from model_compression_toolkit import qunatizers_infrastructure as qi
29
- from model_compression_toolkit.core.common import constants as C
30
-
31
-
32
- class STEWeightQuantizer(qi.BaseKerasQuantizer):
33
- """
34
- Trainable constrained quantizer to quantize a layer inputs.
35
- """
36
-
37
- def __init__(self, quantization_config: NodeWeightsQuantizationConfig):
38
- """
39
- Initialize a TrainableWeightQuantizer object with parameters to use
40
- for the quantization.
41
-
42
- Args:
43
- quantization_config: node quantization config class
44
- """
45
- super().__init__(quantization_config,
46
- qi.QuantizationTarget.Weights,
47
- [qi.QuantizationMethod.POWER_OF_TWO, qi.QuantizationMethod.SYMMETRIC])
48
- self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
49
- self.threshold_values = quantization_config.weights_quantization_params[C.THRESHOLD]
50
- self.threshold_shape = np.asarray(self.threshold_values).shape
51
- self.np_threshold_values = np.reshape(np.asarray(self.threshold_values),
52
- [-1]) if self.quantization_config.weights_channels_axis else float(
53
- self.threshold_values)
54
-
55
- if self.quantization_config.weights_per_channel_threshold and self.quantization_config.weights_channels_axis not in [
56
- -1, len(self.threshold_shape) - 1]:
57
- # Tensorflow's fake_quant_with_min_max_vars_per_channel only works on last axis, so
58
- # need to move the quantization axis to the last axis
59
- self.perm_vec = list(np.arange(len(self.threshold_shape)))
60
- self.perm_vec[self.quantization_config.weights_channels_axis] = len(self.threshold_shape) - 1
61
- self.perm_vec[len(self.threshold_shape) - 1] = self.quantization_config.weights_channels_axis
62
- else:
63
- self.perm_vec = None
64
-
65
- if self.power_of_two:
66
- self.np_threshold_values = np.power(2.0,
67
- np.ceil(np.log2(np.maximum(self.np_threshold_values, C.MIN_THRESHOLD))))
68
- num_bits = self.quantization_config.weights_n_bits
69
- delta = self.np_threshold_values / np.power(2.0, num_bits - int(C.WEIGHTS_SIGNED))
70
- min_int = -int(C.WEIGHTS_SIGNED) * (2 ** (num_bits - int(C.WEIGHTS_SIGNED)))
71
- max_int = (2 ** (num_bits - int(C.WEIGHTS_SIGNED))) - 1
72
- self.min = delta * min_int
73
- self.max = delta * max_int
74
- self.quantizer_parameters = {}
75
-
76
- def initialize_quantization(self,
77
- tensor_shape: TensorShape,
78
- name: str,
79
- layer: QuantizeWrapper) -> Dict[str, tf.Variable]:
80
- """
81
- Add min and max variables to layer.
82
- Args:
83
- tensor_shape: Tensor shape the quantizer quantize.
84
- name: Prefix of variables names.
85
- layer: Layer to add the variables to. The variables are saved
86
- in the layer's scope.
87
-
88
- Returns:
89
- Dictionary of new variables.
90
- """
91
- ptq_threshold_tensor = layer.add_weight(
92
- name + THRESHOLD_TENSOR,
93
- shape=len(self.np_threshold_values) if self.quantization_config.weights_channels_axis else (),
94
- initializer=tf.keras.initializers.Constant(1.0),
95
- trainable=False)
96
- ptq_threshold_tensor.assign(self.np_threshold_values)
97
-
98
- fq_min = layer.add_weight(
99
- name + FQ_MIN,
100
- shape=len(self.min) if self.quantization_config.weights_channels_axis else (),
101
- initializer=tf.keras.initializers.Constant(-1.0),
102
- trainable=False)
103
- fq_min.assign(self.min)
104
-
105
- fq_max = layer.add_weight(
106
- name + FQ_MAX,
107
- shape=len(self.max) if self.quantization_config.weights_channels_axis else (),
108
- initializer=tf.keras.initializers.Constant(1.0),
109
- trainable=False)
110
- fq_max.assign(self.max)
111
-
112
- # save the quantizer added parameters for later calculations
113
- self.quantizer_parameters = {THRESHOLD_TENSOR: ptq_threshold_tensor,
114
- FQ_MIN: fq_min, FQ_MAX: fq_max}
115
- return self.quantizer_parameters
116
-
117
- def __call__(self,
118
- inputs: tf.Tensor,
119
- training: bool):
120
- """
121
- Quantize a tensor.
122
- Args:
123
- inputs: Input tensor to quantize.
124
- training: Whether the graph is in training mode.
125
- weights: Dictionary of weights the quantizer can use to quantize the tensor.
126
- **kwargs: Additional variables the quantizer may receive.
127
-
128
- Returns:
129
- The quantized tensor.
130
- """
131
-
132
- _min = self.quantizer_parameters[FQ_MIN]
133
- _max = self.quantizer_parameters[FQ_MAX]
134
- if self.quantization_config.weights_channels_axis:
135
- if self.perm_vec:
136
- inputs = tf.transpose(inputs, perm=self.perm_vec)
137
- q_tensor = tf.quantization.fake_quant_with_min_max_vars_per_channel(inputs, _min, _max,
138
- num_bits=self.quantization_config.weights_n_bits)
139
- if self.perm_vec:
140
- q_tensor = tf.transpose(q_tensor, perm=self.perm_vec)
141
- else:
142
- q_tensor = tf.quantization.fake_quant_with_min_max_vars(inputs, _min, _max,
143
- num_bits=self.quantization_config.weights_n_bits)
144
-
145
- return q_tensor
@@ -1,8 +0,0 @@
1
- from model_compression_toolkit.qunatizers_infrastructure.common.base_quantizer import QuantizationTarget, \
2
- QuantizationMethod
3
-
4
- from model_compression_toolkit.qunatizers_infrastructure.keras.quantize_wrapper import KerasQuantizationWrapper
5
- from model_compression_toolkit.qunatizers_infrastructure.keras.base_keras_quantizer import BaseKerasQuantizer
6
- from model_compression_toolkit.qunatizers_infrastructure.keras.load_model import keras_load_quantized_model
7
- from model_compression_toolkit.qunatizers_infrastructure.keras.keras_node_quantization_dispatcher import \
8
- KerasNodeQuantizationDispatcher
@@ -1,14 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,123 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import List
17
- from enum import Enum
18
-
19
- from model_compression_toolkit.core import common
20
- from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \
21
- NodeActivationQuantizationConfig, BaseNodeQuantizationConfig
22
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
23
-
24
-
25
- class QuantizationTarget(Enum):
26
- Activation = 0
27
- Weights = 1
28
-
29
-
30
- class BaseQuantizer:
31
- def __init__(self,
32
- quantization_config: BaseNodeQuantizationConfig,
33
- quantization_target: QuantizationTarget,
34
- quantization_method: List[QuantizationMethod]):
35
- """
36
- This class is a base quantizer which validate the the provide quantization config and define abstract function which any quantizer need to implment.
37
-
38
- Args:
39
- quantization_config: node quantization config class contins all the information above a quantizer.
40
- quantization_target: A enum which decided the qunaizer tensor type activation or weights.
41
- quantization_method: A list of "QuantizationMethod" enums which represent the quantizer supported methods.
42
- """
43
- self.quantization_config = quantization_config
44
- self.quantization_target = quantization_target
45
- self.quantization_method = quantization_method
46
- if self.quantization_target == QuantizationTarget.Weights:
47
- self.validate_weights()
48
- if self.quantization_config.weights_quantization_method not in quantization_method:
49
- common.Logger.error(
50
- f'Quantization method mismatch expected:{quantization_method} and got {self.quantization_config.weights_quantization_method}')
51
- elif self.quantization_target == QuantizationTarget.Activation:
52
- self.validate_activation()
53
- if self.quantization_config.activation_quantization_method not in quantization_method:
54
- common.Logger.error(
55
- f'Quantization method mismatch expected:{quantization_method} and got {self.quantization_config.activation_quantization_method}')
56
- else:
57
- common.Logger.error(
58
- f'Unknown Quantization Part:{quantization_target}')
59
-
60
- def initialize_quantization(self,
61
- tensor_shape,
62
- name: str,
63
- layer):
64
- """
65
- This initializes the quantizer parameters given the parameter name and shape.
66
-
67
- Args:
68
- tensor_shape: tensor shape
69
- name: tensor name
70
- layer: layer to quantized
71
-
72
- Returns: None
73
-
74
- """
75
- raise NotImplemented
76
-
77
- def __call__(self,
78
- input2quantize,
79
- training: bool):
80
- """
81
- Quantize a tensor.
82
-
83
- Args:
84
- input2quantize: Input tensor to quantize.
85
- training: Whether the graph is in training mode.
86
-
87
- Returns:
88
- The quantized tensor.
89
- """
90
- raise NotImplemented
91
-
92
- def activation_quantization(self) -> bool:
93
- """
94
-
95
- Returns: A boolean stating is this activation quantizer
96
-
97
- """
98
- return isinstance(self.quantization_config, NodeActivationQuantizationConfig)
99
-
100
- def weights_quantization(self) -> bool:
101
- """
102
-
103
- Returns: A boolean stating is this weights quantizer
104
-
105
- """
106
- return isinstance(self.quantization_config, NodeWeightsQuantizationConfig)
107
-
108
- def validate_weights(self) -> None:
109
- """
110
- This function valid the quantize config compare with it parameters.
111
-
112
-
113
- """
114
- if self.activation_quantization() or not self.weights_quantization():
115
- common.Logger.error(f'Expect weight quantization got activation')
116
-
117
- def validate_activation(self) -> None:
118
- """
119
- This function valid the quantize config compare with it parameters.
120
-
121
- """
122
- if not self.activation_quantization() or self.weights_quantization():
123
- common.Logger.error(f'Expect activation quantization got weight')
@@ -1,65 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import Dict, List
17
-
18
- from model_compression_toolkit.qunatizers_infrastructure.common.base_quantizer import BaseQuantizer
19
-
20
-
21
- class NodeQuantizationDispatcher:
22
- def __init__(self,
23
- weight_quantizers: Dict[str, BaseQuantizer] = None,
24
- activation_quantizers: List[BaseQuantizer] = None):
25
- """
26
- Node quantization dispatcher collects all the quantizer of a given layer.
27
-
28
- Args:
29
- weight_quantizers: A dictionary between weight name to it quantizer .
30
- activation_quantizers: A list of activation quantization one for each layer output.
31
- """
32
- self.weight_quantizers = weight_quantizers if weight_quantizers is not None else dict()
33
- self.activation_quantizers = activation_quantizers if activation_quantizers is not None else list()
34
-
35
- def add_weight_quantizer(self, param_name: str, quantizer: BaseQuantizer):
36
- """
37
- This function add a weight quantizer to existing node dispatcher
38
-
39
- Args:
40
- param_name: The name of the parameter to quantize
41
- quantizer: A quantizer.
42
-
43
- Returns: None
44
-
45
- """
46
- self.weight_quantizers.update({param_name: quantizer})
47
-
48
- @property
49
- def is_activation_quantization(self) -> bool:
50
- """
51
- This function check activation quantizer exists in dispatcher.
52
- Returns: a boolean if activation quantizer exists
53
-
54
- """
55
- return len(self.activation_quantizers) > 0
56
-
57
- @property
58
- def is_weights_quantization(self) -> bool:
59
- """
60
- This function check weights quantizer exists in dispatcher.
61
-
62
- Returns: a boolean if weights quantizer exists
63
-
64
- """
65
- return len(self.weight_quantizers) > 0
@@ -1,14 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,75 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from typing import Dict, Any, List
16
-
17
- from model_compression_toolkit.core.common import Logger
18
- from model_compression_toolkit.core.common.constants import FOUND_TF
19
- from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
21
-
22
- from model_compression_toolkit.qunatizers_infrastructure.common.base_quantizer import BaseQuantizer, QuantizationTarget
23
-
24
- if FOUND_TF:
25
- QUANTIZATION_CONFIG = 'qunatization_config'
26
- from model_compression_toolkit.qunatizers_infrastructure.keras.config_serialization import config_serialization, \
27
- config_deserialization
28
-
29
-
30
- class BaseKerasQuantizer(BaseQuantizer):
31
- def __init__(self,
32
- quantization_config: BaseNodeQuantizationConfig,
33
- quantization_target: QuantizationTarget,
34
- quantization_method: List[QuantizationMethod]):
35
- """
36
- This class is a base quantizer which validate the provide quantization config and define abstract function which any quantizer need to implment.
37
- This class add to the base quantizer get_config and from_config function to enable keras load and save model.
38
- Args:
39
- quantization_config: node quantization config class contins all the information above a quantizer.
40
- quantization_target: A enum which decided the qunaizer tensor type activation or weights.
41
- quantization_method: A list of enums which represent the quantizer supported methods.
42
- """
43
- super().__init__(quantization_config, quantization_target, quantization_method)
44
-
45
- def get_config(self) -> Dict[str, Any]:
46
- """
47
-
48
- Returns: Configuration of BaseKerasQuantizer.
49
-
50
- """
51
- return {QUANTIZATION_CONFIG: config_serialization(self.quantization_config)}
52
-
53
- @classmethod
54
- def from_config(cls, config: dict):
55
- """
56
-
57
- Args:
58
- config(dict): dictonory of BaseKerasQuantizer Configuration
59
-
60
- Returns: A BaseKerasQuantizer
61
-
62
- """
63
- config = config.copy()
64
- quantization_config = config_deserialization(config[QUANTIZATION_CONFIG])
65
- # Note that a quantizer only receive quantization config and the rest of define hardcoded inside the speficie quantizer.
66
- return cls(quantization_config=quantization_config)
67
-
68
- else:
69
- class BaseKerasQuantizer(BaseQuantizer):
70
- def __init__(self, quantization_config: BaseNodeQuantizationConfig, quantization_target: QuantizationTarget,
71
- quantization_method: List[QuantizationMethod]):
72
- super().__init__(quantization_config, quantization_target, quantization_method)
73
- Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
74
- 'when using BaseKerasQuantizer. '
75
- 'Could not find Tensorflow package.')