mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -16,61 +16,46 @@ from enum import Enum
16
16
  from typing import Callable, Any, Dict
17
17
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
18
18
  from model_compression_toolkit.core import common
19
-
20
- MAX_LSBS_CHANGE_MAP = {8: 2,
21
- 4: 1,
22
- 2: 1}
23
-
24
- N_CYCLES = 4
25
- MIM_TEMP = 0.5
26
- MAX_TEMP = 1.0
27
- GAMMA_TEMPERATURE = 0.1
28
- GUMBEL_SCALE = 0.5
19
+ from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR, MAX_LSB_STR, REG_DEFAULT
29
20
 
30
21
 
31
22
  class RoundingType(Enum):
32
23
  """
33
24
  An enum for choosing the GPTQ rounding methods
34
25
  0. STRAIGHT-THROUGH ESTIMATOR
35
- 1. Gumbel Rounding
26
+ 1. SoftQuantizer
36
27
  """
37
28
  STE = 0
38
- GumbelRounding = 1
29
+ SoftQuantizer = 1
39
30
 
40
31
 
41
- class GumbelConfig(object):
32
+ class GPTQHessianWeightsConfig:
42
33
  """
43
- Configuration to use for quantization with Gumbel Rounding.
34
+ Configuration to use for computing the Hessian-based weights for GPTQ loss metric.
44
35
  """
45
36
 
46
37
  def __init__(self,
47
- temperature_learning: bool = True,
48
- n_cycles: int = N_CYCLES,
49
- minimal_temp: float = MIM_TEMP,
50
- maximal_temp: float = MAX_TEMP,
51
- gumbel_entropy_regularization: float = GAMMA_TEMPERATURE,
52
- gumbel_scale: float = GUMBEL_SCALE,
53
- gumbel_scale_per_bitwidth: Dict[int, float] = None):
54
- """
55
- Initialize a GumbelConfig.
56
-
38
+ hessians_num_samples: int = 16,
39
+ norm_weights: bool = True,
40
+ log_norm: bool = True,
41
+ scale_log_norm: bool = False,
42
+ hessians_n_iter: int = 50):
57
43
 
44
+ """
45
+ Initialize a GPTQHessianWeightsConfig.
58
46
  Args:
59
- temperature_learning (bool): Whether to update the temperature during the training or not.
60
- gumbel_entropy_regularization (float): A floating point number that defines the gumbel entropy regularization factor.
61
- n_cycles (int): A floating point number that defines the gumbel entropy regularization factor.
62
- minimal_temp (float): A floating point number that defines the gumbel entropy regularization factor.
63
- maximal_temp (float): A floating point number that defines the gumbel entropy regularization factor.
64
- gumbel_scale (float): A normalization factor for the gumbel tensor values.
65
- gumbel_scale_per_bitwidth (dict): An optional mapping between a bit-width and a gumbel scale value for Gumbel Rounding,
47
+ hessians_num_samples (int): Number of samples to use for computing the Hessian-based weights.
48
+ norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
49
+ log_norm (bool): Whether to use log normalization to the GPTQ Hessian-based weights.
50
+ scale_log_norm (bool): Whether to scale the final vector of the Hessian weights.
51
+ hessians_n_iter (int): Number of random iterations to run Hessian approximation for GPTQ weights.
66
52
  """
67
- self.gumbel_entropy_regularization = gumbel_entropy_regularization
68
- self.temperature_learning = temperature_learning
69
- self.n_cycles = n_cycles
70
- self.minimal_temp = minimal_temp
71
- self.maximal_temp = maximal_temp
72
- self.gumbel_scale = gumbel_scale
73
- self.gumbel_scale_per_bitwidth = gumbel_scale_per_bitwidth
53
+
54
+ self.hessians_num_samples = hessians_num_samples
55
+ self.norm_weights = norm_weights
56
+ self.log_norm = log_norm
57
+ self.scale_log_norm = scale_log_norm
58
+ self.hessians_n_iter = hessians_n_iter
74
59
 
75
60
 
76
61
  class GradientPTQConfig:
@@ -78,27 +63,19 @@ class GradientPTQConfig:
78
63
  Configuration to use for quantization with GradientPTQ (experimental).
79
64
  """
80
65
 
81
- def __init__(self,
82
- n_iter: int,
66
+ def __init__(self, n_iter: int,
83
67
  optimizer: Any,
84
68
  optimizer_rest: Any = None,
85
69
  loss: Callable = None,
86
70
  log_function: Callable = None,
87
71
  train_bias: bool = True,
88
- quantization_parameters_learning: bool = False,
89
- sam_optimization: bool = False,
90
- rounding_type: RoundingType = RoundingType.GumbelRounding,
91
- rho: float = 0.01,
92
- lsb_change_per_bit_width: dict = DefaultDict(MAX_LSBS_CHANGE_MAP, lambda: 1),
93
- eps: float = 1e-6,
94
- use_jac_based_weights: bool = True,
95
- num_samples_for_loss: int = 16,
96
- norm_weights: bool = False,
97
- quantizer_config: GumbelConfig = GumbelConfig(),
72
+ rounding_type: RoundingType = RoundingType.SoftQuantizer,
73
+ use_hessian_based_weights: bool = True,
98
74
  optimizer_quantization_parameter: Any = None,
99
75
  optimizer_bias: Any = None,
100
- log_norm: bool = True,
101
- weights_n_iter: int = 50):
76
+ regularization_factor: float = REG_DEFAULT,
77
+ hessian_weights_config: GPTQHessianWeightsConfig = GPTQHessianWeightsConfig(),
78
+ gptq_quantizer_params_override: Dict[str, Any] = None):
102
79
  """
103
80
  Initialize a GradientPTQConfig.
104
81
 
@@ -111,20 +88,13 @@ class GradientPTQConfig:
111
88
  accordingly. see example in multiple_tensors_mse_loss
112
89
  log_function (Callable): Function to log information about the GPTQ process.
113
90
  train_bias (bool): Whether to update the bias during the training or not.
114
- quantization_parameters_learning (bool): Whether to update the quantization param during the training or not.
115
- sam_optimization (bool): Whether to use sam optimization.
116
- rounding_type (RoundingType): An enum that defines the rounding type (STE or GumbelRoudning).
117
- rho (rho): A floating point number that defines the sam optimization lookahead.
118
- lsb_change_per_bit_width (dict): Whether to update the bias during the training or not.
119
- eps (float): A floating point value for numeric stability.
120
- use_jac_based_weights (bool): Whether to use jacobian-based weights for weighted average loss.
121
- num_samples_for_loss (int): Number of samples to use for computing the jacobian-based weights.
122
- norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
123
- quantizer_config (Any): A class the contins the quantizer specific config.
91
+ rounding_type (RoundingType): An enum that defines the rounding type.
92
+ use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
124
93
  optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters.
125
- optimizer_bias (Any): Optimizer to override the rest optimizerfor bias.
126
- log_norm (bool): Whether to use log normalization to the GPTQ Jacobian-based weights.
127
- weights_n_iter (int): Number of random iterations to run Jacobian approximation for GPTQ weights.
94
+ optimizer_bias (Any): Optimizer to override the rest optimizer for bias.
95
+ regularization_factor (float): A floating point number that defines the regularization factor.
96
+ hessian_weights_config (GPTQHessianWeightsConfig): A configuration that include all necessary arguments to run a computation of Hessian weights for the GPTQ loss.
97
+ gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
128
98
 
129
99
  """
130
100
  self.n_iter = n_iter
@@ -133,58 +103,35 @@ class GradientPTQConfig:
133
103
  self.loss = loss
134
104
  self.log_function = log_function
135
105
  self.train_bias = train_bias
136
- self.quantization_parameters_learning = quantization_parameters_learning
106
+
137
107
  self.rounding_type = rounding_type
138
- self.sam_optimization = sam_optimization
139
- self.rho = rho
140
- self.lsb_change_per_bit_width = lsb_change_per_bit_width
141
- self.eps = eps
142
- self.use_jac_based_weights = use_jac_based_weights
143
- self.num_samples_for_loss = num_samples_for_loss
144
- self.norm_weights = norm_weights
145
- if not isinstance(quantizer_config, GumbelConfig) and self.is_gumbel:
146
- common.Logger.error("Please use GumbelConfig as quantizer config when using Gumbel Rounding")
147
- self.quantizer_config = quantizer_config
108
+ self.use_hessian_based_weights = use_hessian_based_weights
148
109
  self.optimizer_quantization_parameter = optimizer_quantization_parameter
149
110
  self.optimizer_bias = optimizer_bias
150
- self.log_norm = log_norm
151
- self.weights_n_iter = weights_n_iter
111
+ self.regularization_factor = regularization_factor
112
+ self.hessian_weights_config = hessian_weights_config
152
113
 
153
- @property
154
- def is_gumbel(self) -> bool:
155
- """
156
- This function state if Gumbel Rounding is in use.
157
- Returns: boolean
158
-
159
- """
160
- return self.rounding_type == RoundingType.GumbelRounding
114
+ self.gptq_quantizer_params_override = {} if gptq_quantizer_params_override is None \
115
+ else gptq_quantizer_params_override
161
116
 
162
117
 
163
118
  class GradientPTQConfigV2(GradientPTQConfig):
164
119
  """
165
120
  Configuration to use for quantization with GradientPTQV2 (experimental).
166
121
  """
167
- def __init__(self,
168
- n_epochs: int,
122
+ def __init__(self, n_epochs: int,
169
123
  optimizer: Any,
170
124
  optimizer_rest: Any = None,
171
125
  loss: Callable = None,
172
126
  log_function: Callable = None,
173
127
  train_bias: bool = True,
174
- quantization_parameters_learning: bool = False,
175
- sam_optimization: bool = False,
176
- rounding_type: RoundingType = RoundingType.GumbelRounding,
177
- rho: float = 0.01,
178
- lsb_change_per_bit_width: dict = DefaultDict(MAX_LSBS_CHANGE_MAP, lambda: 1),
179
- eps: float = 1e-6,
180
- use_jac_based_weights: bool = True,
181
- num_samples_for_loss: int = 16,
182
- norm_weights: bool = False,
183
- quantizer_config: GumbelConfig = GumbelConfig(),
128
+ rounding_type: RoundingType = RoundingType.SoftQuantizer,
129
+ use_hessian_based_weights: bool = True,
184
130
  optimizer_quantization_parameter: Any = None,
185
131
  optimizer_bias: Any = None,
186
- log_norm: bool = True,
187
- weights_n_iter: int = 50):
132
+ regularization_factor: float = REG_DEFAULT,
133
+ hessian_weights_config: GPTQHessianWeightsConfig = GPTQHessianWeightsConfig(),
134
+ gptq_quantizer_params_override: Dict[str, Any] = None):
188
135
  """
189
136
  Initialize a GradientPTQConfigV2.
190
137
 
@@ -197,20 +144,13 @@ class GradientPTQConfigV2(GradientPTQConfig):
197
144
  accordingly. see example in multiple_tensors_mse_loss
198
145
  log_function (Callable): Function to log information about the GPTQ process.
199
146
  train_bias (bool): Whether to update the bias during the training or not.
200
- quantization_parameters_learning (bool): Whether to update the quantization param during the training or not.
201
- sam_optimization (bool): Whether to use sam optimization.
202
- rounding_type (RoundingType): An enum that defines the rounding type (STE or GumbelRoudning).
203
- rho (rho): A floating point number that defines the sam optimization lookahead.
204
- lsb_change_per_bit_width (dict): Whether to update the bias during the training or not.
205
- eps (float): A floating point value for numeric stability.
206
- use_jac_based_weights (bool): Whether to use jacobian-based weights for weighted average loss.
207
- num_samples_for_loss (int): Number of samples to use for computing the jacobian-based weights.
208
- norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
209
- quantizer_config (Any): A class the contins the quantizer specific config.
147
+ rounding_type (RoundingType): An enum that defines the rounding type.
148
+ use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
210
149
  optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters.
211
150
  optimizer_bias (Any): Optimizer to override the rest optimizerfor bias.
212
- log_norm (bool): Whether to use log normalization to the GPTQ Jacobian-based weights.
213
- weights_n_iter (int): Number of random iterations to run Jacobian approximation for GPTQ weights.
151
+ regularization_factor (float): A floating point number that defines the regularization factor.
152
+ hessian_weights_config (GPTQHessianWeightsConfig): A configuration that include all necessary arguments to run a computation of Hessian weights for the GPTQ loss.
153
+ gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
214
154
 
215
155
  """
216
156
 
@@ -220,20 +160,13 @@ class GradientPTQConfigV2(GradientPTQConfig):
220
160
  loss=loss,
221
161
  log_function=log_function,
222
162
  train_bias=train_bias,
223
- quantization_parameters_learning=quantization_parameters_learning,
224
- sam_optimization=sam_optimization,
225
163
  rounding_type=rounding_type,
226
- rho=rho,
227
- lsb_change_per_bit_width=lsb_change_per_bit_width,
228
- eps=eps,
229
- use_jac_based_weights=use_jac_based_weights,
230
- num_samples_for_loss=num_samples_for_loss,
231
- norm_weights=norm_weights,
232
- quantizer_config=quantizer_config,
164
+ use_hessian_based_weights=use_hessian_based_weights,
233
165
  optimizer_quantization_parameter=optimizer_quantization_parameter,
234
166
  optimizer_bias=optimizer_bias,
235
- log_norm=log_norm,
236
- weights_n_iter=weights_n_iter)
167
+ regularization_factor=regularization_factor,
168
+ hessian_weights_config=hessian_weights_config,
169
+ gptq_quantizer_params_override=gptq_quantizer_params_override)
237
170
  self.n_epochs = n_epochs
238
171
 
239
172
  @classmethod
@@ -248,8 +181,5 @@ class GradientPTQConfigV2(GradientPTQConfig):
248
181
  """
249
182
  n_epochs = int(round(config_v1.n_iter) / n_ptq_iter)
250
183
  v1_params = config_v1.__dict__
251
- v1_params.pop('n_iter')
184
+ v1_params = {k: v for k, v in v1_params.items() if k != 'n_iter'}
252
185
  return cls(n_epochs, **v1_params)
253
-
254
-
255
-
@@ -1,11 +1,25 @@
1
+ # Parameters names
1
2
  AUXVAR = 'auxvar_tensor'
2
3
  ITERVAR = 'iteration_variable'
3
- THRESHOLD_TENSOR = "ptq_threshold_tensor"
4
4
  SCALE_TENSOR = "scale_ptq_tensor"
5
- GPTQ_ITER = "_gptq_iter"
6
- AUXSHIFT = '_shift'
7
- TEMP = '_temp'
5
+ AUXSHIFT = 'shift'
8
6
  WEIGHTS_QUANTIZATION_PARAMS = 'weights_quantization_params'
9
- PTQ_MIN_RANGE = "_min_range"
10
- PTQ_MAX_RANGE = "_max_range"
7
+ PTQ_MIN_RANGE = "min_range"
8
+ PTQ_MAX_RANGE = "max_range"
9
+ PTQ_THRESHOLD = "ptq_threshold"
10
+ SCALE_PTQ = "scale"
11
11
 
12
+ # Default quantizer values
13
+ N_CYCLES = 4
14
+ MIM_TEMP = 0.5
15
+ MAX_TEMP = 1.0
16
+ REG_DEFAULT = 0.01
17
+ MAX_LSB_CHANGE = 1
18
+
19
+ # Soft rounding arguments values
20
+ SOFT_ROUNDING_GAMMA = -0.1
21
+ SOFT_ROUNDING_ZETA = 1.1
22
+
23
+ # GPTQ config constant
24
+ QUANT_PARAM_LEARNING_STR = 'quantization_parameter_learning'
25
+ MAX_LSB_STR = 'max_lsbs_change_map'
@@ -13,6 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from typing import Tuple, List
16
+
17
+ from model_compression_toolkit import FrameworkInfo
18
+ from model_compression_toolkit.core.common import Logger
16
19
  from model_compression_toolkit.core.common.graph.base_graph import Graph
17
20
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
18
21
 
@@ -42,3 +45,22 @@ def get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str], L
42
45
  compare_points_std.append(n.prior_info.std_output)
43
46
  compare_points_mean.append(n.prior_info.mean_output)
44
47
  return compare_points, compare_points_name, compare_points_mean, compare_points_std
48
+
49
+
50
+ def get_kernel_attribute_name_for_gptq(layer_type: type, fw_info: FrameworkInfo) -> str:
51
+ """
52
+ Returns a layer's kernel attribute name for GPTQ training purposes.
53
+
54
+ Args:
55
+ layer_type: A type of model's layer.
56
+ fw_info: A FrameworkInfo object.
57
+
58
+ Returns: The name of the kernel attribute.
59
+
60
+ """
61
+ kernel_attribute = fw_info.get_kernel_op_attributes(layer_type)
62
+ if len(kernel_attribute) != 1:
63
+ Logger.error( # pragma: no cover
64
+ f"In GPTQ training only the kernel weights attribute should be trained, but number of kernel "
65
+ f"attributes is {len(kernel_attribute)}.")
66
+ return kernel_attribute[0]
@@ -20,6 +20,7 @@ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
20
20
  from model_compression_toolkit.core.common import Graph, Logger, BaseNode
21
21
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
22
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
23
+ from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
23
24
  from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
24
25
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
25
26
 
@@ -69,26 +70,23 @@ class GPTQTrainer(ABC):
69
70
  def get_optimizer_with_param(self,
70
71
  flattened_trainable_weights: List[Any],
71
72
  flattened_bias_weights: List[Any],
72
- trainable_quantization_parameters: List[Any],
73
- temperature_weights: List[Any]) -> List[Any]:
73
+ trainable_quantization_parameters: List[Any]) -> List[Any]:
74
74
  """
75
75
  Create Optimizers with their trainable parameters
76
76
  Args:
77
77
  flattened_trainable_weights: list of trainable weights parameters (flattened)
78
78
  flattened_bias_weights: list of trainable bias parameters (flattened)
79
79
  trainable_quantization_parameters: list of trainable quantization parameters
80
- temperature_weights: list of temperature weights variables
81
80
  Returns:
82
81
  List of Optimizer objects with parameters
83
82
  """
84
83
 
85
84
  w2train = [*flattened_trainable_weights]
86
- if self.gptq_config.is_gumbel:
87
- if self.gptq_config.quantizer_config.temperature_learning:
88
- w2train.extend(temperature_weights)
85
+
86
+ quant_params_learning = self.gptq_config.gptq_quantizer_params_override.get(QUANT_PARAM_LEARNING_STR, False)
89
87
 
90
88
  optimizer_with_param = [(self.gptq_config.optimizer, w2train)]
91
- if self.gptq_config.train_bias or self.gptq_config.quantization_parameters_learning:
89
+ if self.gptq_config.train_bias or quant_params_learning:
92
90
  w2train_res = []
93
91
  if self.gptq_config.train_bias:
94
92
  if self.gptq_config.optimizer_bias is not None:
@@ -96,35 +94,42 @@ class GPTQTrainer(ABC):
96
94
  else:
97
95
  w2train_res.extend(flattened_bias_weights)
98
96
  if self.gptq_config.optimizer_rest is None:
99
- Logger.error(
97
+ Logger.error( # pragma: no cover
100
98
  "To enable bias micro training an additional optimizer is required, please define the optimizer_rest")
101
- if self.gptq_config.quantization_parameters_learning:
99
+ if quant_params_learning:
102
100
  if self.gptq_config.optimizer_quantization_parameter is not None: # Ability to override optimizer
103
101
  optimizer_with_param.append((self.gptq_config.optimizer_quantization_parameter,
104
102
  trainable_quantization_parameters))
105
103
  else:
106
104
  w2train_res.extend(trainable_quantization_parameters)
107
105
  if self.gptq_config.optimizer_rest is None:
108
- Logger.error(
109
- "To enable bias micro training an additional optimizer is required, please define the optimizer_rest")
110
- optimizer_with_param.append((self.gptq_config.optimizer_rest, w2train_res))
106
+ Logger.error( # pragma: no cover
107
+ "To enable quantization parameters micro training an additional optimizer is required, please define the optimizer_rest")
108
+ if len(w2train_res) > 0:
109
+ # Either bias or quantization parameters are trainable but did not provide a specific optimizer,
110
+ # so we should use optimizer_rest to train them
111
+ if self.gptq_config.optimizer_rest is None:
112
+ Logger.error( # pragma: no cover
113
+ "To enable bias or quantization parameters micro training an additional optimizer is required, please define the optimizer_rest")
114
+ optimizer_with_param.append((self.gptq_config.optimizer_rest, w2train_res))
111
115
 
112
116
  return optimizer_with_param
113
117
 
114
118
 
115
- def compute_jacobian_based_weights(self,
116
- representative_data_gen: Callable) -> np.ndarray:
119
+ def compute_hessian_based_weights(self,
120
+ representative_data_gen: Callable) -> np.ndarray:
117
121
  """
118
- Computes the jacobian-based weights using the framework's model_grad method per batch of images.
122
+ Computes the Hessian-based weights using the framework's model_grad method per batch of images.
119
123
 
120
124
  Args:
121
- representative_data_gen: Dataset used for inference to compute the jacobian-based weights.
125
+ representative_data_gen: Dataset used for inference to compute the Hessian-based weights.
122
126
 
123
127
  Returns: A vector of weights, one for each compare point,
124
128
  to be used for the loss metric weighted average computation when running GPTQ training.
125
129
  """
126
- if self.gptq_config.use_jac_based_weights:
127
- images = self._generate_images_batch(representative_data_gen, self.gptq_config.num_samples_for_loss)
130
+ if self.gptq_config.use_hessian_based_weights:
131
+ images = self._generate_images_batch(representative_data_gen,
132
+ self.gptq_config.hessian_weights_config.hessians_num_samples)
128
133
 
129
134
  model_output_replacement = self._get_model_output_replacement()
130
135
 
@@ -142,17 +147,18 @@ class GPTQTrainer(ABC):
142
147
  output_list=model_output_replacement,
143
148
  all_outputs_indices=[],
144
149
  alpha=0,
145
- norm_weights=self.gptq_config.norm_weights,
146
- n_iter=self.gptq_config.weights_n_iter)
150
+ norm_weights=self.gptq_config.hessian_weights_config.norm_weights,
151
+ n_iter=self.gptq_config.hessian_weights_config.hessians_n_iter)
147
152
  points_apprx_jacobians_weights.append(image_ip_gradients)
148
- if self.gptq_config.log_norm:
153
+ if self.gptq_config.hessian_weights_config.log_norm:
149
154
  mean_jacobian_weights = np.mean(points_apprx_jacobians_weights, axis=0)
150
155
  mean_jacobian_weights = np.where(mean_jacobian_weights != 0, mean_jacobian_weights,
151
156
  np.partition(mean_jacobian_weights, 1)[1])
152
157
  log_weights = np.log10(mean_jacobian_weights)
153
158
 
154
- # To add scaling to the normalized weights replace return statement with the following line:
155
- # return log_weights - np.min(log_weights) / (np.max(log_weights) - np.min(log_weights))
159
+ if self.gptq_config.hessian_weights_config.scale_log_norm:
160
+ return (log_weights - np.min(log_weights)) / (np.max(log_weights) - np.min(log_weights))
161
+
156
162
  return log_weights - np.min(log_weights)
157
163
  else:
158
164
  return np.mean(points_apprx_jacobians_weights, axis=0)
@@ -204,7 +210,7 @@ class GPTQTrainer(ABC):
204
210
  Quantized graph for GPTQ fine-tuning, GPTQ graph user info
205
211
  """
206
212
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
207
- f'framework\'s GPTQ model builder method.')
213
+ f'framework\'s GPTQ model builder method.') # pragma: no cover
208
214
 
209
215
  @abstractmethod
210
216
  def train(self, representative_data_gen: Callable):
@@ -214,7 +220,7 @@ class GPTQTrainer(ABC):
214
220
  representative_data_gen: Dataset to use for inputs of the models.
215
221
  """
216
222
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
217
- f'framework\'s train method.')
223
+ f'framework\'s train method.') # pragma: no cover
218
224
 
219
225
  @abstractmethod
220
226
  def update_graph(self) -> Graph:
@@ -225,7 +231,7 @@ class GPTQTrainer(ABC):
225
231
  Updated graph after GPTQ.
226
232
  """
227
233
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
228
- f'framework\'s update_graph method.')
234
+ f'framework\'s update_graph method.') # pragma: no cover
229
235
 
230
236
  def _get_model_output_replacement(self) -> List[BaseNode]:
231
237
  """
@@ -86,6 +86,7 @@ def mse_loss_per_tensor(y: tf.Tensor,
86
86
  _loss = tf.reduce_mean(tf.pow(tf.abs(y - x), p))
87
87
  return _loss / tf.reduce_mean(tf.pow(tf.abs(x), p)) if normalized else _loss
88
88
 
89
+
89
90
  def activation_mse(flp_act_list,
90
91
  fxp_act_list,
91
92
  p_vector=None,
@@ -116,7 +117,6 @@ def activation_mse(flp_act_list,
116
117
  return tf.reduce_mean(tf.stack(loss_values_list)), tf.reduce_mean(tf.stack(bias_loss_list))
117
118
 
118
119
 
119
-
120
120
  class GPTQMultipleTensorsLoss:
121
121
  def __init__(self, norm_loss: bool = False):
122
122
  self.alpha = None