mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -12,16 +12,22 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Callable, List, Tuple
15
+ from typing import Callable, List, Tuple, Union
16
16
 
17
17
  import tensorflow as tf
18
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
18
+ from keras import Model
19
+ from tensorflow.keras.layers import Layer
19
20
  from tqdm import tqdm
20
21
 
21
22
  # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
22
- from model_compression_toolkit.gptq.keras.gptq_model_builder import GPTQKerasModelBuilder
23
+ from model_compression_toolkit.core.common.user_info import UserInformation
24
+ from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
23
25
  from packaging import version
24
26
 
27
+ from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
28
+ from model_compression_toolkit.gptq.keras.quantizer.quantization_builder import quantization_builder
29
+ from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
30
+
25
31
  if version.parse(tf.__version__) < version.parse("2.6"):
26
32
  from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
27
33
  else:
@@ -31,15 +37,14 @@ from model_compression_toolkit.core import common
31
37
  from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
32
38
  from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
33
39
  from model_compression_toolkit.core.common import Graph
34
- from model_compression_toolkit.gptq.keras.graph_info import get_trainable_parameters, get_weights_for_loss, \
35
- get_gumbel_probability
40
+ from model_compression_toolkit.gptq.keras.graph_info import get_weights_for_loss, get_gptq_trainable_parameters
41
+ from model_compression_toolkit.gptq.keras.quantizer.regularization_factory import get_regularization
36
42
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
37
43
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
38
44
  import numpy as np
39
45
  import copy
40
46
  from model_compression_toolkit.core.keras.constants import BIAS, USE_BIAS
41
- from model_compression_toolkit.gptq.keras.quantizer import WeightQuantizeConfig
42
- from model_compression_toolkit.gptq.keras.optimizers.sam_optimizer import SAM
47
+ from model_compression_toolkit import quantizers_infrastructure as qi
43
48
 
44
49
 
45
50
  class KerasGPTQTrainer(GPTQTrainer):
@@ -77,11 +82,10 @@ class KerasGPTQTrainer(GPTQTrainer):
77
82
  self.loss_list = []
78
83
  self.input_scale = 1
79
84
 
80
- trainable_weights, bias_weights, trainable_threshold, temperature_weights = get_trainable_parameters(
85
+ trainable_weights, bias_weights, trainable_threshold = get_gptq_trainable_parameters(
81
86
  self.fxp_model,
82
87
  fw_info,
83
- add_bias=gptq_config.train_bias,
84
- is_gumbel=gptq_config.is_gumbel)
88
+ add_bias=gptq_config.train_bias)
85
89
 
86
90
  self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
87
91
 
@@ -96,29 +100,70 @@ class KerasGPTQTrainer(GPTQTrainer):
96
100
  trainable_quantization_parameters = trainable_threshold
97
101
  self.optimizer_with_param = self.get_optimizer_with_param(flattened_trainable_weights,
98
102
  flattened_bias_weights,
99
- trainable_quantization_parameters,
100
- temperature_weights)
101
- self.has_params_to_train = np.sum([len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param])>0
103
+ trainable_quantization_parameters)
104
+ self.has_params_to_train = np.sum(
105
+ [len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
102
106
 
103
107
  if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
104
108
  common.Logger.error("Input scale mismatch between float and GPTQ networks") # pragma: no cover
105
109
  else:
106
110
  self.input_scale = self.gptq_user_info.input_scale
107
111
 
108
- self.weights_for_average_loss = self.compute_jacobian_based_weights(representative_data_gen)
112
+ self.weights_for_average_loss = self.compute_hessian_based_weights(representative_data_gen)
113
+
114
+ self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
109
115
 
110
- def build_gptq_model(self):
116
+ def _is_gptq_applicable(self,
117
+ node: common.BaseNode) -> bool:
118
+ """
119
+ A function for deciding if a layer should be fine-tuned during GPTQ.
120
+
121
+ Args:
122
+ node (BaseNode): Node for quantization decision
123
+
124
+ Returns:
125
+ A boolean whether the layer is to be wrapped with a QuantizeWrapper
126
+ """
127
+
128
+ if node.is_weights_quantization_enabled() and not self.fw_info.is_kernel_op(node.type):
129
+ common.Logger.error(f"GPTQ Error: Quantizing node {node.name} of type {node.type} "
130
+ f"without a kernel isn't supported")
131
+ return node.is_weights_quantization_enabled()
132
+
133
+ def gptq_wrapper(self, n: common.BaseNode, layer: Layer) -> Union[qi.KerasQuantizationWrapper, Layer]:
134
+ """
135
+ A function which takes a computational graph node and a keras layer and perform the quantization wrapping.
136
+
137
+ Args:
138
+ n: A node of mct graph.
139
+ layer: A keras layer
140
+
141
+ Returns: Wrapped layer if the layer should be wrap, otherwise returns the layer as is.
142
+
143
+ """
144
+ if self._is_gptq_applicable(n):
145
+ weights_quantizers, activation_quantizers = quantization_builder(n, self.gptq_config)
146
+ return qi.KerasQuantizationWrapper(layer,
147
+ weights_quantizers=weights_quantizers,
148
+ activation_quantizers=activation_quantizers)
149
+ else:
150
+ return layer
151
+
152
+ def build_gptq_model(self) -> Tuple[Model, UserInformation]:
111
153
  """
112
154
  Build the GPTQ model with QuantizationWrappers
155
+
113
156
  Returns:
114
157
  Quantized graph for GPTQ fine-tuning, GPTQ graph user info
115
158
  """
116
159
 
117
- return GPTQKerasModelBuilder(graph=self.graph_quant,
118
- gptq_config=self.gptq_config,
119
- append2output=self.compare_points,
120
- fw_info=self.fw_info,
121
- return_float_outputs=True).build_model()
160
+ gptq_model, gptq_user_info = KerasModelBuilder(graph=self.graph_quant,
161
+ append2output=self.compare_points,
162
+ fw_info=self.fw_info,
163
+ return_float_outputs=True,
164
+ wrapper=self.gptq_wrapper).build_model()
165
+
166
+ return gptq_model, gptq_user_info
122
167
 
123
168
  def compute_gradients(self, in_y_float: List[tf.Tensor], input_data: List[np.ndarray],
124
169
  in_optimizer_with_param: List,
@@ -149,18 +194,9 @@ class KerasGPTQTrainer(GPTQTrainer):
149
194
  self.compare_points_std,
150
195
  self.weights_for_average_loss)
151
196
 
152
- if self.gptq_config.is_gumbel and self.gptq_config.quantizer_config.temperature_learning:
153
- gumbel_prob = get_gumbel_probability(self.fxp_model)
154
- gumbel_reg = 0
155
- for p in gumbel_prob:
156
- entropy = -tf.reduce_mean(
157
- tf.reduce_sum(p * tf.math.log(tf.maximum(p,
158
- self.gptq_config.eps)),
159
- axis=0))
197
+ reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor)
160
198
 
161
- gumbel_reg += entropy
162
- gumbel_reg /= len(gumbel_prob)
163
- loss_value += self.gptq_config.quantizer_config.gumbel_entropy_regularization * gumbel_reg
199
+ loss_value += reg_value
164
200
 
165
201
  # Use the gradient tape to automatically retrieve
166
202
  # the gradients of the trainable variables with respect to the loss.
@@ -179,9 +215,6 @@ class KerasGPTQTrainer(GPTQTrainer):
179
215
  representative_data_gen: Dataset to use for inputs of the models.
180
216
  """
181
217
  compute_gradients = self.compute_gradients
182
- if self.gptq_config.sam_optimization:
183
- sam = SAM(self.fxp_model, self.compute_gradients, self.optimizer_with_param, self.gptq_config.rho)
184
- compute_gradients = sam.compute_gradients
185
218
 
186
219
  # ----------------------------------------------
187
220
  # Training loop
@@ -237,7 +270,8 @@ class KerasGPTQTrainer(GPTQTrainer):
237
270
  for data in tqdm(data_function()):
238
271
  input_data = [d * self.input_scale for d in data]
239
272
 
240
- loss_value_step, grads = self.nano_training_step(input_data, in_compute_gradients, in_optimizer_with_param, is_training)
273
+ loss_value_step, grads = self.nano_training_step(input_data, in_compute_gradients,
274
+ in_optimizer_with_param, is_training)
241
275
  # Run one step of gradient descent by updating
242
276
  # the value of the variables to minimize the loss.
243
277
  for i, (o, p) in enumerate(in_optimizer_with_param):
@@ -258,16 +292,17 @@ class KerasGPTQTrainer(GPTQTrainer):
258
292
  graph = copy.copy(self.graph_quant)
259
293
 
260
294
  for layer in self.fxp_model.layers:
261
- if isinstance(layer, QuantizeWrapper) and isinstance(
262
- layer.quantize_config, WeightQuantizeConfig):
295
+ if isinstance(layer, KerasQuantizationWrapper):
263
296
  node = graph.find_node_by_name(layer.layer.name)
264
297
  if len(node) == 0 and isinstance(layer.layer, TensorFlowOpLayer):
265
298
  node = graph.find_node_by_name('_'.join(layer.layer.name.split('_')[3:]))
266
299
  if len(node) != 1:
267
300
  common.Logger.error(f"Can't update GPTQ graph due to missing layer named: {layer.layer.name}")
268
301
  node = node[0]
302
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
303
+ fw_info=self.fw_info)
269
304
  weights, weight_quant_config, activation_quant_config = \
270
- layer.quantize_config.update_layer_quantization_params(layer)
305
+ layer.weights_quantizers[kernel_attribute].update_layer_quantization_params(layer)
271
306
  for weight_attr, weight in weights.items():
272
307
  node.set_weights_by_keys(weight_attr, weight.numpy())
273
308
  for config_attr, config_value in weight_quant_config.items():
@@ -281,4 +316,3 @@ class KerasGPTQTrainer(GPTQTrainer):
281
316
  node.set_weights_by_keys(BIAS, new_bias)
282
317
 
283
318
  return graph
284
-
@@ -13,22 +13,21 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
-
17
16
  import tensorflow as tf
18
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
19
17
  from typing import Tuple, List
20
-
21
18
  from model_compression_toolkit.core.keras.constants import USE_BIAS
22
- from model_compression_toolkit.gptq.keras.quantizer import WeightQuantizeConfig
23
19
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
24
20
  from tensorflow.keras.models import Model
21
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
22
+ from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
23
+ from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
24
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
25
25
 
26
26
 
27
- def get_trainable_parameters(fxp_model: Model,
28
- fw_info: FrameworkInfo,
29
- add_bias: bool = False,
30
- is_gumbel: bool = False) -> (
31
- List[tf.Variable], List[tf.Variable], List[tf.Variable], List[tf.Variable], List[tf.Variable]):
27
+ def get_gptq_trainable_parameters(fxp_model: Model,
28
+ fw_info: FrameworkInfo,
29
+ add_bias: bool = False) -> (
30
+ List[tf.Variable], List[tf.Variable], List[tf.Variable]):
32
31
  """
33
32
  Get trainable parameters from all layers in a model
34
33
 
@@ -36,7 +35,6 @@ def get_trainable_parameters(fxp_model: Model,
36
35
  fxp_model: Model to get its trainable parameters.
37
36
  fw_info: Framework information needed for keras kernel ops list.
38
37
  add_bias: Whether to include biases of the model (if there are) or not.
39
- is_gumbel: Whether the fxp model is quantized using Gumbel Rounding
40
38
 
41
39
  Returns:
42
40
  A list of trainable variables in a model. Each item is a list of a layers weights.
@@ -45,15 +43,17 @@ def get_trainable_parameters(fxp_model: Model,
45
43
  trainable_weights: List[tf.Tensor] = []
46
44
  trainable_threshold: List[tf.Tensor] = []
47
45
  bias_weights: List[List[tf.Tensor]] = []
48
- temperature_weights: List[tf.Tensor] = []
46
+
49
47
  for layer in fxp_model.layers:
50
- if isinstance(layer, QuantizeWrapper) and isinstance(
51
- layer.quantize_config, WeightQuantizeConfig):
52
- # collect trainable weights per layer
53
- layer_trainable_weights = layer.quantize_config.get_aux_variable()
54
- layer_trainable_threshold = layer.quantize_config.get_quantization_variable()
55
- if is_gumbel:
56
- temperature_weights.append(layer.quantize_config.get_temperature_variable())
48
+ if isinstance(layer, KerasQuantizationWrapper):
49
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
50
+ fw_info=DEFAULT_KERAS_INFO)
51
+
52
+ # collect trainable weights per quantizer
53
+ quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
54
+ quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
55
+ trainable_weights.append(quantizer_trainable_weights)
56
+ trainable_threshold.extend(quantizer_trainable_threshold)
57
57
 
58
58
  if add_bias:
59
59
  kernel_ops_attrs = fw_info.kernel_ops_attributes_mapping.get(type(layer.layer))
@@ -61,27 +61,8 @@ def get_trainable_parameters(fxp_model: Model,
61
61
  and layer.layer.get_config().get(USE_BIAS)
62
62
  if use_bias is not None and use_bias:
63
63
  bias_weights.append([layer.layer.bias])
64
- trainable_weights.append(layer_trainable_weights)
65
- trainable_threshold.extend(layer_trainable_threshold)
66
64
 
67
- return trainable_weights, bias_weights, trainable_threshold, temperature_weights
68
-
69
-
70
- def get_gumbel_probability(fxp_model: Model) -> List[tf.Tensor]:
71
- """
72
- This function return the gumbel softmax probability of GumRounding
73
- Args:
74
- fxp_model: A model to be quantized with GumRounding
75
-
76
- Returns: A list of tensors.
77
-
78
- """
79
- gumbel_prob_aux: List[tf.Tensor] = []
80
- for layer in fxp_model.layers:
81
- if isinstance(layer, QuantizeWrapper) and isinstance(
82
- layer.quantize_config, WeightQuantizeConfig):
83
- gumbel_prob_aux.append(layer.quantize_config.get_gumbel_probability())
84
- return gumbel_prob_aux
65
+ return trainable_weights, bias_weights, trainable_threshold
85
66
 
86
67
 
87
68
  def get_weights_for_loss(fxp_model: Model) -> Tuple[List[list], List[list]]:
@@ -99,14 +80,14 @@ def get_weights_for_loss(fxp_model: Model) -> Tuple[List[list], List[list]]:
99
80
  flp_weights_list = []
100
81
  fxp_weights_list = []
101
82
  for layer in fxp_model.layers:
102
- if isinstance(layer, QuantizeWrapper) and isinstance(
103
- layer.quantize_config, WeightQuantizeConfig):
83
+ if isinstance(layer, KerasQuantizationWrapper):
104
84
 
105
85
  # collect pairs of float and quantized weights per layer
106
86
  _layer_flp_weights, _layer_fxp_weights = [], []
107
- for weight, quantizer, quantizer_vars in layer._weight_vars:
108
- _layer_flp_weights.append(weight)
109
- _layer_fxp_weights.append(quantizer(weight, training=False, weights=quantizer_vars))
87
+ for weight, quantizer_vars, quantizer in layer.get_weights_vars():
88
+ _layer_flp_weights.append(quantizer_vars)
89
+ _layer_fxp_weights.append(quantizer(training=False, inputs=quantizer_vars))
90
+
110
91
  flp_weights_list.append(_layer_flp_weights)
111
92
  fxp_weights_list.append(_layer_fxp_weights)
112
93
 
@@ -85,26 +85,18 @@ if common.constants.FOUND_TF:
85
85
 
86
86
  Create a GradientPTQConfigV2 to run for 5 epochs:
87
87
 
88
- >>> gptq_conf = mct.get_keras_gptq_config(n_epochs=5)
88
+ >>> gptq_conf = mct.gptq.get_keras_gptq_config(n_epochs=5)
89
89
 
90
90
  Other Tensorflow optimizers can be passed:
91
91
 
92
- >>> gptq_conf = mct.get_keras_gptq_config(n_epochs=3, optimizer=tf.keras.optimizers.Nadam())
92
+ >>> gptq_conf = mct.gptq.get_keras_gptq_config(n_epochs=3, optimizer=tf.keras.optimizers.Nadam())
93
93
 
94
94
  The configuration can be passed to :func:`~model_compression_toolkit.keras_post_training_quantization` in order to quantize a keras model using gptq.
95
95
 
96
96
  """
97
97
  bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
98
- optimizer_quantization_parameter = tf.keras.optimizers.SGD(learning_rate=LR_QUANTIZATION_PARAM_DEFAULT, momentum=GPTQ_MOMENTUM)
99
- return GradientPTQConfigV2(n_epochs,
100
- optimizer,
101
- optimizer_rest=optimizer_rest,
102
- loss=loss,
103
- log_function=log_function,
104
- train_bias=True,
105
- quantization_parameters_learning=True,
106
- optimizer_bias=bias_optimizer,
107
- optimizer_quantization_parameter=optimizer_quantization_parameter)
98
+ return GradientPTQConfigV2(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
99
+ log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer)
108
100
 
109
101
 
110
102
  def keras_gradient_post_training_quantization_experimental(in_model: Model,
@@ -183,11 +175,11 @@ if common.constants.FOUND_TF:
183
175
 
184
176
  Create GPTQ config:
185
177
 
186
- >>> gptq_config = mct.get_keras_gptq_config(n_epochs=1)
178
+ >>> gptq_config = mct.gptq.get_keras_gptq_config(n_epochs=1)
187
179
 
188
180
  Pass the model with the representative dataset generator to get a quantized model:
189
181
 
190
- >>> quantized_model, quantization_info = mct.keras_gradient_post_training_quantization_experimental(model, repr_datagen, gptq_config, target_kpi=kpi, core_config=config)
182
+ >>> quantized_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization_experimental(model, repr_datagen, gptq_config, target_kpi=kpi, core_config=config)
191
183
 
192
184
  """
193
185
  KerasModelValidation(model=in_model,
@@ -196,8 +188,8 @@ if common.constants.FOUND_TF:
196
188
  if core_config.mixed_precision_enable:
197
189
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
198
190
  common.Logger.error("Given quantization config to mixed-precision facade is not of type "
199
- "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization API,"
200
- "or pass a valid mixed precision configuration.")
191
+ "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
192
+ "API, or pass a valid mixed precision configuration.") # pragma: no cover
201
193
 
202
194
  common.Logger.info("Using experimental mixed-precision quantization. "
203
195
  "If you encounter an issue please file a bug.")
@@ -243,10 +235,10 @@ else:
243
235
  def get_keras_gptq_config(*args, **kwargs):
244
236
  Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
245
237
  'when using keras_post_training_quantization_mixed_precision. '
246
- 'Could not find Tensorflow package.')
238
+ 'Could not find Tensorflow package.') # pragma: no cover
247
239
 
248
240
 
249
241
  def keras_gradient_post_training_quantization_experimental(*args, **kwargs):
250
242
  Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
251
243
  'when using keras_gradient_post_training_quantization_experimental. '
252
- 'Could not find Tensorflow package.')
244
+ 'Could not find Tensorflow package.') # pragma: no cover
@@ -13,4 +13,5 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.gptq.keras.quantizer.configs.weight_quantizer_gptq_config import WeightQuantizeConfig
16
+ import model_compression_toolkit.gptq.keras.quantizer.ste_rounding.symmetric_ste
17
+ import model_compression_toolkit.gptq.keras.quantizer.soft_rounding.symmetric_soft_quantizer
@@ -0,0 +1,112 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from abc import abstractmethod
16
+ from typing import Union, Dict, List
17
+
18
+ from model_compression_toolkit.core.common import Logger
19
+ from model_compression_toolkit.core.common.constants import FOUND_TF
20
+ from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
21
+
22
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
23
+ TrainableQuantizerActivationConfig
24
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer
25
+
26
+ if FOUND_TF:
27
+ import tensorflow as tf
28
+
29
+ from model_compression_toolkit.quantizers_infrastructure import BaseKerasTrainableQuantizer, \
30
+ KerasQuantizationWrapper
31
+
32
+ class BaseKerasGPTQTrainableQuantizer(BaseKerasTrainableQuantizer):
33
+ """
34
+ A base class for trainable Keras quantizer for GPTQ.
35
+ """
36
+
37
+ def __init__(self,
38
+ quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
39
+ """
40
+ Initializes BaseKerasGPTQTrainableQuantizer object.
41
+
42
+ Args:
43
+ quantization_config: quantizer config class contains all the information about a quantizer configuration.
44
+ """
45
+
46
+ super().__init__(quantization_config)
47
+
48
+
49
+ def update_layer_quantization_params(self, layer: KerasQuantizationWrapper
50
+ ) -> (Dict[str, tf.Tensor], Dict[str, Dict], Dict):
51
+ """
52
+ A Function to calculate the needed change in attributes in NodeQuantizationConfig after retraining.
53
+
54
+ Args:
55
+ layer: A wrapped Keras layer.
56
+
57
+ Returns:
58
+ 3 dictionaries describing the change in layer's weights, weights config, activation config
59
+ that changed during GPTQ retraining.
60
+ Keys must match NodeQuantizationConfig attributes
61
+
62
+ """
63
+ weights = {}
64
+ for weight, quantizer_vars, quantizer in layer.get_weights_vars():
65
+ if not isinstance(quantizer, BaseTrainableQuantizer):
66
+ Logger.error(f"Expecting a GPTQ trainable quantizer, " # pragma: no cover
67
+ f"but got {type(quantizer)} which is not callable.")
68
+ weights.update({weight: quantizer(training=False, inputs=quantizer_vars)})
69
+
70
+ quant_config = {WEIGHTS_QUANTIZATION_PARAMS: self.get_quant_config()}
71
+
72
+ return weights, quant_config, {}
73
+
74
+ def get_aux_variable(self) -> List[tf.Tensor]:
75
+ """
76
+ This function return a list with the quantizer's quantization auxiliary variables.
77
+
78
+ Returns: A list with the quantization auxiliary variables.
79
+
80
+ """
81
+
82
+ return [] # pragma: no cover
83
+
84
+ def get_quantization_variable(self) -> List[tf.Tensor]:
85
+ """
86
+ This function return a list with the quantizer's quantization parameters variables.
87
+
88
+ Returns: A list with the quantization parameters.
89
+
90
+ """
91
+
92
+ return [] # pragma: no cover
93
+
94
+ @abstractmethod
95
+ def get_quant_config(self):
96
+ """
97
+ Returns the config used to edit NodeQuantizationConfig after GPTQ retraining.
98
+
99
+ Returns:
100
+ A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
101
+ Keys must match NodeQuantizationConfig attributes.
102
+
103
+ """
104
+ raise NotImplemented(f'{self.__class__.__name__} have to implement the ' # pragma: no cover
105
+ f'quantizer\'s get_quant_config.')
106
+
107
+ else:
108
+ class BaseKerasGPTQTrainableQuantizer: # pragma: no cover
109
+ def __init__(self, *args, **kwargs):
110
+ Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
111
+ 'when using BaseKerasGPTQTrainableQuantizer. '
112
+ 'Could not find Tensorflow package.') # pragma: no cover
@@ -26,6 +26,19 @@ def ste_ceil(x: tf.Tensor) -> tf.Tensor:
26
26
  return error + x
27
27
 
28
28
 
29
+ def safe_log(x: tf.Tensor, eps: float) -> tf.Tensor:
30
+ """
31
+ Computes log function of x unless x is smaller than some small value, so the log function would not fail.
32
+
33
+ Args:
34
+ x: input variable.
35
+ eps: limit value.
36
+
37
+ Returns: log of x where x > eps, else, log of eps.
38
+ """
39
+ return tf.math.log(tf.maximum(x, eps))
40
+
41
+
29
42
  def ste_round(x: tf.Tensor) -> tf.Tensor:
30
43
  """
31
44
  Return the rounded values of a tensor.
@@ -59,20 +72,6 @@ def calculate_delta(max_tensor: tf.Tensor,
59
72
  return max_tensor / (2 ** (num_bits - int(signed)))
60
73
 
61
74
 
62
- def adjustable_steps(x: tf.Variable, t: float) -> tf.Tensor:
63
- """
64
- A function to gradually quantize a float variable to an integer of values [-1, 0 ,1]
65
- Args:
66
- x: input float variable
67
- t: temperature to control quantization
68
-
69
- Returns:
70
- semi-quantized variable
71
-
72
- """
73
- return tf.sigmoid(tf.add(x, 1) / t) + tf.sigmoid(tf.add(x, -1) / t) - 1
74
-
75
-
76
75
  def ste_clip(x: [tf.Tensor, tf.Variable], max_val=1, min_val=None) -> tf.Tensor:
77
76
  """
78
77
  clip a variable between fixed values such that min_val<=output<=max_val
@@ -0,0 +1,78 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Dict, List, Tuple
16
+
17
+ from model_compression_toolkit.gptq import GradientPTQConfigV2
18
+ from model_compression_toolkit.core import common
19
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
20
+ from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \
21
+ get_inferable_quantizer_kwargs
22
+ from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
23
+ from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
24
+ from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
25
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
26
+ get_inferable_quantizer_class
27
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import \
28
+ BaseKerasInferableQuantizer
29
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizer_config import \
30
+ get_trainable_quantizer_weights_config
31
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizers import \
32
+ get_trainable_quantizer_class
33
+
34
+
35
+ def quantization_builder(n: common.BaseNode,
36
+ gptq_config: GradientPTQConfigV2
37
+ ) -> Tuple[Dict[str, BaseKerasGPTQTrainableQuantizer], List[BaseKerasInferableQuantizer]]:
38
+ """
39
+ Build quantizers for a node according to its quantization configuration and
40
+ a global NoOpQuantizeConfig object.
41
+
42
+ Args:
43
+ n: Node to build its QuantizeConfig.
44
+ gptq_config (GradientPTQConfigV2): GradientPTQConfigV2 configuration.
45
+
46
+ Returns:
47
+ A dictionary which maps the weights kernel attribute to a quantizer for GPTQ training.
48
+ Note that we return a dictionary although there is only a single attribute that is being mapped to a quantizer,
49
+ to be compatible with the quantization infrastructure template.
50
+ """
51
+
52
+ weights_quantizers = {}
53
+ if n.is_weights_quantization_enabled():
54
+ quant_method = n.final_weights_quantization_cfg.weights_quantization_method
55
+
56
+ quantizer_class = get_trainable_quantizer_class(quant_target=QuantizationTarget.Weights,
57
+ quantizer_type=gptq_config.rounding_type,
58
+ quant_method=quant_method,
59
+ quantizer_base_class=BaseKerasGPTQTrainableQuantizer)
60
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=n.type,
61
+ fw_info=DEFAULT_KERAS_INFO)
62
+
63
+ weights_quantizers.update({kernel_attribute: quantizer_class(get_trainable_quantizer_weights_config(n),
64
+ **gptq_config.gptq_quantizer_params_override)})
65
+
66
+ activation_quantizers = []
67
+ if n.is_activation_quantization_enabled():
68
+ quant_method = n.final_activation_quantization_cfg.activation_quantization_method
69
+
70
+ quantizer_class = get_inferable_quantizer_class(quant_target=QuantizationTarget.Activation,
71
+ quant_method=quant_method,
72
+ quantizer_base_class=BaseKerasInferableQuantizer)
73
+
74
+ kwargs = get_inferable_quantizer_kwargs(n, QuantizationTarget.Activation)
75
+
76
+ activation_quantizers.append(quantizer_class(**kwargs))
77
+
78
+ return weights_quantizers, activation_quantizers