mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -0,0 +1,106 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numpy as np
17
+
18
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer \
21
+ import mark_quantizer, \
22
+ QuantizationTarget
23
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants \
24
+ import MULTIPLIER_N_BITS, EPS
25
+
26
+ if FOUND_TORCH:
27
+ import torch
28
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizer_utils \
29
+ import to_torch_tensor, get_working_device, lut_quantizer
30
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers \
31
+ .base_lut_symmetric_inferable_quantizer import BaseLUTSymmetricInferableQuantizer
32
+
33
+
34
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
35
+ quantization_method=[QuantizationMethod.LUT_SYM_QUANTIZER],
36
+ quantizer_type=None)
37
+ class WeightsLUTSymmetricInferableQuantizer(BaseLUTSymmetricInferableQuantizer):
38
+ """
39
+ Class for quantizing weights using a lut symmetric quantizer
40
+ """
41
+
42
+ def __init__(self,
43
+ num_bits: int,
44
+ cluster_centers: np.ndarray,
45
+ threshold: np.ndarray,
46
+ per_channel: bool,
47
+ channel_axis: int = None,
48
+ multiplier_n_bits: int = MULTIPLIER_N_BITS,
49
+ eps: float = EPS):
50
+ """
51
+ Initialize the quantizer with the specified parameters.
52
+
53
+ Args:
54
+ num_bits: number of bits to use for quantization
55
+ cluster_centers: the cluster centers to assign the weights
56
+ threshold: threshold for quantizing weights
57
+ per_channel: whether to use per-channel quantization
58
+ channel_axis: Axis of input to apply per-channel quantization on
59
+ multiplier_n_bits: Number of bits that determines the quantization range
60
+ eps: Small value for numerical stability in division
61
+ """
62
+
63
+ super(WeightsLUTSymmetricInferableQuantizer, self).__init__(threshold=threshold,
64
+ num_bits=num_bits,
65
+ cluster_centers=cluster_centers,
66
+ signed=True,
67
+ multiplier_n_bits=multiplier_n_bits,
68
+ eps=eps)
69
+
70
+ if per_channel:
71
+ assert channel_axis is not None, f'Channel axis is missing in per channel quantization'
72
+ assert len(
73
+ threshold) >= 1, f'In per-channel quantization threshold should be of length >= 1 but is ' \
74
+ f'{len(threshold)}'
75
+ else:
76
+ assert len(
77
+ threshold) == 1, f'In per-tensor quantization threshold should be of length 1 but is ' \
78
+ f'{len(threshold)}'
79
+
80
+ self.per_channel = per_channel
81
+ self.channel_axis = channel_axis
82
+
83
+ self.threshold = to_torch_tensor(self.threshold).to(get_working_device())
84
+ self.cluster_centers = to_torch_tensor(self.cluster_centers).to(get_working_device())
85
+
86
+ def __call__(self, inputs: torch.Tensor) -> torch.Tensor:
87
+ """
88
+ Quantize the given inputs using the quantizer parameters.
89
+
90
+ Args:
91
+ inputs: input tensor to quantize
92
+
93
+ Returns:
94
+ quantized tensor.
95
+ """
96
+ inputs.requires_grad = False
97
+ return lut_quantizer(inputs, cluster_centers=self.cluster_centers, signed=True,
98
+ threshold=self.threshold, multiplier_n_bits=self.multiplier_n_bits, eps=self.eps)
99
+
100
+
101
+ else:
102
+ class WeightsLUTSymmetricInferableQuantizer: # pragma: no cover
103
+ def __init__(self, *args, **kwargs):
104
+ raise Exception('Installing torch is mandatory '
105
+ 'when using WeightsLUTSymmetricInferableQuantizer. '
106
+ 'Could not find torch package.')
@@ -0,0 +1,66 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numpy as np
17
+
18
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
21
+ QuantizationTarget
22
+
23
+ if FOUND_TORCH:
24
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.weights_inferable_quantizers.weights_symmetric_inferable_quantizer import \
25
+ WeightsSymmetricInferableQuantizer
26
+
27
+
28
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
29
+ quantization_method=[QuantizationMethod.POWER_OF_TWO],
30
+ quantizer_type=None)
31
+ class WeightsPOTInferableQuantizer(WeightsSymmetricInferableQuantizer):
32
+ """
33
+ Class for quantizing weights using power-of-two quantizer
34
+ """
35
+
36
+ def __init__(self,
37
+ num_bits: int,
38
+ threshold: np.ndarray,
39
+ per_channel: bool,
40
+ channel_axis: int = None
41
+ ):
42
+ """
43
+ Initialize the quantizer with the specified parameters.
44
+
45
+ Args:
46
+ num_bits: number of bits to use for quantization
47
+ threshold: threshold for quantizing activations
48
+ per_channel: whether to use per-channel quantization
49
+ channel_axis: Axis of input to apply per-channel quantization on.
50
+ """
51
+ # target of Weights quantization
52
+ super(WeightsPOTInferableQuantizer, self).__init__(num_bits=num_bits,
53
+ threshold=threshold,
54
+ per_channel=per_channel,
55
+ channel_axis=channel_axis)
56
+
57
+ is_threshold_pot = np.all(np.round(np.log2(threshold.flatten()))==np.log2(threshold.flatten()))
58
+ assert is_threshold_pot, f'Expected threshold to be power of 2 but is {threshold}'
59
+
60
+
61
+ else:
62
+ class WeightsPOTInferableQuantizer: # pragma: no cover
63
+ def __init__(self, *args, **kwargs):
64
+ raise Exception('Installing torch is mandatory '
65
+ 'when using WeightsPOTInferableQuantizer. '
66
+ 'Could not find torch package.')
@@ -0,0 +1,104 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numpy as np
17
+
18
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
21
+ QuantizationTarget
22
+
23
+ if FOUND_TORCH:
24
+ import torch
25
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizer_utils import to_torch_tensor, \
26
+ get_working_device
27
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.base_symmetric_inferable_quantizer import \
28
+ BaseSymmetricInferableQuantizer
29
+
30
+
31
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
32
+ quantization_method=[QuantizationMethod.SYMMETRIC],
33
+ quantizer_type=None)
34
+ class WeightsSymmetricInferableQuantizer(BaseSymmetricInferableQuantizer):
35
+ """
36
+ Class for quantizing weights using a symmetric quantizer
37
+ """
38
+
39
+ def __init__(self,
40
+ num_bits: int,
41
+ threshold: np.ndarray,
42
+ per_channel: bool,
43
+ channel_axis: int = None
44
+ ):
45
+ """
46
+ Initialize the quantizer with the specified parameters.
47
+
48
+ Args:
49
+ num_bits: number of bits to use for quantization
50
+ threshold: threshold for quantizing weights
51
+ per_channel: whether to use per-channel quantization
52
+ channel_axis: Axis of input to apply per-channel quantization on.
53
+ """
54
+
55
+ super(WeightsSymmetricInferableQuantizer, self).__init__(threshold=threshold,
56
+ num_bits=num_bits,
57
+ signed=True)
58
+
59
+ if per_channel:
60
+ assert channel_axis is not None, f'Channel axis is missing in per channel quantization'
61
+ assert len(
62
+ threshold) >= 1, f'In per-channel quantization threshold should be of length >= 1 but is ' \
63
+ f'{len(threshold)}'
64
+ else:
65
+ assert len(
66
+ threshold) == 1, f'In per-tensor quantization threshold should be of length 1 but is {len(threshold)}'
67
+
68
+ self.per_channel = per_channel
69
+ self.channel_axis = channel_axis
70
+
71
+ self.scales = to_torch_tensor(self.scales).to(get_working_device())
72
+ self.zero_points = torch.zeros(len(threshold), dtype=torch.int32).to(get_working_device())
73
+
74
+ def __call__(self, inputs: torch.Tensor) -> torch.Tensor:
75
+ """
76
+ Quantize the given inputs using the quantizer parameters.
77
+
78
+ Args:
79
+ inputs: input tensor to quantize
80
+
81
+ Returns:
82
+ quantized tensor.
83
+ """
84
+ inputs.requires_grad = False
85
+ if self.per_channel:
86
+ return torch.fake_quantize_per_channel_affine(inputs,
87
+ self.scales,
88
+ self.zero_points,
89
+ axis=self.channel_axis,
90
+ quant_min=self.min_quantized_domain,
91
+ quant_max=self.max_quantized_domain)
92
+ return torch.fake_quantize_per_tensor_affine(inputs,
93
+ self.scales,
94
+ self.zero_points,
95
+ quant_min=self.min_quantized_domain,
96
+ quant_max=self.max_quantized_domain)
97
+
98
+
99
+ else:
100
+ class WeightsSymmetricInferableQuantizer: # pragma: no cover
101
+ def __init__(self, *args, **kwargs):
102
+ raise Exception('Installing torch is mandatory '
103
+ 'when using WeightsSymmetricInferableQuantizer. '
104
+ 'Could not find torch package.')
@@ -0,0 +1,109 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numpy as np
17
+
18
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
+ from model_compression_toolkit.core.common.logger import Logger
20
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
21
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget, \
22
+ mark_quantizer
23
+
24
+ if FOUND_TORCH:
25
+ import torch
26
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizer_utils import get_working_device, \
27
+ fix_range_to_include_zero, to_torch_tensor
28
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.base_uniform_inferable_quantizer import \
29
+ BaseUniformInferableQuantizer
30
+
31
+
32
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
33
+ quantization_method=[QuantizationMethod.UNIFORM],
34
+ quantizer_type=None)
35
+ class WeightsUniformInferableQuantizer(BaseUniformInferableQuantizer):
36
+ """
37
+ Class for quantizing weights using a uniform quantizer
38
+ """
39
+
40
+ def __init__(self,
41
+ num_bits: int,
42
+ min_range: np.ndarray,
43
+ max_range: np.ndarray,
44
+ per_channel: bool,
45
+ channel_axis: int = None
46
+ ):
47
+ """
48
+ Initialize the quantizer with the specified parameters.
49
+
50
+ Args:
51
+ num_bits: number of bits to use for quantization
52
+ min_range: min quantization range for quantizing weights
53
+ max_range: max quantization range for quantizing weights
54
+ per_channel: whether to use per-channel quantization
55
+ channel_axis: Axis of input to apply per-channel quantization on.
56
+ """
57
+ super(WeightsUniformInferableQuantizer, self).__init__(num_bits=num_bits,
58
+ min_range=min_range,
59
+ max_range=max_range)
60
+
61
+ # Align mix/max numpy arrays so they are torch Tensors on the working device
62
+ min_range = to_torch_tensor(min_range).to(get_working_device())
63
+ max_range = to_torch_tensor(max_range).to(get_working_device())
64
+
65
+ self.per_channel = per_channel
66
+ self.channel_axis = channel_axis
67
+
68
+ min_range, max_range = fix_range_to_include_zero(min_range,
69
+ max_range,
70
+ num_bits)
71
+ # Compute the step size of quantized values.
72
+ self.scales = (max_range - min_range) / (2 ** num_bits - 1)
73
+ self.zero_points = -(
74
+ min_range / self.scales).int() # zp has to be positive, and a <=0, so we multiply by -1
75
+
76
+ self.scales = self.scales.to(get_working_device())
77
+ self.zero_points = self.zero_points.to(get_working_device())
78
+
79
+ def __call__(self,
80
+ inputs: torch.Tensor) -> torch.Tensor:
81
+ """
82
+ Weight fake quantizer
83
+ Args:
84
+ inputs: weights to quantize.
85
+
86
+ Returns:
87
+ quantized weights
88
+ """
89
+ inputs.requires_grad = False
90
+ if self.per_channel:
91
+ return torch.fake_quantize_per_channel_affine(inputs,
92
+ self.scales.flatten(),
93
+ self.zero_points.flatten(),
94
+ axis=self.channel_axis,
95
+ quant_min=self.min_quantized_domain,
96
+ quant_max=self.max_quantized_domain)
97
+ return torch.fake_quantize_per_tensor_affine(inputs,
98
+ self.scales,
99
+ self.zero_points,
100
+ quant_min=self.min_quantized_domain,
101
+ quant_max=self.max_quantized_domain)
102
+
103
+
104
+ else:
105
+ class WeightsUniformInferableQuantizer: # pragma: no cover
106
+ def __init__(self, *args, **kwargs):
107
+ Logger.error('Installing torch is mandatory '
108
+ 'when using WeightsUniformInferableQuantizer. '
109
+ 'Could not find torch package.')
@@ -0,0 +1,14 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,14 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,200 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from abc import abstractmethod
16
+ from enum import Enum
17
+ from typing import Union, List, Any
18
+ from inspect import signature
19
+
20
+ from model_compression_toolkit.core import common
21
+ from model_compression_toolkit.core.common import Logger
22
+
23
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer, \
24
+ QuantizationTarget
25
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
26
+ TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig
27
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import QUANTIZATION_METHOD, \
28
+ QUANTIZATION_TARGET
29
+
30
+
31
+ VAR = 'var'
32
+ GROUP = 'group'
33
+
34
+ class VariableGroup(Enum):
35
+ """
36
+ An enum for choosing trainable variable group
37
+ 0. WEIGHTS
38
+ 1. QPARAMS
39
+ """
40
+ WEIGHTS = 0
41
+ QPARAMS = 1
42
+
43
+
44
+ class BaseTrainableQuantizer(BaseInferableQuantizer):
45
+ def __init__(self,
46
+ quantization_config: Union[TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig]):
47
+ """
48
+ This class is a base quantizer which validates the provided quantization config and defines an abstract function which any quantizer needs to implment.
49
+
50
+ Args:
51
+ quantization_config: quantizer config class contains all the information about the quantizer configuration.
52
+ """
53
+
54
+ # verify the quantizer class that inherits this class only has a config argument and key-word arguments
55
+ for i, (k, v) in enumerate(self.get_sig().parameters.items()):
56
+ if i == 0:
57
+ if v.annotation not in [TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]:
58
+ common.Logger.error(f"First parameter must be either TrainableQuantizerWeightsConfig or TrainableQuantizerActivationConfig") # pragma: no cover
59
+ elif v.default is v.empty:
60
+ common.Logger.error(f"Parameter {k} doesn't have a default value") # pragma: no cover
61
+
62
+ super(BaseTrainableQuantizer, self).__init__()
63
+ self.quantization_config = quantization_config
64
+
65
+ # Inherited class should be decorated with @mark_quantizer decorator, and define the following static properties
66
+ static_quantization_method = getattr(self, QUANTIZATION_METHOD, None)
67
+ static_quantization_target = getattr(self, QUANTIZATION_TARGET, None)
68
+
69
+ if static_quantization_method is None or static_quantization_target is None:
70
+ Logger.error("A quantizer class that inherit from BaseTrainableQuantizer is not defined appropriately."
71
+ "Either it misses the @mark_quantizer decorator or the decorator is not used correctly.")
72
+
73
+ if static_quantization_target == QuantizationTarget.Weights:
74
+ self.validate_weights()
75
+ if self.quantization_config.weights_quantization_method not in static_quantization_method:
76
+ common.Logger.error(
77
+ f'Quantization method mismatch expected: {static_quantization_method} and got {self.quantization_config.weights_quantization_method}')
78
+ elif static_quantization_target == QuantizationTarget.Activation:
79
+ self.validate_activation()
80
+ if self.quantization_config.activation_quantization_method not in static_quantization_method:
81
+ common.Logger.error(
82
+ f'Quantization method mismatch expected: {static_quantization_method} and got {self.quantization_config.activation_quantization_method}')
83
+ else:
84
+ common.Logger.error(
85
+ f'Unknown Quantization Part:{static_quantization_target}') # pragma: no cover
86
+
87
+ self.quantizer_parameters = {}
88
+
89
+ @classmethod
90
+ def get_sig(cls):
91
+ return signature(cls)
92
+
93
+ def initialize_quantization(self,
94
+ tensor_shape,
95
+ name: str,
96
+ layer):
97
+ """
98
+ This initializes the quantizer parameters given the parameter name and shape.
99
+
100
+ Args:
101
+ tensor_shape: tensor shape
102
+ name: tensor name
103
+ layer: layer to quantized
104
+
105
+ Returns: None
106
+
107
+ """
108
+ raise NotImplemented # pragma: no cover
109
+
110
+ def __call__(self,
111
+ input2quantize,
112
+ training: bool):
113
+ """
114
+ Quantize a tensor.
115
+
116
+ Args:
117
+ input2quantize: Input tensor to quantize.
118
+ training: Whether the graph is in training mode.
119
+
120
+ Returns:
121
+ The quantized tensor.
122
+ """
123
+ raise NotImplemented # pragma: no cover
124
+
125
+ def activation_quantization(self) -> bool:
126
+ """
127
+
128
+ Returns: A boolean stating is this activation quantizer
129
+
130
+ """
131
+ return isinstance(self.quantization_config, TrainableQuantizerActivationConfig)
132
+
133
+ def weights_quantization(self) -> bool:
134
+ """
135
+
136
+ Returns: A boolean stating is this weights quantizer
137
+
138
+ """
139
+ return isinstance(self.quantization_config, TrainableQuantizerWeightsConfig)
140
+
141
+ def validate_weights(self) -> None:
142
+ """
143
+ This function validates the quantization config compared with its parameters.
144
+
145
+
146
+ """
147
+ if self.activation_quantization() or not self.weights_quantization():
148
+ common.Logger.error(f'Expect weight quantization got activation')
149
+
150
+ def validate_activation(self) -> None:
151
+ """
152
+ This function validates the quantization config compared with its parameters.
153
+
154
+ """
155
+ if not self.activation_quantization() or self.weights_quantization():
156
+ common.Logger.error(f'Expect activation quantization got weight')
157
+
158
+ def convert2inferable(self) -> BaseInferableQuantizer:
159
+ """
160
+ Convert quantizer to inferable quantizer.
161
+
162
+ Returns:
163
+ BaseInferableQuantizer object.
164
+ """
165
+ raise NotImplemented # pragma: no cover
166
+
167
+ def add_quantizer_variable(self, name: str, variable: Any, group: VariableGroup = VariableGroup.WEIGHTS):
168
+ """
169
+ Add a quantizer variable to quantizer_parameters dictionary
170
+ """
171
+ self.quantizer_parameters.update({name: {VAR: variable, GROUP: group}})
172
+
173
+ def get_quantizer_variable(self, name: str) -> Any:
174
+ """
175
+ Get a quantizer variable by name
176
+
177
+ Args:
178
+ name: variable name
179
+
180
+ Returns:
181
+ trainable variable
182
+ """
183
+ if name in self.quantizer_parameters:
184
+ return self.quantizer_parameters[name][VAR]
185
+ else:
186
+ common.Logger.error(f'Variable {name} is not exist in quantizers parameters!') # pragma: no cover
187
+
188
+
189
+ @abstractmethod
190
+ def get_trainable_variables(self, group: VariableGroup) -> List[Any]:
191
+ """
192
+ Get trainable parameters with specific group from quantizer
193
+
194
+ Args:
195
+ group: Enum of variable group
196
+
197
+ Returns:
198
+ List of trainable variables
199
+ """
200
+ raise NotImplemented # pragma: no cover