mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -0,0 +1,244 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import torch
16
+ import torch.nn as nn
17
+ from typing import Dict
18
+ import numpy as np
19
+
20
+ from model_compression_toolkit.core.common import max_power_of_two
21
+ from model_compression_toolkit import quantizers_infrastructure as qi
22
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
23
+ from model_compression_toolkit.gptq.common.gptq_config import RoundingType
24
+ from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
25
+ BasePytorchGPTQTrainableQuantizer
26
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
27
+ from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
28
+ from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
29
+ SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
30
+ from model_compression_toolkit.core.common.constants import THRESHOLD, MIN_THRESHOLD
31
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
32
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
33
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
34
+ get_threshold_reshape_shape
35
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
36
+
37
+
38
+ def soft_rounding_symmetric_quantizer(input_tensor: torch.Tensor,
39
+ auxvar_tensor: torch.Tensor,
40
+ threshold_tensor: torch.Tensor,
41
+ num_bits: int,
42
+ signed: bool,
43
+ power_of_two: bool) -> torch.Tensor:
44
+ """
45
+ Quantize a tensor symmetrically for GPTQ quantizers.
46
+
47
+ Args:
48
+ input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
49
+ auxvar_tensor: Tensor that manifests the bit shift of the quantized weights due to gptq training.
50
+ threshold_tensor: Tensor with values to compute the threshold.
51
+ num_bits: Num of bits to use.
52
+ signed: Signedness of the quantization range.
53
+ power_of_two: Whether the threshold should be constrained or not.
54
+
55
+ Returns:
56
+ A quantized tensor.
57
+ """
58
+
59
+ if power_of_two:
60
+ threshold_tensor = qutils.power_of_two_max(threshold_tensor)
61
+ delta = qutils.calculate_delta(threshold_tensor, num_bits, signed)
62
+ with torch.no_grad():
63
+ input_tensor_int = torch.floor(input_tensor / delta)
64
+ tensor_q = input_tensor_int + auxvar_tensor
65
+ int_threshold = 2 ** (num_bits - int(signed))
66
+ return delta * qutils.ste_clip(tensor_q,
67
+ min_val=-int(signed) * int_threshold,
68
+ max_val=int_threshold - 1)
69
+
70
+
71
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
72
+ quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
73
+ quantizer_type=RoundingType.SoftQuantizer)
74
+ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
75
+ """
76
+ Trainable symmetric quantizer to optimize the rounding of the quantized values using a soft quantization method.
77
+ """
78
+
79
+ def __init__(self,
80
+ quantization_config: TrainableQuantizerWeightsConfig,
81
+ quantization_parameter_learning: bool = False):
82
+ """
83
+ Construct a Pytorch model that utilize a fake weight quantizer of soft-quantizer for symmetric quantizer.
84
+
85
+ Args:
86
+ quantization_config: Trainable weights quantizer config.
87
+ quantization_parameter_learning (Bool): Whether to learn the threshold or not
88
+ """
89
+
90
+ super().__init__(quantization_config)
91
+ self.num_bits = quantization_config.weights_n_bits
92
+ self.per_channel = quantization_config.weights_per_channel_threshold
93
+
94
+ threshold_values = quantization_config.weights_quantization_params[THRESHOLD]
95
+ self.threshold_shape = np.asarray(threshold_values).shape
96
+ self.threshold_values = np.reshape(np.asarray(threshold_values), [-1]) if self.per_channel else float(
97
+ threshold_values)
98
+
99
+ self.quantization_axis = quantization_config.weights_channels_axis
100
+ self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
101
+ self.quantization_parameter_learning = quantization_parameter_learning
102
+
103
+ # gamma and zeta are stretch parameters for computing the rectified sigmoind function.
104
+ # See: https://arxiv.org/pdf/2004.10568.pdf
105
+ self.gamma = SOFT_ROUNDING_GAMMA
106
+ self.zeta = SOFT_ROUNDING_ZETA
107
+
108
+ self.quantizer_parameters = {}
109
+
110
+ def initialize_quantization(self,
111
+ tensor_shape: torch.Size,
112
+ name: str,
113
+ layer: qi.PytorchQuantizationWrapper):
114
+ """
115
+ Add quantizer parameters to the quantizer parameters dictionary
116
+
117
+ Args:
118
+ tensor_shape: tensor shape of the quantized tensor.
119
+ name: Tensor name.
120
+ layer: Layer to quantize.
121
+ """
122
+
123
+ if self.per_channel:
124
+ threshold_tensor = to_torch_tensor(self.threshold_values)
125
+ else:
126
+ threshold_tensor = torch.tensor(self.threshold_values)
127
+ layer.register_parameter(f"{name}_{PTQ_THRESHOLD}",
128
+ nn.Parameter(threshold_tensor, requires_grad=False))
129
+
130
+ w = layer.layer.weight
131
+ delta = qutils.calculate_delta(threshold_tensor.reshape(self.threshold_shape), self.num_bits, signed=True)
132
+ w_clipped_normed = torch.clip(w / delta, -2**(self.num_bits-1), 2**(self.num_bits-1)-1)
133
+ rest = w_clipped_normed - torch.floor(w_clipped_normed) # rest of rounding [0, 1)
134
+ # Note that (rest - self.gamma) can't be zero since rest is positive and gamma is negative, so the division
135
+ # is safe
136
+ alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest
137
+
138
+ layer.register_parameter(f"{name}_{AUXVAR}", nn.Parameter(alpha, requires_grad=True))
139
+
140
+ # save the quantizer added parameters for later calculations
141
+ self.add_quantizer_variable(PTQ_THRESHOLD, layer.get_parameter(f"{name}_{PTQ_THRESHOLD}"), VariableGroup.QPARAMS)
142
+ self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
143
+
144
+ if self.quantization_parameter_learning:
145
+ if self.per_channel:
146
+ layer.register_parameter(f"{name}_{SCALE_PTQ}",
147
+ nn.Parameter(to_torch_tensor(torch.ones_like(torch.Tensor(self.threshold_values))),
148
+ requires_grad=True))
149
+ else:
150
+ layer.register_parameter(f"{name}_{SCALE_PTQ}",
151
+ nn.Parameter(to_torch_tensor((torch.tensor([1.0], requires_grad=True)))))
152
+ self.add_quantizer_variable(SCALE_PTQ, layer.get_parameter(f"{name}_{SCALE_PTQ}"), VariableGroup.QPARAMS)
153
+
154
+ def get_soft_targets(self) -> torch.Tensor:
155
+ """
156
+ Computes the rectified sigmoid function for the quantization target parameters.
157
+
158
+ Returns:
159
+ A tensor with the soft rounding targets values.
160
+
161
+ """
162
+ scaled_sigmoid = torch.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma
163
+ return torch.clip(scaled_sigmoid, min=0, max=1)
164
+
165
+ def get_quant_config(self) -> Dict[str, np.ndarray]:
166
+ """
167
+ Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
168
+
169
+ Returns:
170
+ A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
171
+ Keys must match NodeQuantizationConfig attributes
172
+
173
+ """
174
+ old_threshold = torch_tensor_to_numpy(self.get_quantizer_variable(PTQ_THRESHOLD))
175
+ old_threshold = np.resize(old_threshold, self.threshold_shape)
176
+ if self.power_of_two:
177
+ old_threshold = max_power_of_two(old_threshold, MIN_THRESHOLD)
178
+ else:
179
+ if self.quantization_parameter_learning:
180
+ scale = torch.reshape(self.get_quantizer_variable(SCALE_PTQ), self.threshold_shape)
181
+ old_threshold = old_threshold * torch_tensor_to_numpy(scale)
182
+ old_threshold = old_threshold.reshape(self.threshold_shape)
183
+ return {THRESHOLD: old_threshold}
184
+
185
+ def __call__(self,
186
+ inputs: nn.Parameter,
187
+ training: bool) -> torch.Tensor:
188
+ """
189
+ Quantize a tensor.
190
+
191
+ Args:
192
+ inputs: Input tensor to quantize.
193
+ training: whether in training mode or not
194
+
195
+ Returns:
196
+ quantized tensor
197
+ """
198
+ auxvar = self.get_quantizer_variable(AUXVAR)
199
+ ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
200
+
201
+ #####################################################
202
+ # Soft Rounding
203
+ #####################################################
204
+ aux_var = self.get_soft_targets()
205
+ if not training:
206
+ aux_var = (aux_var >= 0.5).to(auxvar.dtype)
207
+
208
+ if self.per_channel:
209
+ reshape_shape = get_threshold_reshape_shape(inputs.shape,
210
+ quant_axis=self.quantization_axis,
211
+ quant_axis_dim=-1)
212
+
213
+ ##########################################################
214
+ # Calculate soft rounding targets and optimized threshold
215
+ ##########################################################
216
+ ptq_threshold_tensor_hat = torch.reshape(ptq_threshold_tensor, reshape_shape)
217
+
218
+ #####################################################
219
+ # Quantized Input
220
+ #####################################################
221
+ q_tensor = soft_rounding_symmetric_quantizer(input_tensor=inputs,
222
+ auxvar_tensor=aux_var,
223
+ threshold_tensor=ptq_threshold_tensor_hat,
224
+ num_bits=self.num_bits,
225
+ signed=True,
226
+ power_of_two=self.power_of_two)
227
+
228
+ if self.quantization_parameter_learning and not self.power_of_two:
229
+ scale = torch.reshape(self.get_quantizer_variable(SCALE_PTQ), reshape_shape)
230
+ q_tensor *= scale
231
+
232
+ else:
233
+ q_tensor = soft_rounding_symmetric_quantizer(input_tensor=inputs,
234
+ auxvar_tensor=aux_var,
235
+ threshold_tensor=ptq_threshold_tensor,
236
+ num_bits=self.num_bits,
237
+ signed=True,
238
+ power_of_two=self.power_of_two)
239
+
240
+ if self.quantization_parameter_learning and not self.power_of_two:
241
+ scale = self.get_quantizer_variable(SCALE_PTQ)
242
+ q_tensor *= scale
243
+
244
+ return q_tensor
@@ -0,0 +1,196 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import torch
16
+ import torch.nn as nn
17
+ from typing import Dict
18
+ import numpy as np
19
+
20
+ from model_compression_toolkit import quantizers_infrastructure as qi
21
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
22
+ from model_compression_toolkit.gptq.common.gptq_config import RoundingType
23
+ from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
24
+ BasePytorchGPTQTrainableQuantizer
25
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
26
+ from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
27
+ from model_compression_toolkit.gptq.common.gptq_constants import SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
28
+ from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import fix_range_to_include_zero
29
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
30
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
31
+ mark_quantizer
32
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import \
33
+ VariableGroup
34
+ from model_compression_toolkit.core.common.constants import RANGE_MAX, RANGE_MIN
35
+ from model_compression_toolkit.qat.common.constants import FQ_MIN, FQ_MAX
36
+
37
+ def soft_rounding_unifrom_quantizer(input_tensor: torch.Tensor,
38
+ auxvar_tensor: torch.Tensor,
39
+ min_range: torch.Tensor,
40
+ max_range: torch.Tensor,
41
+ num_bits: int) -> torch.Tensor:
42
+ """
43
+ Quantize a tensor uniformly for GPTQ quantizers.
44
+
45
+ Args:
46
+ input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
47
+ auxvar_tensor: Tensor that manifests the bit shift of the quantized weights due to gptq training.
48
+ min_range: Tensor with min values to compute the delta grid.
49
+ max_range: Tensor with max values to compute the delta grid.
50
+ num_bits: Num of bits to use.
51
+
52
+ Returns:
53
+ A quantized tensor.
54
+ """
55
+ # adjusts the quantization range so the quantization grid includes zero.
56
+ min_range, max_range = fix_range_to_include_zero(min_range, max_range, num_bits)
57
+ delta = qutils.calculate_delta_uniform(max_range, min_range, num_bits)
58
+ with torch.no_grad():
59
+ input_tensor_int = torch.floor(input_tensor / delta)
60
+ tensor_q = input_tensor_int + auxvar_tensor
61
+ return delta * qutils.ste_clip(tensor_q,
62
+ min_val=0,
63
+ max_val=2 ** num_bits - 1)
64
+
65
+
66
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
67
+ quantization_method=[QuantizationMethod.UNIFORM],
68
+ quantizer_type=RoundingType.SoftQuantizer)
69
+ class UniformSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
70
+ """
71
+ Trainable uniform quantizer to optimize the rounding of the quantized values using a soft quantization method.
72
+ """
73
+
74
+ def __init__(self,
75
+ quantization_config: TrainableQuantizerWeightsConfig,
76
+ quantization_parameter_learning: bool = False):
77
+ """
78
+ Construct a Pytorch model that utilize a fake weight quantizer of soft-quantizer for symmetric quantizer.
79
+
80
+ Args:
81
+ quantization_config: Trainable weights quantizer config.
82
+ quantization_parameter_learning (Bool): Whether to learn the min/max ranges or not
83
+ """
84
+
85
+ super().__init__(quantization_config)
86
+ self.num_bits = quantization_config.weights_n_bits
87
+ self.per_channel = quantization_config.weights_per_channel_threshold
88
+
89
+ self.min_values = quantization_config.weights_quantization_params[RANGE_MIN]
90
+ self.max_values = quantization_config.weights_quantization_params[RANGE_MAX]
91
+
92
+ self.quantization_axis = quantization_config.weights_channels_axis
93
+ self.quantization_parameter_learning = quantization_parameter_learning
94
+
95
+ # gamma and zeta are stretch parameters for computing the rectified sigmoid function.
96
+ # See: https://arxiv.org/pdf/2004.10568.pdf
97
+ self.gamma = SOFT_ROUNDING_GAMMA
98
+ self.zeta = SOFT_ROUNDING_ZETA
99
+
100
+ def initialize_quantization(self,
101
+ tensor_shape: torch.Size,
102
+ name: str,
103
+ layer: qi.PytorchQuantizationWrapper):
104
+ """
105
+ Add quantizer parameters to the quantizer parameters dictionary
106
+
107
+ Args:
108
+ tensor_shape: tensor shape of the quantized tensor.
109
+ name: Tensor name.
110
+ layer: Layer to quantize.
111
+ """
112
+
113
+ # Add min and max variables to layer.
114
+ if self.per_channel:
115
+ min_values = to_torch_tensor(self.min_values)
116
+ max_values = to_torch_tensor(self.max_values)
117
+ else:
118
+ min_values = torch.tensor(self.min_values)
119
+ max_values = torch.tensor(self.max_values)
120
+
121
+ layer.register_parameter(name+"_"+FQ_MIN, nn.Parameter(min_values, requires_grad=self.quantization_parameter_learning))
122
+ layer.register_parameter(name+"_"+FQ_MAX, nn.Parameter(max_values, requires_grad=self.quantization_parameter_learning))
123
+
124
+ w = layer.layer.weight
125
+ delta = qutils.calculate_delta_uniform(max_values, min_values, self.num_bits)
126
+ w_clipped_normed = torch.clip(w / delta, 0, 2 ** self.num_bits - 1)
127
+ rest = w_clipped_normed - torch.floor(w_clipped_normed) # rest of rounding [0, 1)
128
+ alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest
129
+ layer.register_parameter(f"{name}_{AUXVAR}", nn.Parameter(alpha, requires_grad=True))
130
+
131
+ # Save the quantizer parameters
132
+ self.add_quantizer_variable(FQ_MIN, layer.get_parameter(name+"_"+FQ_MIN), VariableGroup.QPARAMS)
133
+ self.add_quantizer_variable(FQ_MAX, layer.get_parameter(name+"_"+FQ_MAX), VariableGroup.QPARAMS)
134
+ self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
135
+
136
+
137
+ def get_soft_targets(self) -> torch.Tensor:
138
+ """
139
+ Computes the rectified sigmoid function for the quantization target parameters.
140
+
141
+ Returns:
142
+ A tensor with the soft rounding targets values.
143
+
144
+ """
145
+ scaled_sigmoid = torch.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma
146
+ return torch.clip(scaled_sigmoid, min=0, max=1)
147
+
148
+ def get_quant_config(self) -> Dict[str, np.ndarray]:
149
+ """
150
+ Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
151
+
152
+ Returns:
153
+ A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
154
+ Keys must match NodeQuantizationConfig attributes
155
+
156
+ """
157
+ min_values = torch_tensor_to_numpy(self.get_quantizer_variable(FQ_MIN))
158
+ max_values = torch_tensor_to_numpy(self.get_quantizer_variable(FQ_MAX))
159
+ return {RANGE_MIN: min_values,
160
+ RANGE_MAX: max_values}
161
+
162
+ def __call__(self,
163
+ inputs: nn.Parameter,
164
+ training: bool) -> torch.Tensor:
165
+ """
166
+ Quantize a tensor.
167
+
168
+ Args:
169
+ inputs: Input tensor to quantize.
170
+ training: whether in training mode or not
171
+
172
+ Returns:
173
+ quantized tensor
174
+ """
175
+ auxvar = self.get_quantizer_variable(AUXVAR)
176
+ min_range = self.get_quantizer_variable(FQ_MIN)
177
+ max_range = self.get_quantizer_variable(FQ_MAX)
178
+
179
+ #####################################################
180
+ # Soft Rounding
181
+ #####################################################
182
+ aux_var = self.get_soft_targets()
183
+ if not training:
184
+ aux_var = (aux_var >= 0.5).to(auxvar.dtype)
185
+
186
+ #####################################################
187
+ # Quantized Input
188
+ #####################################################
189
+ q_tensor = soft_rounding_unifrom_quantizer(input_tensor=inputs,
190
+ auxvar_tensor=aux_var,
191
+ min_range=min_range,
192
+ max_range=max_range,
193
+ num_bits=self.num_bits)
194
+
195
+
196
+ return q_tensor
@@ -0,0 +1,182 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import torch
16
+ import torch.nn as nn
17
+ from typing import Dict
18
+ import numpy as np
19
+ from model_compression_toolkit.core.common.defaultdict import DefaultDict
20
+
21
+ from model_compression_toolkit import quantizers_infrastructure as qi
22
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
23
+ from model_compression_toolkit.gptq.common.gptq_config import RoundingType
24
+ from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
25
+ BasePytorchGPTQTrainableQuantizer
26
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
27
+ from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
28
+ from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD, MAX_LSB_CHANGE
29
+ from model_compression_toolkit.core.common.constants import THRESHOLD
30
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
31
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
32
+ mark_quantizer
33
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
34
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
35
+ get_threshold_reshape_shape
36
+
37
+
38
+ def pertubation_symmetric_quantizer(input_tensor: torch.Tensor,
39
+ auxvar_tensor: nn.Parameter,
40
+ max_tensor: torch.Tensor,
41
+ num_bits: int,
42
+ signed: bool,
43
+ power_of_two: bool,
44
+ max_lsbs_change: int = MAX_LSB_CHANGE) -> nn.Parameter:
45
+ """
46
+ Quantize a tensor symmetrically with maximum LSBs shift.
47
+
48
+ Args:
49
+ input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
50
+ auxvar_tensor: Tensor that manifests the bit shift the weight due to gptq
51
+ max_tensor: Tensor with max values to compute the threshold.
52
+ num_bits: Num of bits to use.
53
+ signed: Signedness of the quantization range.
54
+ power_of_two: Whether the threshold should be constrained or not.
55
+ max_lsbs_change: maximum number of LSBs that the auxvar is allowed to change
56
+
57
+ Returns:
58
+ A quantized tensor.
59
+ """
60
+
61
+ if power_of_two:
62
+ max_tensor = qutils.power_of_two_max(max_tensor)
63
+ delta = qutils.calculate_delta(max_tensor, num_bits, signed)
64
+ delta = to_torch_tensor(delta)
65
+ max_tensor_change = delta * max_lsbs_change
66
+
67
+ min_int = -int(signed) * (2 ** (num_bits - int(signed)))
68
+ max_int = (2 ** (num_bits - int(signed))) - 1
69
+
70
+ tensor_clipped = qutils.ste_clip(auxvar_tensor, min_val=-max_tensor_change, max_val=max_tensor_change) / delta
71
+ input_tensor_int = torch.round(input_tensor / delta).detach()
72
+
73
+ tensor_q = qutils.ste_round(qutils.ste_round(input_tensor_int + tensor_clipped))
74
+
75
+ return delta * qutils.ste_clip(tensor_q, max_val=max_int, min_val=min_int)
76
+
77
+
78
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
79
+ quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
80
+ quantizer_type=RoundingType.STE)
81
+ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
82
+ """
83
+ Trainable symmetric quantizer to quantize a layer weights.
84
+ """
85
+
86
+ def __init__(self,
87
+ quantization_config: TrainableQuantizerWeightsConfig,
88
+ max_lsbs_change_map: dict = DefaultDict({}, lambda: 1)):
89
+ """
90
+ Construct a Pytorch model that utilize a fake weight quantizer of STE (Straight Through Estimator) for symmetric quantizer.
91
+
92
+ Args:
93
+ quantization_config: Trainable weights quantizer config.
94
+ """
95
+ super().__init__(quantization_config)
96
+ self.num_bits = quantization_config.weights_n_bits
97
+ self.per_channel = quantization_config.weights_per_channel_threshold
98
+
99
+ threshold_values = quantization_config.weights_quantization_params[THRESHOLD]
100
+ self.threshold_shape = np.asarray(threshold_values).shape
101
+ self.threshold_values = np.reshape(np.asarray(threshold_values), [-1]) if self.per_channel else float(
102
+ threshold_values)
103
+
104
+ self.quantization_axis = quantization_config.weights_channels_axis
105
+ self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
106
+ self.max_lsbs_change = max_lsbs_change_map.get(self.num_bits)
107
+
108
+
109
+ def initialize_quantization(self,
110
+ tensor_shape: torch.Size,
111
+ name: str,
112
+ layer: qi.PytorchQuantizationWrapper):
113
+ """
114
+ Add quantizer parameters to the quantizer parameters dictionary
115
+
116
+ Args:
117
+ tensor_shape: tensor shape of the quantized tensor.
118
+ name: Tensor name.
119
+ layer: Layer to quantize.
120
+ """
121
+
122
+ layer.register_parameter(f"{name}_{PTQ_THRESHOLD}",
123
+ nn.Parameter(torch.tensor(self.threshold_values, requires_grad=False)
124
+ if not self.per_channel
125
+ else to_torch_tensor(self.threshold_values),requires_grad=False))
126
+ layer.register_parameter(f"{name}_{AUXVAR}", nn.Parameter(to_torch_tensor(torch.zeros(self.threshold_shape)),
127
+ requires_grad=True))
128
+
129
+ # save the quantizer added parameters for later calculations
130
+ self.add_quantizer_variable(PTQ_THRESHOLD, layer.get_parameter(f"{name}_{PTQ_THRESHOLD}"), VariableGroup.QPARAMS)
131
+ self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
132
+
133
+
134
+ def get_quant_config(self) -> Dict[str, np.ndarray]:
135
+ """
136
+ Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
137
+
138
+ Returns:
139
+ A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
140
+ Keys must match NodeQuantizationConfig attributes
141
+
142
+ """
143
+ old_threshold = self.get_quantizer_variable(PTQ_THRESHOLD)
144
+ return {THRESHOLD: torch_tensor_to_numpy(old_threshold).reshape(self.threshold_shape)}
145
+
146
+ def __call__(self,
147
+ inputs: nn.Parameter,
148
+ training: bool) -> nn.Parameter:
149
+ """
150
+ Quantize a tensor.
151
+
152
+ Args:
153
+ inputs: Input tensor to quantize.
154
+ training: whether in training mode or not
155
+
156
+ Returns:
157
+ quantized tensor
158
+ """
159
+ auxvar = self.get_quantizer_variable(AUXVAR)
160
+ ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
161
+
162
+ if self.per_channel:
163
+ reshape_shape = get_threshold_reshape_shape(inputs.shape,
164
+ quant_axis=self.quantization_axis,
165
+ quant_axis_dim=-1)
166
+ ptq_threshold_tensor = torch.reshape(ptq_threshold_tensor, reshape_shape)
167
+
168
+ q_tensor = pertubation_symmetric_quantizer(inputs,
169
+ auxvar,
170
+ ptq_threshold_tensor,
171
+ self.num_bits,
172
+ signed=True,
173
+ power_of_two=self.power_of_two,
174
+ max_lsbs_change=self.max_lsbs_change)
175
+ return q_tensor
176
+ else:
177
+ return pertubation_symmetric_quantizer(inputs,
178
+ auxvar,
179
+ ptq_threshold_tensor,
180
+ self.num_bits,
181
+ signed=True,
182
+ power_of_two=self.power_of_two)
@@ -125,8 +125,8 @@ if FOUND_TF:
125
125
  if core_config.mixed_precision_enable:
126
126
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
127
127
  common.Logger.error("Given quantization config to mixed-precision facade is not of type "
128
- "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization API,"
129
- "or pass a valid mixed precision configuration.")
128
+ "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
129
+ "API, or pass a valid mixed precision configuration.") # pragma: no cover
130
130
 
131
131
  common.Logger.info("Using experimental mixed-precision quantization. "
132
132
  "If you encounter an issue please file a bug.")
@@ -171,4 +171,4 @@ else:
171
171
  def keras_post_training_quantization_experimental(*args, **kwargs):
172
172
  Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
173
173
  'when using keras_post_training_quantization_experimental. '
174
- 'Could not find Tensorflow package.')
174
+ 'Could not find Tensorflow package.') # pragma: no cover