mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -0,0 +1,283 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Union
17
+
18
+ import numpy as np
19
+ import tensorflow as tf
20
+ from tensorflow.python.framework.tensor_shape import TensorShape
21
+ from model_compression_toolkit.core.common.constants import SIGNED
22
+
23
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
24
+ from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
25
+ from model_compression_toolkit.qat.common.constants import FQ_MIN, FQ_MAX
26
+ from model_compression_toolkit import quantizers_infrastructure as qi, TrainingMethod
27
+ from model_compression_toolkit.core.common import constants as C
28
+ from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
29
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
30
+ TrainableQuantizerActivationConfig
31
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
32
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import \
33
+ WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer, ActivationPOTInferableQuantizer, \
34
+ ActivationSymmetricInferableQuantizer
35
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
36
+
37
+
38
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
39
+ quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
40
+ quantizer_type=TrainingMethod.STE)
41
+ class STEWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
42
+ """
43
+ Trainable constrained quantizer to quantize a layer inputs.
44
+ """
45
+
46
+ def __init__(self, quantization_config: TrainableQuantizerWeightsConfig):
47
+ """
48
+ Initialize a TrainableWeightQuantizer object with parameters to use
49
+ for the quantization.
50
+
51
+ Args:
52
+ quantization_config: trainable quantizer config class
53
+ """
54
+ super().__init__(quantization_config)
55
+ self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
56
+ self.threshold_values = quantization_config.weights_quantization_params[C.THRESHOLD]
57
+ self.threshold_shape = np.asarray(self.threshold_values).shape
58
+ self.per_channel = self.quantization_config.weights_per_channel_threshold
59
+ self.channel_axis = self.quantization_config.weights_channels_axis
60
+ self.np_threshold_values = np.reshape(np.asarray(self.threshold_values),[-1]) if self.channel_axis else float(self.threshold_values)
61
+
62
+ if self.per_channel and self.channel_axis not in [-1, len(self.threshold_shape) - 1]:
63
+ # Tensorflow's fake_quant_with_min_max_vars_per_channel only works on last axis, so
64
+ # need to move the quantization axis to the last axis
65
+ self.perm_vec = list(np.arange(len(self.threshold_shape)))
66
+ self.perm_vec[self.channel_axis] = len(self.threshold_shape) - 1
67
+ self.perm_vec[len(self.threshold_shape) - 1] = self.channel_axis
68
+ else:
69
+ self.perm_vec = None
70
+
71
+ if self.power_of_two:
72
+ self.np_threshold_values = np.power(2.0,np.ceil(np.log2(np.maximum(self.np_threshold_values, C.MIN_THRESHOLD))))
73
+
74
+ self.num_bits = self.quantization_config.weights_n_bits
75
+ delta = self.np_threshold_values / np.power(2.0, self.num_bits - int(C.WEIGHTS_SIGNED))
76
+ min_int = -int(C.WEIGHTS_SIGNED) * (2 ** (self.num_bits - int(C.WEIGHTS_SIGNED)))
77
+ max_int = (2 ** (self.num_bits - int(C.WEIGHTS_SIGNED))) - 1
78
+ self.min = delta * min_int
79
+ self.max = delta * max_int
80
+
81
+
82
+ def initialize_quantization(self,
83
+ tensor_shape: TensorShape,
84
+ name: str,
85
+ layer: qi.KerasQuantizationWrapper):
86
+ """
87
+ Add quantizer parameters to the quantizer parameters dictionary
88
+
89
+ Args:
90
+ tensor_shape: tensor shape of the quantized tensor.
91
+ name: Tensor name.
92
+ layer: Layer to quantize.
93
+ """
94
+ ptq_threshold_tensor = layer.add_weight(
95
+ name + THRESHOLD_TENSOR,
96
+ shape=len(self.np_threshold_values) if self.channel_axis else (),
97
+ initializer=tf.keras.initializers.Constant(1.0),
98
+ trainable=False)
99
+ ptq_threshold_tensor.assign(self.np_threshold_values)
100
+
101
+ fq_min = layer.add_weight(
102
+ name + FQ_MIN,
103
+ shape=len(self.min) if self.channel_axis else (),
104
+ initializer=tf.keras.initializers.Constant(-1.0),
105
+ trainable=False)
106
+ fq_min.assign(self.min)
107
+
108
+ fq_max = layer.add_weight(
109
+ name + FQ_MAX,
110
+ shape=len(self.max) if self.channel_axis else (),
111
+ initializer=tf.keras.initializers.Constant(1.0),
112
+ trainable=False)
113
+ fq_max.assign(self.max)
114
+
115
+ # save the quantizer added parameters for later calculations
116
+ self.add_quantizer_variable(THRESHOLD_TENSOR, ptq_threshold_tensor, VariableGroup.QPARAMS)
117
+ self.add_quantizer_variable(FQ_MIN, fq_min, VariableGroup.QPARAMS)
118
+ self.add_quantizer_variable(FQ_MAX, fq_max, VariableGroup.QPARAMS)
119
+
120
+ def __call__(self,
121
+ inputs: tf.Tensor,
122
+ training: bool):
123
+ """
124
+ Quantize a tensor.
125
+ Args:
126
+ inputs: Input tensor to quantize.
127
+ training: Whether the graph is in training mode.
128
+ weights: Dictionary of weights the quantizer can use to quantize the tensor.
129
+ **kwargs: Additional variables the quantizer may receive.
130
+
131
+ Returns:
132
+ The quantized tensor.
133
+ """
134
+
135
+ _min = self.get_quantizer_variable(FQ_MIN)
136
+ _max = self.get_quantizer_variable(FQ_MAX)
137
+ if self.channel_axis:
138
+ if self.perm_vec:
139
+ inputs = tf.transpose(inputs, perm=self.perm_vec)
140
+ q_tensor = tf.quantization.fake_quant_with_min_max_vars_per_channel(inputs, _min, _max,
141
+ num_bits=self.num_bits)
142
+ if self.perm_vec:
143
+ q_tensor = tf.transpose(q_tensor, perm=self.perm_vec)
144
+ else:
145
+ q_tensor = tf.quantization.fake_quant_with_min_max_vars(inputs, _min, _max,
146
+ num_bits=self.num_bits)
147
+
148
+ return q_tensor
149
+
150
+ def convert2inferable(self) -> Union[WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer]:
151
+ """
152
+ Convert quantizer to inferable quantizer.
153
+
154
+ Returns:
155
+ BaseKerasInferableQuantizer object.
156
+ """
157
+ if self.power_of_two:
158
+ pot_threshold = 2 ** np.ceil(np.log2(self.get_quantizer_variable(THRESHOLD_TENSOR)))
159
+ return WeightsPOTInferableQuantizer(num_bits=self.num_bits,
160
+ threshold=list(pot_threshold.flatten()),
161
+ per_channel=self.per_channel,
162
+ channel_axis=self.channel_axis,
163
+ input_rank=len(self.threshold_shape))
164
+ else:
165
+ return WeightsSymmetricInferableQuantizer(num_bits=self.num_bits,
166
+ threshold=list(self.get_quantizer_variable(THRESHOLD_TENSOR).numpy().flatten()),
167
+ per_channel=self.per_channel,
168
+ channel_axis=self.channel_axis,
169
+ input_rank=len(self.threshold_shape))
170
+
171
+
172
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Activation,
173
+ quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
174
+ quantizer_type=TrainingMethod.STE)
175
+ class STEActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
176
+ """
177
+ Trainable constrained quantizer to quantize a layer outputs.
178
+ """
179
+
180
+ def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
181
+ """
182
+ Initialize a STEActivationQATQuantizer object with parameters to use
183
+ for the quantization.
184
+
185
+ Args:
186
+ quantization_config: trainable quantizer config class
187
+ """
188
+ super().__init__(quantization_config)
189
+ self.power_of_two = quantization_config.activation_quantization_method == QuantizationMethod.POWER_OF_TWO
190
+ self.threshold_values = quantization_config.activation_quantization_params[C.THRESHOLD]
191
+ self.threshold_shape = np.asarray(self.threshold_values).shape
192
+ self.np_threshold_values = float(self.threshold_values)
193
+ self.signed = quantization_config.activation_quantization_params[SIGNED]
194
+ if self.power_of_two:
195
+ self.np_threshold_values = np.power(2.0,
196
+ np.ceil(np.log2(np.maximum(self.np_threshold_values, C.MIN_THRESHOLD))))
197
+ self.num_bits = quantization_config.activation_n_bits
198
+ delta = self.np_threshold_values / np.power(2.0, self.num_bits - int(self.signed))
199
+ min_int = -int(self.signed) * (2 ** (self.num_bits - int(self.signed)))
200
+ max_int = (2 ** (self.num_bits - int(self.signed))) - 1
201
+ self.min = delta * min_int
202
+ self.max = delta * max_int
203
+
204
+ def initialize_quantization(self,
205
+ tensor_shape: TensorShape,
206
+ name: str,
207
+ layer: qi.KerasQuantizationWrapper):
208
+ """
209
+ Add quantizer parameters to the quantizer parameters dictionary
210
+
211
+ Args:
212
+ tensor_shape: tensor shape of the quantized tensor.
213
+ name: Tensor name.
214
+ layer: Layer to quantize.
215
+ """
216
+ ptq_threshold_tensor = layer.add_weight(
217
+ name + THRESHOLD_TENSOR,
218
+ shape=(),
219
+ initializer=tf.keras.initializers.Constant(1.0),
220
+ trainable=False)
221
+ ptq_threshold_tensor.assign(self.np_threshold_values)
222
+
223
+ fq_min = layer.add_weight(
224
+ name + FQ_MIN,
225
+ shape=(),
226
+ initializer=tf.keras.initializers.Constant(-1.0),
227
+ trainable=False)
228
+ fq_min.assign(self.min)
229
+
230
+ fq_max = layer.add_weight(
231
+ name + FQ_MAX,
232
+ shape=(),
233
+ initializer=tf.keras.initializers.Constant(1.0),
234
+ trainable=False)
235
+ fq_max.assign(self.max)
236
+
237
+ # save the quantizer added parameters for later calculations
238
+ self.add_quantizer_variable(THRESHOLD_TENSOR, ptq_threshold_tensor, VariableGroup.QPARAMS)
239
+ self.add_quantizer_variable(FQ_MIN, fq_min, VariableGroup.QPARAMS)
240
+ self.add_quantizer_variable(FQ_MAX, fq_max, VariableGroup.QPARAMS)
241
+
242
+
243
+ def __call__(self,
244
+ inputs: tf.Tensor,
245
+ training: bool):
246
+ """
247
+ Quantize a tensor.
248
+ Args:
249
+ inputs: Input tensor to quantize.
250
+ training: Whether the graph is in training mode.
251
+
252
+ Returns:
253
+ The quantized tensor.
254
+ """
255
+
256
+ _min = self.get_quantizer_variable(FQ_MIN)
257
+ _max = self.get_quantizer_variable(FQ_MAX)
258
+ q_tensor = tf.quantization.fake_quant_with_min_max_vars(inputs, _min, _max,
259
+ num_bits=self.num_bits)
260
+
261
+ return q_tensor
262
+
263
+ def convert2inferable(self) -> Union[ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer]:
264
+ """
265
+ Convert quantizer to inferable quantizer.
266
+
267
+ Returns:
268
+ BaseKerasInferableQuantizer object.
269
+ """
270
+
271
+ if self.power_of_two:
272
+ pot_threshold = 2 ** np.ceil(np.log2(self.get_quantizer_variable(THRESHOLD_TENSOR)))
273
+ return ActivationPOTInferableQuantizer(num_bits=self.num_bits,
274
+ # In activation quantization is per-tensor only - thus we pass
275
+ # the threshold as a list with a len of 1
276
+ threshold=[pot_threshold],
277
+ signed=self.signed)
278
+ else:
279
+ return ActivationSymmetricInferableQuantizer(num_bits=self.num_bits,
280
+ # In activation quantization is per-tensor only - thus we
281
+ # pass the threshold as a list with a len of 1
282
+ threshold=[self.get_quantizer_variable(THRESHOLD_TENSOR).numpy()],
283
+ signed=self.signed)
@@ -12,91 +12,92 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
16
- from typing import Dict, Any, List
17
-
18
15
  import numpy as np
19
16
  import tensorflow as tf
20
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
21
17
  from tensorflow.python.framework.tensor_shape import TensorShape
22
18
  from model_compression_toolkit.core.common.constants import RANGE_MIN, RANGE_MAX
19
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
23
20
  from model_compression_toolkit.qat.common.constants import FQ_MIN, FQ_MAX
24
- from model_compression_toolkit import qunatizers_infrastructure as qi
25
- from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
26
-
27
-
28
- class STEUniformWeightQuantizer(qi.BaseKerasQuantizer):
21
+ from model_compression_toolkit.qat.keras.quantizer.quant_utils import adjust_range_to_include_zero
22
+ from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero
23
+ from model_compression_toolkit import quantizers_infrastructure as qi, TrainingMethod
24
+ from model_compression_toolkit.core.common import constants as C
25
+ from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
26
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
27
+ TrainableQuantizerActivationConfig
28
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
29
+ mark_quantizer
30
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import \
31
+ BaseKerasInferableQuantizer, WeightsUniformInferableQuantizer, ActivationUniformInferableQuantizer
32
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
33
+
34
+
35
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
36
+ quantization_method=[QuantizationMethod.UNIFORM],
37
+ quantizer_type=TrainingMethod.STE)
38
+ class STEUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
29
39
  """
30
40
  Trainable constrained quantizer to quantize a layer inputs.
31
41
  """
32
42
 
33
- def __init__(self, quantization_config: NodeWeightsQuantizationConfig):
43
+ def __init__(self, quantization_config: TrainableQuantizerWeightsConfig):
34
44
  """
35
45
  Initialize a TrainableWeightQuantizer object with parameters to use
36
46
  for the quantization.
37
47
 
38
48
  Args:
39
- quantization_config: a quantization config class with attributes for the quantization.
49
+ quantization_config: a trainable quantizer config class with attributes for the quantization.
40
50
 
41
51
  """
42
- super().__init__(quantization_config,
43
- qi.QuantizationTarget.Weights,
44
- [qi.QuantizationMethod.UNIFORM])
52
+ super().__init__(quantization_config)
45
53
  self.max_values = quantization_config.weights_quantization_params[RANGE_MAX]
46
54
  self.min_values = quantization_config.weights_quantization_params[RANGE_MIN]
55
+ self.num_bits = self.quantization_config.weights_n_bits
56
+ self.per_channel = self.quantization_config.weights_per_channel_threshold
57
+ self.channel_axis = self.quantization_config.weights_channels_axis
47
58
  self.min_max_shape = np.asarray(self.max_values).shape
48
- self.max = np.reshape(self.max_values,
49
- [-1]) if self.quantization_config.weights_per_channel_threshold else float(
50
- self.max_values)
51
- self.min = np.reshape(self.min_values,
52
- [-1]) if self.quantization_config.weights_per_channel_threshold else float(
53
- self.min_values)
54
-
55
- if self.quantization_config.weights_per_channel_threshold and self.quantization_config.weights_channels_axis not in [
56
- -1,
57
- len(self.min_max_shape) - 1]:
59
+ self.max = np.reshape(self.max_values, [-1]) if self.per_channel else float(self.max_values)
60
+ self.min = np.reshape(self.min_values, [-1]) if self.per_channel else float(self.min_values)
61
+
62
+ if self.per_channel and self.channel_axis not in [-1, len(self.min_max_shape) - 1]:
58
63
  # Tensorflow's fake_quant_with_min_max_vars_per_channel only works on last axis, so
59
64
  # need to move the quantization axis to the last axis
60
65
  self.perm_vec = list(np.arange(len(self.min_max_shape)))
61
- self.perm_vec[self.quantization_config.weights_channels_axis] = len(self.min_max_shape) - 1
62
- self.perm_vec[len(self.min_max_shape) - 1] = self.quantization_config.weights_channels_axis
66
+ self.perm_vec[self.channel_axis] = len(self.min_max_shape) - 1
67
+ self.perm_vec[len(self.min_max_shape) - 1] = self.channel_axis
63
68
  else:
64
69
  self.perm_vec = None
65
70
 
66
- self.quantizer_parameters = {}
67
-
68
71
  def initialize_quantization(self,
69
72
  tensor_shape: TensorShape,
70
73
  name: str,
71
- layer: QuantizeWrapper) -> Dict[str, tf.Variable]:
74
+ layer: qi.KerasQuantizationWrapper):
72
75
  """
73
- Add min and max variables to layer.
74
- Args:
75
- tensor_shape: Tensor shape the quantizer quantize.
76
- name: Prefix of variables names.
77
- layer: Layer to add the variables to. The variables are saved
78
- in the layer's scope.
76
+ Add quantizer parameters to the quantizer parameters dictionary
79
77
 
80
- Returns:
81
- Dictionary of new variables.
78
+ Args:
79
+ tensor_shape: tensor shape of the quantized tensor.
80
+ name: Tensor name.
81
+ layer: Layer to quantize.
82
82
  """
83
83
  fq_min = layer.add_weight(
84
84
  name + FQ_MIN,
85
- shape=len(self.min) if self.quantization_config.weights_per_channel_threshold else (),
85
+ shape=len(self.min) if self.per_channel else (),
86
86
  initializer=tf.keras.initializers.Constant(-1.0),
87
87
  trainable=False)
88
88
  fq_min.assign(self.min)
89
89
 
90
90
  fq_max = layer.add_weight(
91
91
  name + FQ_MAX,
92
- shape=len(self.max) if self.quantization_config.weights_per_channel_threshold else (),
92
+ shape=len(self.max) if self.per_channel else (),
93
93
  initializer=tf.keras.initializers.Constant(1.0),
94
94
  trainable=False)
95
95
  fq_max.assign(self.max)
96
96
 
97
97
  # save the quantizer added parameters for later calculations
98
- self.quantizer_parameters = {FQ_MIN: fq_min, FQ_MAX: fq_max}
99
- return self.quantizer_parameters
98
+ self.add_quantizer_variable(FQ_MIN, fq_min, VariableGroup.QPARAMS)
99
+ self.add_quantizer_variable(FQ_MAX, fq_max, VariableGroup.QPARAMS)
100
+
100
101
 
101
102
  def __call__(self, inputs: tf.Tensor,
102
103
  training: bool):
@@ -110,17 +111,128 @@ class STEUniformWeightQuantizer(qi.BaseKerasQuantizer):
110
111
  The quantized tensor.
111
112
  """
112
113
 
113
- _min = self.quantizer_parameters[FQ_MIN]
114
- _max = self.quantizer_parameters[FQ_MAX]
115
- if self.quantization_config.weights_per_channel_threshold:
114
+ _min = self.get_quantizer_variable(FQ_MIN)
115
+ _max = self.get_quantizer_variable(FQ_MAX)
116
+ _min, _max = adjust_range_to_include_zero(_min, _max, self.num_bits)
117
+
118
+ if self.per_channel:
116
119
  if self.perm_vec:
117
120
  inputs = tf.transpose(inputs, perm=self.perm_vec)
121
+
118
122
  q_tensor = tf.quantization.fake_quant_with_min_max_vars_per_channel(inputs, _min, _max,
119
- num_bits=self.quantization_config.weights_n_bits)
123
+ num_bits=self.num_bits)
120
124
  if self.perm_vec:
121
125
  q_tensor = tf.transpose(q_tensor, perm=self.perm_vec)
122
126
  else:
123
127
  q_tensor = tf.quantization.fake_quant_with_min_max_vars(inputs, _min, _max,
124
- num_bits=self.quantization_config.weights_n_bits)
128
+ num_bits=self.num_bits)
125
129
 
126
130
  return q_tensor
131
+
132
+ def convert2inferable(self) -> BaseKerasInferableQuantizer:
133
+ """
134
+ Convert quantizer to inferable quantizer.
135
+
136
+ Returns:
137
+ BaseKerasInferableQuantizer object.
138
+ """
139
+ min_range, max_range = fix_range_to_include_zero(self.get_quantizer_variable(FQ_MIN).numpy(),
140
+ self.get_quantizer_variable(FQ_MAX).numpy(),
141
+ self.num_bits)
142
+ return WeightsUniformInferableQuantizer(num_bits=self.num_bits,
143
+ min_range=list(min_range.flatten()),
144
+ max_range=list(max_range.flatten()),
145
+ per_channel=self.per_channel,
146
+ channel_axis=self.channel_axis,
147
+ input_rank=len(self.min_max_shape))
148
+
149
+
150
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Activation,
151
+ quantization_method=[QuantizationMethod.UNIFORM],
152
+ quantizer_type=TrainingMethod.STE)
153
+ class STEUniformActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
154
+ """
155
+ Trainable constrained quantizer to quantize a layer outputs.
156
+ """
157
+
158
+ def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
159
+ """
160
+ Initialize a STEUniformActivationQATQuantizer object with parameters to use
161
+ for the quantization.
162
+
163
+ Args:
164
+ quantization_config: trainable quantizer config class
165
+ """
166
+ super().__init__(quantization_config)
167
+
168
+ self.num_bits = quantization_config.activation_n_bits
169
+ self.min_range = quantization_config.activation_quantization_params[C.RANGE_MIN]
170
+ self.max_range = quantization_config.activation_quantization_params[C.RANGE_MAX]
171
+
172
+ def initialize_quantization(self,
173
+ tensor_shape: TensorShape,
174
+ name: str,
175
+ layer: qi.KerasQuantizationWrapper):
176
+ """
177
+ Add quantizer parameters to the quantizer parameters dictionary
178
+
179
+ Args:
180
+ tensor_shape: tensor shape of the quantized tensor.
181
+ name: Tensor name.
182
+ layer: Layer to quantize.
183
+ """
184
+ fq_min = layer.add_weight(
185
+ name + FQ_MIN,
186
+ shape=(),
187
+ initializer=tf.keras.initializers.Constant(-1.0),
188
+ trainable=False)
189
+ fq_min.assign(self.min_range)
190
+
191
+ fq_max = layer.add_weight(
192
+ name + FQ_MAX,
193
+ shape=(),
194
+ initializer=tf.keras.initializers.Constant(1.0),
195
+ trainable=False)
196
+ fq_max.assign(self.max_range)
197
+
198
+ # save the quantizer added parameters for later calculations
199
+ self.add_quantizer_variable(FQ_MIN, fq_min, VariableGroup.QPARAMS)
200
+ self.add_quantizer_variable(FQ_MAX, fq_max, VariableGroup.QPARAMS)
201
+
202
+
203
+ def __call__(self,
204
+ inputs: tf.Tensor,
205
+ training: bool):
206
+ """
207
+ Quantize a tensor.
208
+ Args:
209
+ inputs: Input tensor to quantize.
210
+ training: Whether the graph is in training mode.
211
+
212
+ Returns:
213
+ The quantized tensor.
214
+ """
215
+
216
+ _min = self.get_quantizer_variable(FQ_MIN)
217
+ _max = self.get_quantizer_variable(FQ_MAX)
218
+ _min, _max = adjust_range_to_include_zero(_min, _max, self.num_bits)
219
+ q_tensor = tf.quantization.fake_quant_with_min_max_vars(inputs, _min, _max,
220
+ num_bits=self.num_bits)
221
+
222
+ return q_tensor
223
+
224
+ def convert2inferable(self) -> BaseKerasInferableQuantizer:
225
+ """
226
+ Convert quantizer to inferable quantizer.
227
+
228
+ Returns:
229
+ BaseKerasInferableQuantizer object.
230
+ """
231
+ min_range, max_range = fix_range_to_include_zero(self.get_quantizer_variable(FQ_MIN).numpy(),
232
+ self.get_quantizer_variable(FQ_MAX).numpy(),
233
+ self.num_bits)
234
+ return ActivationUniformInferableQuantizer(num_bits=self.num_bits,
235
+ # In activation quantization is per-tensor only - thus we pass
236
+ # the min/max as lists with a len of 1
237
+ min_range=[min_range],
238
+ max_range=[max_range])