mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -1,104 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import List, Tuple
17
-
18
- import tensorflow as tf
19
- from keras.models import Model
20
- from tensorflow.python.util.object_identity import Reference as TFReference
21
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
22
-
23
- from model_compression_toolkit.core import common
24
- from model_compression_toolkit.core.common import BaseNode
25
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
26
- from model_compression_toolkit.core.common.user_info import UserInformation
27
- from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder, \
28
- is_layer_fake_quant, get_node_name_from_layer
29
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
30
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
31
- from model_compression_toolkit.gptq.keras.quantizer.config_factory import quantization_config_builder_gptq
32
-
33
-
34
- class GPTQKerasModelBuilder(KerasModelBuilder):
35
- """
36
- Builder of GPTQ Keras models.
37
- """
38
-
39
- def __init__(self,
40
- graph: common.Graph,
41
- gptq_config: GradientPTQConfig,
42
- append2output=None,
43
- fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
44
- return_float_outputs: bool = True):
45
- """
46
-
47
- Args:
48
- graph: Graph to build the model from.
49
- gptq_config: Configuration for GPTQ optimization.
50
- append2output: Nodes to append to model's output.
51
- fw_info: Information about the specific framework of the model that is built.
52
- return_float_outputs: Whether the model returns float tensors or not.
53
- """
54
-
55
- super().__init__(graph,
56
- append2output,
57
- fw_info,
58
- return_float_outputs)
59
- self.gptq_config = gptq_config
60
-
61
- def _quantize_node_activations(self,
62
- node: BaseNode,
63
- input_tensors: List[TFReference]) -> List[TFReference]:
64
- """
65
- Quantize node's activation given input tensors.
66
-
67
- Args:
68
- node: Node to quantize its outputs.
69
- input_tensors: Input tensors of the node.
70
-
71
- Returns:
72
- Output of the node.
73
-
74
- """
75
-
76
- return node.final_activation_quantization_cfg.quantize_node_output(input_tensors)
77
-
78
- def build_model(self) -> Tuple[Model, UserInformation]:
79
- """
80
- Build a Keras GPTQ model and return it.
81
- Returns: GPTQ Keras model.
82
-
83
- """
84
- model, user_info = super().build_model()
85
-
86
- def _quantize(layer):
87
-
88
- node = self.oh.layer_to_node_dict.get(layer)
89
-
90
- if node is not None:
91
- return QuantizeWrapper(layer, quantization_config_builder_gptq(node, self.fw_info, self.gptq_config))
92
-
93
- elif is_layer_fake_quant(layer):
94
- return layer
95
-
96
- else:
97
- raise Exception(
98
- f"Mismatch between keras model and graph can't find node named: "
99
- f"{get_node_name_from_layer(layer)}")
100
-
101
- # clone each layer in the model and apply _quantize to the layer.
102
- model = tf.keras.models.clone_model(model, input_tensors=None, clone_function=_quantize)
103
-
104
- return model, user_info
@@ -1,119 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from typing import List, Callable, Tuple
16
-
17
- import tensorflow as tf
18
- from model_compression_toolkit.gptq.keras.quantizer.configs.weight_quantizer_gptq_config import WeightQuantizeConfig
19
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
20
-
21
-
22
- class SAM:
23
- """
24
- This class implements Sharpness-Aware Minimization for Efficiently Improving Generalization (https://arxiv.org/abs/2010.01412)
25
- """
26
-
27
- def __init__(self, model2quantized,
28
- gradient_step: Callable,
29
- optimizer_with_param: List[Tuple[List, List[tf.Tensor]]],
30
- rho: float = 0.01,
31
- eps: float = 1e-12):
32
- """
33
- The init function of Sharpness-Aware Minimization gradient computation class.
34
- Args:
35
- model2quantized: Input quantized module
36
- gradient_step: A function that returns a list of gradients tensors
37
- optimizer_with_param: A list of optimizer classes to update with the corresponding parameters.
38
- rho: A floating point number that set the region of smoothness
39
- eps: A floating point number for numeric stability
40
- """
41
- assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
42
- self.rho = rho
43
- self.eps = eps
44
- self.gradient_step = gradient_step
45
-
46
- self.trainable_variables = [p for _, p in optimizer_with_param]
47
- self.m_var = [len(p) for p in self.trainable_variables]
48
- self.n_groups = len(self.trainable_variables)
49
- self.model2quantized = model2quantized
50
- self.e_ws = [[] for _ in range(len(optimizer_with_param))]
51
-
52
- def _enable_update_step_param(self):
53
- """
54
- This function enables the parameter update (update iteration index and gumbel random variable)
55
- Returns: None
56
-
57
- """
58
- for layer in self.model2quantized.layers:
59
- if isinstance(layer, QuantizeWrapper) and isinstance(
60
- layer.quantize_config, WeightQuantizeConfig):
61
- layer.quantize_config.enable_update()
62
-
63
- def _disable_update_step_param(self):
64
- """
65
- This function disables the parameter update (update iteration index and gumbel random variable)
66
- Returns: None
67
-
68
- """
69
- for layer in self.model2quantized.layers:
70
- if isinstance(layer, QuantizeWrapper) and isinstance(
71
- layer.quantize_config, WeightQuantizeConfig):
72
- layer.quantize_config.disable_update()
73
-
74
- def _update_w_location(self, gradients: List[List[tf.Tensor]]):
75
- """
76
- This function updates the weights position to the highest point
77
- Args:
78
- gradients: A list of gradients tensors
79
-
80
- Returns: None
81
-
82
- """
83
-
84
- for g in range(self.n_groups):
85
- self.e_ws[g].clear()
86
- grad_norm = tf.linalg.global_norm(gradients[g])
87
- ew_multiplier = self.rho / (grad_norm + self.eps)
88
- for i in range(self.m_var[g]):
89
- e_w = tf.math.multiply(gradients[g][i], ew_multiplier)
90
- self.trainable_variables[g][i].assign_add(e_w)
91
- self.e_ws[g].append(e_w)
92
-
93
- def _restore_w_location(self):
94
- """
95
- Restore weights to the original position
96
- Returns: None
97
-
98
- """
99
- for g in range(self.n_groups):
100
- for i in range(self.m_var[g]):
101
- self.trainable_variables[g][i].assign_add(-self.e_ws[g][i])
102
-
103
- def compute_gradients(self, *arg, **kwargs) -> (tf.Tensor, List[List[tf.Tensor]]):
104
- """
105
- This function compute the gradients of SAM optimizer
106
- Args:
107
- *arg: args to pass to the gradient step functions
108
- **kwargs: kwargs to pass to the gradient step functions
109
-
110
- Returns: A tensor of the loss value and a list of gradients tensors
111
-
112
- """
113
- self._enable_update_step_param()
114
- loss, grad = self.gradient_step(*arg, **kwargs)
115
- self._update_w_location(grad)
116
- self._disable_update_step_param()
117
- loss, grad = self.gradient_step(*arg, **kwargs)
118
- self._restore_w_location()
119
- return loss, grad
@@ -1,62 +0,0 @@
1
- # Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_configs import \
18
- NoOpQuantizeConfig
19
- from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_registry import \
20
- QuantizeConfig
21
-
22
- from model_compression_toolkit.core import common
23
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
24
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
25
- from model_compression_toolkit.gptq.keras.quantizer import WeightQuantizeConfig
26
-
27
-
28
- def quantization_config_builder_gptq(n: common.BaseNode,
29
- fw_info: FrameworkInfo,
30
- gptq_config: GradientPTQConfig) -> QuantizeConfig:
31
- """
32
- Build a QuantizeConfig for a node according to its quantization configuration and
33
- a global NoOpQuantizeConfig object.
34
-
35
- Args:
36
- n: Node to build its QuantizeConfig.
37
- fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
38
- gptq_config: GPTQ Configuration class..
39
-
40
- Returns:
41
- A QuantizeConfig object with the appropriate quantizers (according to the node's
42
- quantization configuration).
43
- """
44
-
45
- if n.is_weights_quantization_enabled() and n.is_activation_quantization_enabled():
46
- qc = WeightQuantizeConfig(fw_info.get_kernel_op_attributes(n.type),
47
- n.final_weights_quantization_cfg,
48
- gptq_config)
49
- elif n.is_activation_quantization_enabled() and not n.is_weights_quantization_enabled():
50
- qc = NoOpQuantizeConfig() # Quantization is Preformed using fake quantization node
51
- elif n.is_weights_quantization_enabled() and not n.is_activation_quantization_enabled():
52
- qc = WeightQuantizeConfig(fw_info.get_kernel_op_attributes(n.type),
53
- n.final_weights_quantization_cfg,
54
- gptq_config)
55
-
56
- elif not n.is_weights_quantization_enabled() and not n.is_activation_quantization_enabled():
57
- qc = NoOpQuantizeConfig()
58
-
59
- else:
60
- raise Exception('Undefined quantization method')
61
-
62
- return qc
@@ -1,65 +0,0 @@
1
- # Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_config import QuantizeConfig
17
- from typing import Tuple, List, Any, Dict
18
- from tensorflow import Tensor
19
- import six, abc
20
-
21
-
22
- @six.add_metaclass(abc.ABCMeta)
23
- class BaseQuantizeConfig(QuantizeConfig):
24
- """
25
- Base QuantizeConfig to define extra API methods needed by the GPTQ post-processing.
26
- """
27
-
28
- @abc.abstractmethod
29
- def get_quantization_variable(self):
30
- """
31
- A Functions that get the quantization parameters such as threshold, min, max ,etc.
32
-
33
- Returns: A list of trainable variable
34
-
35
- """
36
-
37
- @abc.abstractmethod
38
- def update_layer_quantization_params(self, layer) -> Tuple[Dict[str, Any],
39
- Dict[str, Any],
40
- Dict[str, Any]]:
41
- """
42
- A Function to calculate the needed change in attributes in NodeQuantizationConfig after retraining.
43
- Usually a function of the config quantizers.
44
-
45
- Args:
46
- layer: layer being quantized.
47
-
48
- Returns:
49
- 3 dictionaries of attributes the quantize_config retraining has changed during GPTQ retraining.
50
- Keys must match NodeQuantizationConfig attributes:
51
- 1. layer weights
52
- 2. weight quantization config attributes
53
- 3. activation quantization config attributes
54
-
55
- """
56
-
57
- @abc.abstractmethod
58
- def get_trainable_quantizer_parameters(self) -> List[Tensor]:
59
- """
60
- A function to get a list trainable of trainable parameters for GPTQ retraining from config quantizers
61
-
62
- Returns:
63
- A list of trainable Tensors
64
-
65
- """
@@ -1,269 +0,0 @@
1
- # Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import List, Tuple, Any, Dict
17
-
18
- from tensorflow import Tensor
19
- import tensorflow as tf
20
- from packaging import version
21
-
22
- # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
23
-
24
-
25
- if version.parse(tf.__version__) < version.parse("2.6"):
26
- from tensorflow.python.keras.layers import Layer
27
- else:
28
- from keras.engine.base_layer import Layer
29
-
30
- from tensorflow.python.training.tracking.data_structures import ListWrapper
31
- from tensorflow_model_optimization.python.core.quantization.keras.quantizers import Quantizer
32
-
33
- from model_compression_toolkit.gptq.keras.quantizer.configs.base_quantizer_gptq_config import BaseQuantizeConfig
34
- from model_compression_toolkit.core.keras.constants import KERNEL
35
-
36
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2, RoundingType
37
- from model_compression_toolkit.gptq.keras.quantizer.gumbel_rounding.symmetric_gumbel import SymmetricGumbelRounding
38
- from model_compression_toolkit.gptq.keras.quantizer.gumbel_rounding.uniform_gumbel import UniformGumbelRounding
39
- from model_compression_toolkit.gptq.keras.quantizer.ste_rounding.symmetric_ste import STEWeightQuantizer
40
- from model_compression_toolkit.core.common.target_platform.op_quantization_config import QuantizationMethod
41
- from model_compression_toolkit.core.common.constants import THRESHOLD, RANGE_MAX, RANGE_MIN
42
- from model_compression_toolkit.core import common
43
- from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
44
- from model_compression_toolkit.gptq.common import gptq_constants
45
-
46
-
47
- class WeightQuantizeConfig(BaseQuantizeConfig):
48
- """
49
- QuantizeConfig to quantize the weights of a layer using a TrainableQuantizer.
50
- """
51
-
52
- def __init__(self, weight_attrs: List[str],
53
- final_weights_quantization_cfg: NodeWeightsQuantizationConfig,
54
- gptq_config: GradientPTQConfigV2):
55
- """
56
- Initialize a TrainableQuantizer and set as the weights quantizer.
57
- Args:
58
- weight_attrs: Attributes of the layer's weights to quantize.
59
- final_weights_quantization_cfg: quantization config of the current layer.
60
- gptq_config: A GPTQ configuration calls.
61
- """
62
-
63
- num_bits = final_weights_quantization_cfg.weights_n_bits
64
- weight_channel_axis = final_weights_quantization_cfg.weights_channels_axis
65
- max_lsbs_change_map = gptq_config.lsb_change_per_bit_width
66
- self.weight_attrs = weight_attrs
67
- self.final_weights_quantization_cfg = final_weights_quantization_cfg
68
- self.gptq_config = gptq_config
69
-
70
- if final_weights_quantization_cfg.weights_quantization_method in [QuantizationMethod.SYMMETRIC,
71
- QuantizationMethod.POWER_OF_TWO]:
72
- is_power_of_two = QuantizationMethod.POWER_OF_TWO == final_weights_quantization_cfg.weights_quantization_method
73
- threshold_values = final_weights_quantization_cfg.weights_quantization_params.get(THRESHOLD)
74
- if gptq_config.rounding_type == RoundingType.STE:
75
- self.weight_quantizer = STEWeightQuantizer(num_bits=num_bits,
76
- per_axis=len(
77
- threshold_values.flatten()) > 1,
78
- threshold_values=threshold_values,
79
- signed=True,
80
- power_of_two=is_power_of_two,
81
- quantization_axis=weight_channel_axis,
82
- max_lsbs_change_map=max_lsbs_change_map)
83
- elif gptq_config.rounding_type == RoundingType.GumbelRounding:
84
- self.weight_quantizer = SymmetricGumbelRounding(num_bits=num_bits,
85
- per_axis=len(
86
- threshold_values.flatten()) > 1,
87
- threshold_values=threshold_values,
88
- signed=True,
89
- power_of_two=is_power_of_two,
90
- quantization_parameter_learning=gptq_config.quantization_parameters_learning,
91
- quantization_axis=weight_channel_axis,
92
- max_lsbs_change_map=max_lsbs_change_map,
93
- max_iteration=gptq_config.n_epochs,
94
- gumbel_config=gptq_config.quantizer_config)
95
- else:
96
- common.Logger.error(
97
- f"For quantization method {final_weights_quantization_cfg.weights_quantization_method}, GPTQ Rounding type {gptq_config.rounding_type} is not supported")
98
- elif final_weights_quantization_cfg.weights_quantization_method == QuantizationMethod.UNIFORM:
99
- if not gptq_config.rounding_type == RoundingType.GumbelRounding:
100
- common.Logger.error(
101
- f"For quantization method {final_weights_quantization_cfg.weights_quantization_method}, GPTQ Rounding type {gptq_config.rounding_type} is not supported")
102
- range_max = final_weights_quantization_cfg.weights_quantization_params.get(RANGE_MAX)
103
- range_min = final_weights_quantization_cfg.weights_quantization_params.get(RANGE_MIN)
104
- self.weight_quantizer = UniformGumbelRounding(num_bits=num_bits,
105
- per_axis=len(
106
- range_max.flatten()) > 1,
107
- min_range=range_min,
108
- max_range=range_max,
109
- signed=True,
110
- quantization_parameter_learning=gptq_config.quantization_parameters_learning,
111
- quantization_axis=weight_channel_axis,
112
- max_lsbs_change_map=max_lsbs_change_map,
113
- max_iteration=gptq_config.n_epochs,
114
- gumbel_config=gptq_config.quantizer_config)
115
-
116
- def enable_update(self):
117
- """
118
- This function enable the parameter update (update iteration index and gumbel random variable)
119
- Returns: None
120
-
121
- """
122
- if self.gptq_config.is_gumbel:
123
- return self.weight_quantizer.enable_update()
124
-
125
- def disable_update(self):
126
- """
127
-
128
- This function disable the parameter update (update iteration index and gumbel random variable)
129
- Returns: None
130
-
131
- """
132
- if self.gptq_config.is_gumbel:
133
- return self.weight_quantizer.disable_update()
134
-
135
- def get_weights_and_quantizers(self, layer: Layer) -> List[Tuple[Tensor, Quantizer]]:
136
- """
137
- Get a list of tuples with weights and the weight quantizer.
138
- The layer's attributes are used to get the weights.
139
- Args:
140
- layer: The layer the WeightQuantizeConfig wraps.
141
-
142
- Returns:
143
- List of tuples of the layer's weights and the weight quantizer.
144
- """
145
- return [(getattr(layer, weight_attr), self.weight_quantizer)
146
- for weight_attr in self.weight_attrs]
147
-
148
- def get_activations_and_quantizers(self, layer: Layer) -> list:
149
- return []
150
-
151
- def set_quantize_weights(self, layer: Layer, quantize_weights: List[Tensor]):
152
- """
153
- Set the layer weights with new passed weights.
154
- Args:
155
- layer: Layer to set its attributes.
156
- quantize_weights: Quantized weights to set as new weights.
157
-
158
- """
159
- if len(self.weight_attrs) != len(quantize_weights):
160
- raise ValueError(
161
- '`set_quantize_weights` called on layer {} with {} '
162
- 'weight parameters, but layer expects {} values.'.format(
163
- layer.name, len(quantize_weights), len(self.weight_attrs))) # pragma: no cover
164
-
165
- for weight_attr, weight in zip(self.weight_attrs, quantize_weights):
166
- current_weight = getattr(layer, weight_attr)
167
- if current_weight.shape != weight.shape:
168
- raise ValueError('Existing layer weight shape {} is incompatible with'
169
- 'provided weight shape {}'.format(
170
- current_weight.shape, weight.shape)) # pragma: no cover
171
-
172
- setattr(layer, weight_attr, weight)
173
-
174
- def set_quantize_activations(self, layer, quantize_activations: ListWrapper):
175
- pass
176
-
177
- def get_output_quantizers(self, layer: Layer) -> list:
178
- return []
179
-
180
- @classmethod
181
- def from_config(cls, config: dict):
182
- """
183
- Instantiates a `WeightQuantizeConfig` from its config.
184
-
185
- Args:
186
- config: Output of `get_config()`.
187
-
188
- Returns:
189
- A `WeightQuantizeConfig` instance.
190
- """
191
-
192
- return cls(**config)
193
-
194
- def get_config(self) -> Dict[str, Any]:
195
- """
196
- Returns: The WeightQuantizeConfig configuration.
197
- """
198
- return {
199
- 'weight_attrs': self.weight_attrs,
200
- 'final_weights_quantization_cfg': self.final_weights_quantization_cfg,
201
- 'gptq_config': self.gptq_config,
202
- }
203
-
204
- def update_layer_quantization_params(self, layer: Layer) -> (Dict[str, tf.Tensor], Dict[str, Dict], Dict):
205
- """
206
- A Function to calculate the needed change in attributes in NodeQuantizationConfig after retraining.
207
- Usually a function of the config quantizers.
208
-
209
- Args:
210
- layer: layer being quantized.
211
-
212
- Returns:
213
- 3 dictionaries describing the change in layer's weights, weights config, activation config
214
- that changed during GPTQ retraining.
215
- Keys must match NodeQuantizationConfig attributes
216
-
217
- """
218
- weights = {}
219
- for weight, quantizer, quantizer_vars in layer._weight_vars:
220
- weights.update({KERNEL: quantizer(weight, training=False, weights=quantizer_vars)})
221
-
222
- quant_config = {gptq_constants.WEIGHTS_QUANTIZATION_PARAMS: self.weight_quantizer.get_quant_config(layer)}
223
-
224
- return weights, quant_config, {}
225
-
226
- def get_trainable_quantizer_parameters(self) -> List[tf.Tensor]:
227
- """
228
- A function to get a list trainable of trainable parameters for GPTQ retraining from config quantizers
229
-
230
- Returns:
231
- A list of trainable Tensors
232
-
233
- """
234
- return self.weight_quantizer.get_trainable_parameters()
235
-
236
- def get_aux_variable(self) -> List[tf.Tensor]:
237
- return [self.weight_quantizer.get_aux_variable()]
238
-
239
- def get_quantization_variable(self) -> List[tf.Tensor]:
240
- return self.weight_quantizer.get_quantization_variable()
241
-
242
- def get_temperature_variable(self) -> tf.Tensor:
243
- if self.gptq_config.is_gumbel:
244
- return self.weight_quantizer.get_temperature_variable()
245
- else:
246
- common.logger.Logger.error("Temperature variable only exist when using Gumbel Rounding Quantizer")
247
-
248
- def get_gumbel_probability(self) -> tf.Tensor:
249
- if self.gptq_config.is_gumbel:
250
- return self.weight_quantizer.get_gumbel_probability()
251
- else:
252
- common.logger.Logger.error("Probability variable only exist when using Gumbel Rounding Quantizer")
253
-
254
- def __eq__(self, other: Any) -> bool:
255
- """
256
- Check whether it equals to another object or not.
257
- """
258
- if not isinstance(other, WeightQuantizeConfig):
259
- return False
260
-
261
- return (self.weight_attrs == other.weight_attrs and
262
- self.weight_quantizer == other.weight_quantizer and
263
- self.gptq_config == other.gptq_config)
264
-
265
- def __ne__(self, other: Any) -> bool:
266
- """
267
- Check whether it differs from another object or not.
268
- """
269
- return not self.__eq__(other)