mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -1,263 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from typing import Dict, Any, List
16
-
17
- import numpy as np
18
- import tensorflow as tf
19
-
20
- from model_compression_toolkit import GumbelConfig
21
- from model_compression_toolkit.core.keras.quantizer.base_quantizer import BaseTrainableQuantizer
22
- from model_compression_toolkit.core.common.defaultdict import DefaultDict
23
- from model_compression_toolkit.core import common
24
- from model_compression_toolkit.gptq.keras.quantizer import kernel_functions
25
- from model_compression_toolkit.gptq.keras.quantizer.gumbel_rounding.gumbel_softmax import sample_gumbel
26
- from model_compression_toolkit.gptq.common import gptq_constants
27
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
28
- from tensorflow.python.framework.tensor_shape import TensorShape
29
- from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
30
-
31
- P_INIT = 0.01
32
- GR_SHIFT_BASE = 2
33
-
34
-
35
- def init_aux_var(ceil_indicator: np.ndarray, w_shape: List[int], m: int, p: float = P_INIT) -> np.ndarray:
36
- """
37
- This function generate a random pi matrix for Gumbel Rounding such that the search start at the rounding point.
38
- Args:
39
- ceil_indicator: An array of indicator if the value should be ceil or floor.
40
- w_shape(List[int]): A list of integers that represent the shape of the weights tensor to be quantization.
41
- p(float): A floating point number that represent the probability of non round options of pi matrix.
42
- m(int): An integer that define the number of shift.
43
-
44
- Returns: A numpy array of pi tensor
45
-
46
- """
47
- if m < 2:
48
- common.logger.Logger.error("m must be larger than two")
49
- if m % 2 != 0:
50
- common.logger.Logger.error("m must be module two")
51
- m_hat = m // 2 - 1
52
- shift = -np.log(-np.log(1 - p))
53
- n = np.random.randn(*[m, *w_shape]) * np.sqrt(np.power(np.pi, 2) / 6)
54
- n = n.reshape([m, -1]).T
55
- ceil_indicator = ceil_indicator.flatten()
56
- n[np.arange(ceil_indicator.size), ceil_indicator + m_hat] += shift
57
- n = n.T.reshape(*[m, *w_shape])
58
- return n
59
-
60
-
61
- def _init_shift_var(m: int) -> List[int]:
62
- """
63
- This function generate an list of 2*m+1 from -m to m
64
- Args:
65
- m: An integer value the represent m
66
-
67
- Returns: A list of size m
68
-
69
- """
70
- m_hat = m // 2
71
- aux_index_shift = [-m_hat + 1 + i for i in range(m)]
72
- return aux_index_shift
73
-
74
-
75
- class GumbelRoundingBase(BaseTrainableQuantizer):
76
- def __init__(self,
77
- num_bits: int,
78
- per_axis: bool,
79
- signed: bool,
80
- symmetric: bool,
81
- power_of_two: bool,
82
- quantization_parameter_learning: bool,
83
- quantization_axis: int,
84
- gumbel_config: GumbelConfig,
85
- max_lsbs_change_map: dict = DefaultDict({}, lambda: 1),
86
- max_iteration: int = 10000):
87
- """
88
- A base class for GumRounding
89
-
90
- Args:
91
- num_bits: Number of bits to use for the quantization.
92
- per_axis: Whether to quantize per-channel or per-tensor.
93
- signed: Signedness to use for the quantization range.
94
- symmetric: Whether to quantize is symmetric.
95
- power_of_two: Whether to quantize is power-of-two.
96
- quantization_parameter_learning: A bool flag state if the quantizer parameter are trainable
97
- quantization_axis: Axis of tensor to use for the quantization.
98
- gumbel_config: A class with the gumbel rounding configurations.
99
- max_lsbs_change_map: a mapping between number of bits to max lsb change.
100
- max_iteration: The number of iteration of gptq.
101
- """
102
- self.num_bits = num_bits
103
- self.per_axis = per_axis
104
- self.signed = signed
105
- self.quantization_axis = quantization_axis
106
- self.max_iteration = max_iteration
107
- self.power_of_two = power_of_two
108
- self.symmetric = symmetric
109
- self.quantization_parameter_learning = quantization_parameter_learning
110
- self.temperature_learning = gumbel_config.temperature_learning
111
- self.quantizer_parameters = {}
112
- self.gumbel_config = gumbel_config
113
-
114
- self.max_lsbs_change_map = max_lsbs_change_map
115
- self.max_lsbs_change = max_lsbs_change_map.get(num_bits)
116
- self.m = GR_SHIFT_BASE * self.max_lsbs_change + GR_SHIFT_BASE
117
-
118
- self.n_cycles = gumbel_config.n_cycles
119
- self.minimal_temp = gumbel_config.minimal_temp
120
- self.maximal_temp = gumbel_config.maximal_temp
121
- self.cycle_iterations = max(1, int(self.max_iteration / self.n_cycles))
122
- self.tau = None
123
- self.g_t = None
124
- self.p_t = None
125
- scale = self.cycle_iterations / (-2 * np.log(0.001))
126
-
127
- self.gumbel_scale = gumbel_config.gumbel_scale
128
- self.gumbel_scale_per_bitwidth = gumbel_config.gumbel_scale_per_bitwidth
129
-
130
- def tau_function(i):
131
- """
132
- A function the generate the gumbel temperature.
133
- Args:
134
- i: An int the represent the current iteration number
135
-
136
- Returns: A temperature value.
137
-
138
- """
139
- if i < (self.cycle_iterations - 1):
140
- index = ((i + 1) % self.cycle_iterations) / scale
141
- else:
142
- index = (i % self.cycle_iterations) / scale
143
-
144
- x = tf.exp(-index)
145
- return self.minimal_temp + (self.maximal_temp - self.minimal_temp) * x
146
-
147
- self.tau_function = tau_function
148
- self.w_shape = None
149
- self.update_gumbel_param = True
150
-
151
- def enable_update(self):
152
- self.update_gumbel_param = True
153
-
154
- def disable_update(self):
155
- self.update_gumbel_param = False
156
-
157
- def build(self, tensor_shape: TensorShape,
158
- name: str,
159
- layer: QuantizeWrapper) -> Dict[str, tf.Variable]:
160
- """
161
- Add min and max variables to layer.
162
- Args:
163
- tensor_shape: Tensor shape the quantizer quantize.
164
- name: Prefix of variables names.
165
- layer: Layer to add the variables to. The variables are saved
166
- in the layer's scope.
167
-
168
- Returns:
169
- Dictionary of new variables.
170
- """
171
- w_shape = kernel_functions.get_kernel(layer.weights).shape
172
- self.w_shape = w_shape
173
-
174
- ar_iter = layer.add_weight(
175
- name + gptq_constants.GPTQ_ITER,
176
- shape=(),
177
- initializer=tf.keras.initializers.Constant(0.0),
178
- trainable=False)
179
-
180
- temp_tensor = layer.add_weight(
181
- name + gptq_constants.TEMP,
182
- shape=[1, *self.w_shape],
183
- initializer=tf.keras.initializers.Constant(self.maximal_temp),
184
- trainable=True)
185
-
186
- shift_tensor = layer.add_weight(name + gptq_constants.AUXSHIFT,
187
- shape=self.m,
188
- initializer=tf.keras.initializers.Constant(0.0),
189
- trainable=False)
190
- shift_tensor.assign(_init_shift_var(self.m))
191
-
192
- self.quantizer_parameters = {gptq_constants.GPTQ_ITER: ar_iter,
193
- gptq_constants.AUXSHIFT: shift_tensor,
194
- gptq_constants.TEMP: temp_tensor}
195
- return self.quantizer_parameters
196
-
197
- def get_aux_variable(self) -> tf.Tensor:
198
- return self.quantizer_parameters[gptq_constants.AUXVAR]
199
-
200
- def get_trainable_parameters(self) -> List[tf.Tensor]:
201
- """
202
- A function to get a list trainable of trainable parameters of the quantizer for GPTQ retraining
203
-
204
- Returns:
205
- A list of trainable Tensors
206
-
207
- """
208
- return [t for t in self.quantizer_parameters.values() if t.trainable]
209
-
210
- def __eq__(self, other: Any) -> bool:
211
- """
212
- Check if equals to another object.
213
- Args:
214
- other: Other object to compare.
215
-
216
- Returns:
217
- Whether they are equal or not.
218
- """
219
- if not isinstance(other, GumbelRoundingBase):
220
- return False
221
-
222
- return (self.num_bits == other.num_bits and
223
- self.per_axis == other.per_axis and
224
- self.symmetric == other.symmetric)
225
-
226
- def __ne__(self, other: Any) -> bool:
227
- """
228
- Check if not equals to another object.
229
- Args:
230
- other: Other object to compare.
231
-
232
- Returns:
233
- Whether they are differ or not.
234
- """
235
- return not self.__eq__(other)
236
-
237
- def get_config(self) -> Dict[str, Any]:
238
- """
239
- Returns: Configuration of TrainableQuantizer.
240
- """
241
-
242
- return {
243
- 'num_bits': self.num_bits,
244
- 'per_axis': self.per_axis,
245
- 'symmetric': self.symmetric,
246
- 'power_of_two': self.power_of_two
247
- }
248
-
249
- def update_iteration(self, training, ar_iter):
250
- if self.temperature_learning:
251
- self.tau = qutils.ste_clip(self.quantizer_parameters[gptq_constants.TEMP], self.maximal_temp,
252
- self.minimal_temp)
253
- else:
254
- self.tau = self.tau_function(ar_iter)
255
- if self.update_gumbel_param and training:
256
- ar_iter.assign_add(1.0)
257
- self.g_t = sample_gumbel([self.m, *self.w_shape])
258
-
259
- def get_temperature_variable(self):
260
- return self.quantizer_parameters[gptq_constants.TEMP]
261
-
262
- def get_gumbel_probability(self):
263
- return self.p_t
@@ -1,75 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import tensorflow as tf
16
-
17
-
18
- def sample_gumbel(shape, eps=1e-6) -> tf.Tensor:
19
- """
20
- A function that sample a tensor of i.i.d gumbel random variable.
21
- Args:
22
- shape: The tensor output shape
23
- eps: A small number for numeric stability.
24
-
25
- Returns: A tensor of i.i.d gumbel random variable.
26
-
27
- """
28
- u = tf.random.uniform(shape)
29
- return -tf.math.log(-tf.math.log(u + eps) + eps)
30
-
31
-
32
- def gumbel_softmax(in_pi: tf.Tensor, in_tau: tf.Tensor, in_gumbel: tf.Tensor, eps: float = 1e-6, axis=0,
33
- gumbel_scale: float = 1.0) -> tf.Tensor:
34
- """
35
- A gumbel softmax function.
36
- Args:
37
- in_pi: A tensor of log probability.
38
- in_tau: A temperature tensor.
39
- in_gumbel: A tensor of gumbel random variable.
40
- eps: A small number for numeric stability.
41
- axis: A integer representing the axis of which the gumbel softmax applyed on.
42
- gumbel_scale: A normalization factor for the gumbel tensor values
43
-
44
- Returns: A gumbel softmax probability tensor.
45
-
46
- """
47
- return tf.nn.softmax((tf.nn.log_softmax(in_pi, axis=axis) + gumbel_scale * in_gumbel) / (in_tau + eps), axis=axis)
48
-
49
-
50
- def ste_gumbel(in_prob: tf.Tensor) -> tf.Tensor:
51
- """
52
- This function apply ste on the output of the gumbel softmax.
53
- Args:
54
- in_prob:A tensor of probability
55
-
56
- Returns: A Tensor of ohe hot vector with STE.
57
-
58
- """
59
-
60
- delta = tf.stop_gradient(select_gumbel(in_prob) - in_prob)
61
- return in_prob + delta
62
-
63
-
64
- def select_gumbel(in_prob: tf.Tensor) -> tf.Tensor:
65
- """
66
- This function apply ste on the output of the gumbel softmax.
67
- Args:
68
- in_prob: A tensor of probability.
69
-
70
- Returns: A Tensor of ohe hot vector
71
-
72
- """
73
- max_index = tf.argmax(in_prob, axis=0)
74
- one_hot_prob = tf.one_hot(max_index, depth=in_prob.shape[0], axis=0)
75
- return one_hot_prob + 0 * in_prob
@@ -1,266 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import tensorflow as tf
16
- import numpy as np
17
-
18
- from model_compression_toolkit import GumbelConfig
19
- from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
20
- from model_compression_toolkit.gptq.keras.quantizer.gumbel_rounding.base_gumbel_rounding import GumbelRoundingBase, \
21
- init_aux_var
22
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
23
- from tensorflow.python.framework.tensor_shape import TensorShape
24
- from model_compression_toolkit.core.common.defaultdict import DefaultDict
25
- from typing import Dict, Any, List
26
- from model_compression_toolkit.gptq.keras.quantizer.gumbel_rounding.gumbel_softmax import gumbel_softmax, ste_gumbel
27
- from model_compression_toolkit.core.common.constants import THRESHOLD, GUMBEL_MAX_ITER, MIN_THRESHOLD
28
- from model_compression_toolkit.gptq.common import gptq_constants
29
- from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import max_power_of_two
30
- from model_compression_toolkit.gptq.keras.quantizer.ste_rounding.symmetric_ste import symmetric_quantizer
31
-
32
-
33
- def gumbel_rounding_symmetric_quantizer(input_tensor: tf.Tensor,
34
- auxvar_tensor: tf.Variable,
35
- max_tensor: tf.Tensor,
36
- num_bits: int,
37
- signed: bool,
38
- power_of_two: bool) -> tf.Tensor:
39
- """
40
- Quantize a tensor symmetrically with maximum LSBs shift.
41
- Args:
42
- input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
43
- auxvar_tensor: Tensor that manifests the bit shift the weight due to gptq.
44
- max_tensor: Tensor with max values to compute the threshold.
45
- num_bits: Num of bits to use.
46
- signed: Signedness of the quantization range.
47
- power_of_two: Whether the threshold should be constrained or not.
48
-
49
- Returns:
50
- A quantized tensor.
51
- """
52
-
53
- if power_of_two:
54
- max_tensor = qutils.power_of_two_max(max_tensor)
55
- delta = qutils.calculate_delta(max_tensor, num_bits, signed)
56
- input_tensor = tf.stop_gradient(input_tensor)
57
- input_tensor_int = tf.floor(input_tensor / delta)
58
- tensor_q = input_tensor_int + auxvar_tensor
59
- min_int = -int(signed) * (2 ** (num_bits - int(signed)))
60
- max_int = (2 ** (num_bits - int(signed))) - 1
61
- return delta * qutils.clip(tensor_q, max_val=max_int, min_val=min_int)
62
-
63
-
64
- class SymmetricGumbelRounding(GumbelRoundingBase):
65
- """
66
- Trainable constrained quantizer to quantize a layer inputs.
67
- """
68
- PTQ_THRESHOLD = "_ptq_threshold"
69
- SCALE_PTQ = "_scale"
70
-
71
- def __init__(self, num_bits: int,
72
- per_axis: bool,
73
- signed: bool,
74
- power_of_two: bool,
75
- quantization_parameter_learning: bool,
76
- threshold_values: np.ndarray,
77
- gumbel_config: GumbelConfig,
78
- quantization_axis: int = -1,
79
- max_lsbs_change_map: dict = DefaultDict({}, lambda: 1),
80
- max_iteration: int = GUMBEL_MAX_ITER):
81
- """
82
- Initialize a TrainableWeightQuantizer object with parameters to use
83
- for the quantization.
84
-
85
- Args:
86
- num_bits: Number of bits to use for the quantization.
87
- per_axis: Whether to quantize per-channel or per-tensor.
88
- signed: Signedness to use for the quantization range.
89
- threshold_values: Threshold to use for the quantization.
90
- gumbel_config: A class with the gumbel rounding configurations.
91
- quantization_axis: Axis of tensor to use for the quantization.
92
- power_of_two: Whether the threshold should be constrained or not.
93
- max_lsbs_change_map: a mapping between number of bits to max lsb change.
94
- max_iteration: The number of iteration of gptq.
95
- """
96
- super().__init__(num_bits, per_axis, signed, True, power_of_two, quantization_parameter_learning,
97
- quantization_axis, gumbel_config,
98
- max_lsbs_change_map,
99
- max_iteration)
100
- self.threshold_shape = np.asarray(threshold_values).shape
101
- self.threshold_values = np.reshape(np.asarray(threshold_values), [-1]) if self.per_axis else float(
102
- threshold_values)
103
- self.k_threshold = len(self.threshold_values) if self.per_axis else 1
104
-
105
- def build(self,
106
- tensor_shape: TensorShape,
107
- name: str,
108
- layer: QuantizeWrapper) -> Dict[str, tf.Variable]:
109
- """
110
- Add min and max variables to layer.
111
- Args:
112
- tensor_shape: Tensor shape the quantizer quantize.
113
- name: Prefix of variables names.
114
- layer: Layer to add the variables to. The variables are saved
115
- in the layer's scope.
116
-
117
- Returns:
118
- Dictionary of new variables.
119
- """
120
- super().build(tensor_shape, name, layer)
121
-
122
- if self.per_axis:
123
- input_shape = tensor_shape
124
- n_axis = len(input_shape)
125
- quantization_axis = n_axis + self.quantization_axis if self.quantization_axis < 0 else \
126
- self.quantization_axis
127
- reshape_shape = [self.k_threshold if i == quantization_axis else 1 for i in range(n_axis)]
128
- else:
129
- reshape_shape = [self.k_threshold]
130
-
131
- ptq_threshold_tensor = layer.add_weight(
132
- name + self.PTQ_THRESHOLD,
133
- shape=reshape_shape,
134
- initializer=tf.keras.initializers.Constant(1.0),
135
- trainable=False)
136
- ptq_threshold_tensor.assign(self.threshold_values.reshape(reshape_shape))
137
-
138
- auxvar_tensor = layer.add_weight(
139
- name + gptq_constants.AUXVAR,
140
- shape=[self.m, *self.w_shape],
141
- initializer=tf.keras.initializers.Constant(0.0),
142
- trainable=True)
143
- w = getattr(layer.layer, name)
144
-
145
- q_error = w - symmetric_quantizer(w,
146
- ptq_threshold_tensor,
147
- num_bits=self.num_bits,
148
- signed=True,
149
- power_of_two=self.power_of_two)
150
-
151
- ceil_indicator = (q_error < 0).numpy().astype("int") # Negative error means the choose point is rounded to ceil.
152
- auxvar_tensor.assign(init_aux_var(ceil_indicator, self.w_shape, self.m))
153
-
154
- self.quantizer_parameters.update({gptq_constants.AUXVAR: auxvar_tensor,
155
- self.PTQ_THRESHOLD: ptq_threshold_tensor})
156
-
157
- if self.quantization_parameter_learning and not self.power_of_two:
158
- scale = layer.add_weight(
159
- name + self.SCALE_PTQ,
160
- shape=self.k_threshold,
161
- initializer=tf.keras.initializers.Constant(1.0),
162
- trainable=True)
163
- self.quantizer_parameters.update({self.SCALE_PTQ: scale})
164
-
165
- return self.quantizer_parameters
166
-
167
- def get_quantization_variable(self) -> List[tf.Tensor]:
168
- """
169
- This function return a list of quantizer parameters.
170
- Returns: A list of the quantizer parameters
171
-
172
- """
173
- if self.quantization_parameter_learning and not self.power_of_two:
174
- return [self.quantizer_parameters[self.SCALE_PTQ]]
175
- else:
176
- return []
177
-
178
- def __call__(self, inputs: tf.Tensor,
179
- training: bool,
180
- weights: Dict[str, tf.Variable],
181
- **kwargs: Dict[str, Any]):
182
- """
183
- Quantize a tensor.
184
- Args:
185
- inputs: Input tensor to quantize.
186
- training: Whether the graph is in training mode.
187
- weights: Dictionary of weights the quantizer can use to quantize the tensor.
188
- **kwargs: Additional variables the quantizer may receive.
189
-
190
- Returns:
191
- The quantized tensor.
192
- """
193
-
194
- auxvar = weights[gptq_constants.AUXVAR]
195
- ar_iter = weights[gptq_constants.GPTQ_ITER]
196
- ptq_threshold_tensor = weights[self.PTQ_THRESHOLD]
197
- aux_index_shift = weights[gptq_constants.AUXSHIFT]
198
- self.update_iteration(training, ar_iter)
199
- if self.per_axis:
200
- input_shape = inputs.shape
201
- n_axis = len(input_shape)
202
- quantization_axis = n_axis + self.quantization_axis if self.quantization_axis < 0 else \
203
- self.quantization_axis
204
- reshape_shape = [-1 if i == quantization_axis else 1 for i in range(n_axis)]
205
-
206
- reshape_shape_aux_ind = [-1, *[1 for _ in range(n_axis)]]
207
- #####################################################
208
- # Gumbel Softmax
209
- #####################################################
210
- if training:
211
- gumbel_scale = self.gumbel_scale if self.gumbel_scale_per_bitwidth is None \
212
- else self.gumbel_scale_per_bitwidth.get(self.num_bits, self.gumbel_scale)
213
- p_t = gumbel_softmax(auxvar, self.tau, self.g_t, gumbel_scale=gumbel_scale)
214
- else:
215
- p_t = gumbel_softmax(auxvar, self.minimal_temp, 0)
216
- p_t = ste_gumbel(p_t)
217
- self.p_t = p_t
218
- #####################################################
219
- # Calculate v hat and threshold hat
220
- #####################################################
221
- ptq_threshold_tensor_hat = tf.reshape(ptq_threshold_tensor, reshape_shape)
222
- auxvar_hat = tf.reduce_sum(p_t * tf.reshape(aux_index_shift, reshape_shape_aux_ind), axis=0)
223
- #####################################################
224
- # Quantized Input
225
- #####################################################
226
- q_tensor = gumbel_rounding_symmetric_quantizer(inputs, auxvar_hat,
227
- ptq_threshold_tensor_hat,
228
- self.num_bits,
229
- self.signed,
230
- self.power_of_two)
231
- if self.quantization_parameter_learning and not self.power_of_two:
232
- scale = tf.reshape(self.quantizer_parameters[self.SCALE_PTQ], reshape_shape)
233
- q_tensor *= scale
234
-
235
- return q_tensor
236
- else:
237
- return gumbel_rounding_symmetric_quantizer(inputs, auxvar,
238
- ptq_threshold_tensor,
239
- self.num_bits,
240
- self.signed,
241
- self.power_of_two)
242
-
243
- def get_quant_config(self, layer) -> Dict[str, np.ndarray]:
244
- """
245
- Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
246
-
247
- Args:
248
- layer: quantized layer
249
-
250
- Returns:
251
- A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
252
- Keys must match NodeQuantizationConfig attributes
253
-
254
- """
255
-
256
- if self.power_of_two:
257
- old_threshold = self.quantizer_parameters[self.PTQ_THRESHOLD]
258
- old_threshold = max_power_of_two(old_threshold, MIN_THRESHOLD)
259
- else:
260
- old_threshold = self.quantizer_parameters[self.PTQ_THRESHOLD]
261
- if self.quantization_parameter_learning:
262
- scale = tf.reshape(self.quantizer_parameters[self.SCALE_PTQ], self.threshold_shape)
263
- old_threshold = old_threshold * scale
264
- old_threshold = old_threshold.numpy()
265
- old_threshold = old_threshold.reshape(self.threshold_shape)
266
- return {THRESHOLD: old_threshold}