mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -0,0 +1,180 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import copy
16
+ from typing import Callable
17
+
18
+ import keras.models
19
+ import numpy as np
20
+ import tensorflow as tf
21
+ from keras import Sequential
22
+ from keras.layers import Dense, Conv2D, Reshape
23
+ from keras.models import clone_model
24
+
25
+ from model_compression_toolkit import quantizers_infrastructure as qi
26
+ from model_compression_toolkit.core.common import Logger
27
+ from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
28
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import \
29
+ constants as keras_inferable_constants
30
+
31
+ BIAS_INITIALIZER = 'bias_initializer'
32
+ BIAS_REGULARIZER = 'bias_regularizer'
33
+ BIAS_CONSTRAINT = 'bias_constraint'
34
+ ACTIVITY_REGULARIZER = 'activity_regularizer'
35
+ KERNEL_INITIALIZER = 'kernel_initializer'
36
+ KERNEL_REGULARIZER = 'kernel_regularizer'
37
+ KERNEL_CONSTRAINT = 'kernel_constraint'
38
+ KERNEL_SIZE = 'kernel_size'
39
+ PADDING = 'padding'
40
+ STRIDES = 'strides'
41
+ LAYER_NAME = 'name'
42
+ TRAINABLE = 'trainable'
43
+ ACTIVATION = 'activation'
44
+ USE_BIAS = 'use_bias'
45
+ FILTERS = 'filters'
46
+ UNITS = 'units'
47
+ PAD_VALID = 'valid'
48
+ KERNEL = 'kernel'
49
+
50
+ CONV_KERNEL_CHANNEL_AXIS = 3
51
+ CONV_INPUT_CHANNELS_DIM = 4
52
+
53
+ class INT8TFLiteExporter(FakelyQuantKerasExporter):
54
+ """
55
+ Exporter for INT8 TFLite models.
56
+ The exporter expects to receive an exportable model (where each layer's full quantization parameters
57
+ can be retrieved), and convert it into a quantized model where weights and activations are represented
58
+ as integer data type.
59
+ """
60
+
61
+ def __init__(self,
62
+ model: keras.models.Model,
63
+ is_layer_exportable_fn: Callable,
64
+ save_model_path: str):
65
+ """
66
+
67
+ Args:
68
+ model: Model to export.
69
+ is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
70
+ save_model_path: Path to save the exported model.
71
+ """
72
+ super().__init__(model,
73
+ is_layer_exportable_fn,
74
+ save_model_path)
75
+
76
+ self.exported_model = None
77
+
78
+ def _get_pointwise_layer_to_replace_dense(self, wrapped_layer: qi.KerasQuantizationWrapper) -> keras.layers.Layer:
79
+ # First we create a pointwise configuration based on the Dense layer's configuration
80
+ dense_cfg = wrapped_layer.layer.get_config()
81
+
82
+ # List of pw attributes that should be taken from the dense layer as they are.
83
+ pw_attr_list = [LAYER_NAME, ACTIVATION, USE_BIAS, BIAS_CONSTRAINT,
84
+ BIAS_INITIALIZER, BIAS_REGULARIZER, TRAINABLE, ACTIVITY_REGULARIZER,
85
+ KERNEL_INITIALIZER, KERNEL_REGULARIZER, KERNEL_CONSTRAINT]
86
+
87
+ pw_cfg = {attr: dense_cfg[attr] for attr in pw_attr_list}
88
+
89
+ # Use more attributes that are not taken as they are
90
+ pw_cfg.update({KERNEL_SIZE: (1, 1),
91
+ STRIDES: (1, 1),
92
+ PADDING: PAD_VALID,
93
+ FILTERS: dense_cfg[UNITS]})
94
+
95
+ # Create the point-wise layer
96
+ pw_layer = Conv2D(**pw_cfg)
97
+ pw_layer.build(wrapped_layer.layer.input_shape)
98
+
99
+ # Create and set the point-wise weights to assign
100
+ dense_kernel = wrapped_layer.layer.kernel
101
+ pw_weights = []
102
+ pw_kernel = np.reshape(wrapped_layer.get_weights()[0],
103
+ (1, 1, dense_kernel.get_shape()[0], dense_cfg[UNITS]))
104
+
105
+ pw_weights.append(pw_kernel)
106
+ if wrapped_layer.layer.use_bias:
107
+ pw_bias = wrapped_layer.get_weights()[2]
108
+ pw_weights.append(pw_bias)
109
+
110
+ pw_layer.set_weights(pw_weights)
111
+
112
+ # Now that we have the point-wise to replace the dense layer,
113
+ # we need to wrap it using qi.KerasQuantizationWrapper with a new
114
+ # relevant quantizers.
115
+ # Create new kernel quantizer
116
+ pw_kernel_quantizer_cfg = wrapped_layer.weights_quantizers[KERNEL].get_config()
117
+
118
+ # In Conv2D channel axis is 3 and not 1 as in Dense
119
+ pw_kernel_quantizer_cfg[keras_inferable_constants.CHANNEL_AXIS] = CONV_KERNEL_CHANNEL_AXIS
120
+
121
+ # Unquantized weight to conv layer has 4 dimensions (unlike dense which varies)
122
+ pw_kernel_quantizer_cfg[keras_inferable_constants.INPUT_RANK] = CONV_INPUT_CHANNELS_DIM
123
+
124
+ assert isinstance(pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD], np.ndarray), f'Expected to find threshold which is a numpy array, but found: {type(pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD])}'
125
+ pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD] = list(pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD])
126
+
127
+ # Now that we have the point-wise quantizer we can instantiate it
128
+ quantizer_class = type(wrapped_layer.weights_quantizers[KERNEL])
129
+ pw_quantizer = quantizer_class(**pw_kernel_quantizer_cfg)
130
+ pw_weights_quantizers = copy.deepcopy(wrapped_layer.weights_quantizers)
131
+ pw_weights_quantizers[KERNEL] = pw_quantizer
132
+
133
+ # Wrap pw with the new quantizers (the activation is not affected thus we take the Dense quantizers)
134
+ wrapped_pw = qi.KerasQuantizationWrapper(pw_layer,
135
+ pw_weights_quantizers,
136
+ wrapped_layer.activation_quantizers)
137
+
138
+ # Compute the shape that the input to the new layer should be reshaped into
139
+ # Example: Dense kernel with the following shape (3, 20) expects to have input with the
140
+ # next dimensions (BATCH_SIZE, x0, x1, ..., xn, 20).
141
+ # Conv layer expects 4-rank input. Thus, the input is reshaped to (BATCH_SIZE, 1, x0*x1*...*xn, 20)
142
+ dim = wrapped_layer.layer.input_shape[1:-1]
143
+ target_shape = (1, int(np.prod(dim))) + (dense_kernel.get_shape()[0],)
144
+
145
+ return Sequential([
146
+ Reshape(target_shape=target_shape),
147
+ wrapped_pw,
148
+ Reshape(wrapped_layer.layer.output_shape[1:])
149
+ ])
150
+
151
+ def export(self) -> None:
152
+ """
153
+ Export a fully quantized model to its int8 tflite model.
154
+ """
155
+
156
+ def _substitute_model(wrapped_layer: qi.KerasQuantizationWrapper) -> keras.layers.Layer:
157
+ assert self.is_layer_exportable_fn(
158
+ wrapped_layer), f'Layer {wrapped_layer.get_config()} did not pass validation'
159
+
160
+ # In order to support dense quantization using per-channel quantization (which is
161
+ # unsupported in TFLITE int models) we substitute each dense layer to its equivalent
162
+ # point-wise convolution.
163
+ if isinstance(wrapped_layer.layer, Dense):
164
+ return self._get_pointwise_layer_to_replace_dense(wrapped_layer)
165
+
166
+ return wrapped_layer
167
+
168
+ # Transform the model to a new model that can be converted to int8 models.
169
+ # For example: replace dense layers with point-wise layers (to support per-channel quantization)
170
+ self.transformed_model = clone_model(self.model,
171
+ clone_function=_substitute_model)
172
+
173
+ # Convert model to int8 representation
174
+ converter = tf.lite.TFLiteConverter.from_keras_model(self.transformed_model)
175
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
176
+ self.exported_model = converter.convert()
177
+
178
+ Logger.info(f'Exporting INT8 tflite model to: {self.save_model_path}')
179
+ with open(self.save_model_path, 'wb') as f:
180
+ f.write(self.exported_model)
@@ -15,41 +15,59 @@
15
15
  from enum import Enum
16
16
  from typing import Callable
17
17
 
18
- import keras
19
-
20
18
  from model_compression_toolkit.core.common import Logger
21
- from model_compression_toolkit.exporter.model_exporter.tflite.fakely_quant_tflite_exporter import \
22
- FakelyQuantTFLiteExporter
19
+ from model_compression_toolkit.core.common.constants import FOUND_TF
23
20
 
24
21
 
25
22
  class TFLiteExportMode(Enum):
26
23
  FAKELY_QUANT = 0
24
+ INT8 = 1
25
+
26
+ if FOUND_TF:
27
+ import keras
28
+ from model_compression_toolkit.exporter.model_exporter.tflite.fakely_quant_tflite_exporter import FakelyQuantTFLiteExporter
29
+ from model_compression_toolkit.exporter.model_exporter.tflite.int8_tflite_exporter import INT8TFLiteExporter
30
+ from model_compression_toolkit.exporter.model_wrapper.keras.validate_layer import is_keras_layer_exportable
27
31
 
32
+ def tflite_export_model(model: keras.models.Model,
33
+ save_model_path: str,
34
+ mode: TFLiteExportMode = TFLiteExportMode.FAKELY_QUANT,
35
+ is_layer_exportable_fn: Callable = is_keras_layer_exportable
36
+ ):
37
+ """
38
+ Export a Keras quantized model to a tflite model.
39
+ The model will be saved to the path in save_model_path.
40
+ Mode can be used for different exported files. Currently, tflite_export_model
41
+ supports TFLiteExportMode.FAKELY_QUANT (where weights and activations are
42
+ float fakely-quantized values), and TFLiteExportMode.INT8 (where weights
43
+ and activations are represented using 8bits integers).
28
44
 
29
- def tflite_export_model(model: keras.models.Model,
30
- is_layer_exportable_fn: Callable,
31
- mode: TFLiteExportMode = TFLiteExportMode.FAKELY_QUANT,
32
- save_model_path: str = None):
33
- """
34
- Prepare and return fully quantized model for export. Save exported model to
35
- a path if passed.
45
+ Args:
46
+ model: Model to export.
47
+ is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
48
+ mode: Mode to export the model according to.
49
+ save_model_path: Path to save the model.
36
50
 
37
- Args:
38
- model: Model to export.
39
- is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
40
- mode: Mode to export the model according to.
41
- save_model_path: Path to save the model.
51
+ """
42
52
 
43
- """
53
+ if mode == TFLiteExportMode.FAKELY_QUANT:
54
+ exporter = FakelyQuantTFLiteExporter(model,
55
+ is_layer_exportable_fn,
56
+ save_model_path)
57
+ elif mode == TFLiteExportMode.INT8:
58
+ exporter = INT8TFLiteExporter(model,
59
+ is_layer_exportable_fn,
60
+ save_model_path)
44
61
 
45
- if mode == TFLiteExportMode.FAKELY_QUANT:
46
- exporter = FakelyQuantTFLiteExporter(model,
47
- is_layer_exportable_fn,
48
- save_model_path)
62
+ else:
63
+ Logger.critical(
64
+ f'Unsupported mode was used {mode.name} to export TFLite model.'
65
+ f' Please see API for supported modes.') # pragma: no cover
49
66
 
50
- else:
51
- Logger.critical(
52
- f'Unsupported mode was used {mode.name} to export TFLite model.'
53
- f' Please see API for supported modes.')
67
+ exporter.export()
54
68
 
55
- exporter.export()
69
+ else:
70
+ def tflite_export_model(*args, **kwargs):
71
+ Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
72
+ 'when using tflite_export_model. '
73
+ 'Could not find some or all of TensorFlow packages.') # pragma: no cover
@@ -13,8 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.constants import FOUND_TF
16
+ from model_compression_toolkit.exporter.model_wrapper.keras.validate_layer import is_keras_layer_exportable
17
+ from model_compression_toolkit.exporter.model_wrapper.keras.builder.fully_quantized_model_builder import get_exportable_keras_model
17
18
 
18
- if FOUND_TF:
19
- from model_compression_toolkit.exporter.model_wrapper.keras.validate_layer import is_keras_layer_exportable
20
- from model_compression_toolkit.exporter.model_wrapper.keras.builder.fully_quantized_model_builder import get_exportable_keras_model
19
+ from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
20
+ from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
@@ -12,157 +12,54 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from typing import Tuple
15
16
 
16
- import tensorflow as tf
17
- import tensorflow_model_optimization.quantization.keras.graph_transformations.model_transformer as mt
18
- from keras.layers import TFOpLambda
19
- from keras.models import Model
20
- from tensorflow.python.util.object_identity import Reference as TFReference
21
- from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_configs import \
22
- NoOpQuantizeConfig
23
- from typing import List, Tuple, Dict, Any
24
-
25
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
26
17
 
18
+ from model_compression_toolkit import quantizers_infrastructure as qi
27
19
  from model_compression_toolkit.core import common
28
- from model_compression_toolkit.core.common import BaseNode, Graph, Logger
20
+ from model_compression_toolkit.core.common import Graph, Logger
21
+ from model_compression_toolkit.core.common.constants import FOUND_TF
29
22
  from model_compression_toolkit.core.common.user_info import UserInformation
30
- from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder, \
31
- is_layer_fake_quant, get_node_name_from_layer
32
- from model_compression_toolkit.core.keras.quantizer.input_layer_quantize_transform import InputLayerWrapperTransform
33
-
34
- from model_compression_toolkit.exporter.model_wrapper.keras.builder.quantize_config_to_node import \
35
- get_quantization_config
36
- from model_compression_toolkit.exporter.model_wrapper.keras.quantize_configs.activation_quantize_config import \
37
- ActivationQuantizeConfig
38
- from model_compression_toolkit.exporter.model_wrapper.keras.quantize_configs.weights_activation_quantize_config \
39
- import \
40
- WeightsActivationQuantizeConfig
41
- from model_compression_toolkit.exporter.model_wrapper.keras.quantize_configs.weights_quantize_config import \
42
- WeightsQuantizeConfig
43
- from model_compression_toolkit.exporter.model_wrapper.keras.extended_quantize_wrapper import ExtendedQuantizeWrapper
44
- from model_compression_toolkit.exporter.model_wrapper.keras.quantizers.fq_quantizer import FakeQuantQuantizer
45
- from model_compression_toolkit.exporter.model_wrapper.keras.quantizers.weights_uniform_quantizer import \
46
- WeightsUniformQuantizer
47
-
48
-
49
- def get_exportable_keras_model(graph: Graph) -> tf.keras.models.Model:
50
- """
51
- Convert graph to an exportable Keras model (model with all quantization parameters).
52
- An exportable model can then be exported using model_exporter, to retrieve the
53
- final exported model.
54
-
55
- Args:
56
- graph: Graph to convert to an exportable Keras model.
57
23
 
58
- Returns:
59
- Exportable Keras model.
60
- """
24
+ if FOUND_TF:
25
+ import tensorflow as tf
26
+ from tensorflow.keras.layers import Layer
27
+ from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
28
+ from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizers import get_quantization_quantizers
61
29
 
62
- return FullyQuantizedKerasModelBuilder(graph=graph).build_model()
63
-
64
-
65
- class FullyQuantizedKerasModelBuilder(KerasModelBuilder):
66
- """
67
- Builder of exportable Keras models (fully quantized).
68
- """
69
-
70
- def __init__(self,
71
- graph: common.Graph):
30
+ def _get_wrapper(node: common.BaseNode,
31
+ layer: Layer) -> qi.KerasQuantizationWrapper:
72
32
  """
73
-
33
+ A function which takes a computational graph node and a keras layer and perform the quantization wrapping
74
34
  Args:
75
- graph: Graph to build the model from.
76
- """
35
+ n: A node of mct graph.
36
+ layer: A keras layer
77
37
 
78
- super().__init__(graph)
38
+ Returns: Wrapped layer with weights quantizers and activation quantizers
79
39
 
80
- def _quantize_node_activations(self,
81
- node: BaseNode,
82
- input_tensors: List[TFReference]) -> List[TFReference]:
83
40
  """
84
- Quantize node's activation given input tensors.
41
+ weights_quantizers, activation_quantizers = get_quantization_quantizers(node)
42
+ return qi.KerasQuantizationWrapper(layer, weights_quantizers, activation_quantizers)
85
43
 
86
- Args:
87
- node: Node to quantize its outputs.
88
- input_tensors: Input tensors of the node.
89
-
90
- Returns:
91
- Output of the node.
92
-
93
- """
94
- return input_tensors
95
44
 
96
- def build_model(self) -> Tuple[Model, UserInformation]:
45
+ def get_exportable_keras_model(graph: Graph) -> Tuple[tf.keras.models.Model, UserInformation]:
97
46
  """
98
- Build a Keras mixed-precision model and return it.
99
- Returns: Mixed-precision Keras model.
100
-
101
- """
102
- model, user_info = super().build_model()
103
-
104
- def _wrap_layer_with_quantize_config(layer):
105
-
106
- node = self.oh.layer_to_node_dict.get(layer)
47
+ Convert graph to an exportable Keras model (model with all quantization parameters).
48
+ An exportable model can then be exported using model_exporter, to retrieve the
49
+ final exported model.
107
50
 
108
- if node is not None:
109
- # In case of layers that are in reused groups, output_shape does not exist.
110
- layer_output_shape = layer.output_shape if (node.reuse_group is None) else None
111
- # For now, we do not support reused TFOpLambda layers.
112
- if isinstance(layer, TFOpLambda) and layer_output_shape is None:
113
- Logger.error(
114
- f'Output shape must be inferred to use ExtendedQuantizeWrapper, but it seems that TFOpLambda '
115
- f'layer {layer.name} has no output shape. If it is a reused layer, MCT does not support '
116
- f'reused TFOpLambda layers for now.')
117
- return ExtendedQuantizeWrapper(layer, get_quantization_config(node), layer_output_shape)
118
-
119
- elif is_layer_fake_quant(layer):
120
- return layer
121
-
122
- else:
123
- raise Exception(
124
- f'Mismatch between keras model and graph cant find node named: '
125
- f'{get_node_name_from_layer(layer)}')
126
-
127
- # clone each layer in the model and apply _wrap_layer_with_quantize_config to the layer.
128
- model = tf.keras.models.clone_model(model,
129
- input_tensors=None,
130
- clone_function=_wrap_layer_with_quantize_config)
131
-
132
-
133
- # We use a model transformer to wrap the input layer with QuantizeWrapper.
134
- # A model transformer allows to modify a layer in an existing model, by applying the given list of
135
- # transformers on the model (in this case,
136
- # we only apply single transformer - InputLayerQuantizeTransform)
137
- model_inputs = self.graph.get_inputs()
138
-
139
- input_transformer = mt.ModelTransformer(model, [InputLayerWrapperTransform(inp,
140
- get_quantization_config(inp),
141
- self.get_custom_objects(),
142
- ExtendedQuantizeWrapper)
143
- for inp in model_inputs])
144
-
145
- model = input_transformer.transform()[0]
146
-
147
- return model, user_info
148
-
149
- @staticmethod
150
- def get_custom_objects() -> Dict[str, Any]:
151
- """
152
-
153
- Returns: Dictionary of custom objects needed to load this model builder's output.
51
+ Args:
52
+ graph: Graph to convert to an exportable Keras model.
154
53
 
54
+ Returns:
55
+ Exportable Keras model and user information.
155
56
  """
156
- return {ExtendedQuantizeWrapper.__name__: ExtendedQuantizeWrapper,
157
- QuantizeWrapper.__name__: QuantizeWrapper,
158
- WeightsActivationQuantizeConfig.__name__: WeightsActivationQuantizeConfig,
159
- ActivationQuantizeConfig.__name__: ActivationQuantizeConfig,
160
- WeightsQuantizeConfig.__name__: WeightsQuantizeConfig,
161
- WeightsUniformQuantizer.__name__: WeightsUniformQuantizer,
162
- NoOpQuantizeConfig.__name__: NoOpQuantizeConfig,
163
- FakeQuantQuantizer.__name__: FakeQuantQuantizer}
164
-
165
-
166
-
167
-
168
-
57
+ exportable_model, user_info = KerasModelBuilder(graph=graph,
58
+ wrapper=_get_wrapper).build_model()
59
+ exportable_model.trainable = False
60
+ return exportable_model, user_info
61
+ else:
62
+ def get_exportable_keras_model(*args, **kwargs): # pragma: no cover
63
+ Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
64
+ 'when using get_exportable_keras_model. '
65
+ 'Could not find Tensorflow package.')
@@ -0,0 +1,143 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Dict, Any
16
+
17
+ from model_compression_toolkit.core.common import BaseNode, Logger
18
+ from model_compression_toolkit.core.common.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED, CLUSTER_CENTERS, SCALE_PER_CHANNEL
19
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget
21
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import get_inferable_quantizer_class
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import BaseKerasInferableQuantizer
23
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import constants as qi_keras_consts
24
+
25
+ def get_inferable_quantizer_kwargs(node: BaseNode,
26
+ quantization_target: QuantizationTarget) -> Dict[str, Any]:
27
+ """
28
+ Get the quantization parameters for an inferable quantizer.
29
+ Args:
30
+ node: The node for which the quantizer is being created.
31
+ quantization_target: The target of the quantization (weights or activations).
32
+ Returns:
33
+ The quantization parameters as a dictionary.
34
+ """
35
+
36
+ if quantization_target == QuantizationTarget.Weights:
37
+ # Get the weights quantization configuration for the node
38
+ node_w_qc = node.final_weights_quantization_cfg
39
+ quantization_method = node_w_qc.weights_quantization_method
40
+
41
+ # Return the appropriate quantization parameters based on the quantization method
42
+ if quantization_method in [QuantizationMethod.POWER_OF_TWO,
43
+ QuantizationMethod.SYMMETRIC]:
44
+ return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
45
+ qi_keras_consts.THRESHOLD: list(node_w_qc.weights_quantization_params[THRESHOLD].flatten()),
46
+ qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
47
+ qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
48
+ qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[THRESHOLD].shape)}
49
+
50
+ elif quantization_method in [QuantizationMethod.UNIFORM]:
51
+ return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
52
+ qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
53
+ qi_keras_consts.MIN_RANGE: list(node_w_qc.weights_quantization_params[RANGE_MIN].flatten()),
54
+ qi_keras_consts.MAX_RANGE: list(node_w_qc.weights_quantization_params[RANGE_MAX].flatten()),
55
+ qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
56
+ qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[RANGE_MIN].shape)}
57
+
58
+ elif quantization_method in [QuantizationMethod.LUT_SYM_QUANTIZER, QuantizationMethod.LUT_POT_QUANTIZER]:
59
+ return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
60
+ qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
61
+ qi_keras_consts.CLUSTER_CENTERS: node_w_qc.weights_quantization_params[CLUSTER_CENTERS],
62
+ qi_keras_consts.THRESHOLD: list(node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten()),
63
+ qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
64
+ # TODO: how to pass multiplier nbits and eps for a specific node?
65
+ qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].shape)}
66
+
67
+ else:
68
+ Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
69
+
70
+ elif quantization_target == QuantizationTarget.Activation:
71
+ # Get the activation quantization configuration for the node
72
+ node_qc = node.final_activation_quantization_cfg
73
+ quantization_method = node_qc.activation_quantization_method
74
+
75
+ # Return the appropriate quantization parameters based on the quantization method
76
+ if quantization_method in [QuantizationMethod.POWER_OF_TWO,
77
+ QuantizationMethod.SYMMETRIC]:
78
+ return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
79
+ # In activation quantization is per-tensor only - thus we hold the threshold as a list with a len of 1
80
+ qi_keras_consts.THRESHOLD: [node_qc.activation_quantization_params[THRESHOLD]],
81
+ qi_keras_consts.SIGNED: node_qc.activation_quantization_params[SIGNED]}
82
+
83
+ elif quantization_method in [QuantizationMethod.UNIFORM]:
84
+ return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
85
+ # In activation quantization is per-tensor only - thus we hold the min/max as a list with a len of 1
86
+ qi_keras_consts.MIN_RANGE: [node_qc.activation_quantization_params[RANGE_MIN]],
87
+ qi_keras_consts.MAX_RANGE: [node_qc.activation_quantization_params[RANGE_MAX]]}
88
+
89
+ elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER]:
90
+ return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
91
+ qi_keras_consts.SIGNED: node_qc.activation_quantization_params[SIGNED],
92
+ qi_keras_consts.CLUSTER_CENTERS: node_qc.activation_quantization_params[CLUSTER_CENTERS],
93
+ qi_keras_consts.THRESHOLD: [node_qc.activation_quantization_params[THRESHOLD]]
94
+ # TODO: how to pass multiplier nbits and eps for a specific node?
95
+ }
96
+ else:
97
+ Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
98
+ else:
99
+ Logger.critical(f'{quantization_target} is not supported') # pragma: no cover
100
+
101
+
102
+ def get_weights_quantizer_for_node(node: BaseNode) -> BaseKerasInferableQuantizer:
103
+ """
104
+ Get weights quantizer for a node.
105
+ Args:
106
+ node: Node to create a weight quantizer for.
107
+ Returns:
108
+ Quantizer for the node's weights.
109
+ """
110
+ if node.final_weights_quantization_cfg is None:
111
+ Logger.critical(f'Can not set quantizer for a node with no final weights quantization configuration') # pragma:
112
+ # no cover
113
+ node_w_qc = node.final_weights_quantization_cfg
114
+ weights_quantization_method = node_w_qc.weights_quantization_method
115
+
116
+ quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Weights,
117
+ weights_quantization_method,
118
+ BaseKerasInferableQuantizer)
119
+ kwargs = get_inferable_quantizer_kwargs(node, QuantizationTarget.Weights)
120
+
121
+ return quantier_for_node(**kwargs)
122
+
123
+
124
+ def get_activations_quantizer_for_node(node: BaseNode) -> BaseKerasInferableQuantizer:
125
+ """
126
+ Get activation quantizer for a node.
127
+ Args:
128
+ node: Node to create an activation quantizer for.
129
+ Returns:
130
+ Quantizer for the node's activations.
131
+ """
132
+ if node.final_activation_quantization_cfg is None:
133
+ Logger.critical(f'Can not set quantizer for a node with no final activation quantization configuration') #
134
+ # pragma: no cover
135
+ node_act_qc = node.final_activation_quantization_cfg
136
+ activation_quantization_method = node_act_qc.activation_quantization_method
137
+
138
+ quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
139
+ activation_quantization_method,
140
+ BaseKerasInferableQuantizer)
141
+ kwargs = get_inferable_quantizer_kwargs(node, QuantizationTarget.Activation)
142
+
143
+ return quantier_for_node(**kwargs)
@@ -0,0 +1,46 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Dict, List, Tuple
16
+ from model_compression_toolkit.core.common import BaseNode
17
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
18
+ from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \
19
+ get_weights_quantizer_for_node, get_activations_quantizer_for_node
20
+
21
+
22
+ def get_quantization_quantizers(node: BaseNode) -> Tuple[Dict, List]:
23
+ """
24
+ Create quantizers to wrap a layer for its corresponding node.
25
+
26
+ Args:
27
+ node: Node to create quantizers for.
28
+
29
+ Returns:
30
+ weight_quantizers: A dictionary between a weight's name to its quantizer.
31
+ activation_quantizers: A list of activations quantization, one for each layer output.
32
+ """
33
+ weight_quantizers = {}
34
+ activation_quantizers = []
35
+
36
+ if node.is_weights_quantization_enabled():
37
+ weight_attrs = DEFAULT_KERAS_INFO.get_kernel_op_attributes(node.type)
38
+ weight_quantizer = get_weights_quantizer_for_node(node)
39
+ for attr in weight_attrs:
40
+ weight_quantizers[attr] = weight_quantizer
41
+
42
+ if node.is_activation_quantization_enabled():
43
+ num_of_outputs = len(node.output_shape) if isinstance(node.output_shape, list) else 1
44
+ activation_quantizers = [get_activations_quantizer_for_node(node)] * num_of_outputs
45
+
46
+ return weight_quantizers, activation_quantizers