mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -13,45 +13,24 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Dict, Any, List
16
+ from typing import Dict, Any
17
17
 
18
18
  import numpy as np
19
19
  import tensorflow as tf
20
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
21
- from tensorflow.python.framework.tensor_shape import TensorShape
22
- from model_compression_toolkit.core.keras.quantizer.base_quantizer import BaseTrainableQuantizer
23
- from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
20
+
21
+ from model_compression_toolkit.gptq import RoundingType
22
+ from model_compression_toolkit import quantizers_infrastructure as qi
23
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
24
+ from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD
25
+ from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
24
26
  from model_compression_toolkit.core.common.constants import THRESHOLD
25
27
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
26
- from model_compression_toolkit.gptq.keras.quantizer.kernel_functions import get_kernel
27
- from model_compression_toolkit.gptq.common import gptq_constants
28
-
29
-
30
- def symmetric_quantizer(input_tensor: tf.Tensor,
31
- max_tensor: tf.Tensor,
32
- num_bits: int,
33
- signed: bool,
34
- power_of_two: bool = False) -> tf.Tensor:
35
- """
36
- Quantize a tensor symmetrically.
37
- Args:
38
- input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
39
- max_tensor: Tensor with max values to compute the threshold.
40
- num_bits: Num of bits to use.
41
- signed: Signedness of the quantization range.
42
- power_of_two: Whether the threshold should be constrained or not.
43
-
44
- Returns:
45
- A quantized tensor.
46
- """
47
-
48
- if power_of_two:
49
- max_tensor = qutils.power_of_two_max(max_tensor)
50
- delta = qutils.calculate_delta(max_tensor, num_bits, signed)
51
- tensor_q = qutils.ste_round(input_tensor / delta)
52
- min_int = -int(signed) * (2 ** (num_bits - int(signed)))
53
- max_int = (2 ** (num_bits - int(signed))) - 1
54
- return delta * qutils.ste_clip(tensor_q, max_val=max_int, min_val=min_int)
28
+ from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
29
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
30
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
31
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
32
+ get_threshold_reshape_shape
33
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
55
34
 
56
35
 
57
36
  def pertubation_symmetric_quantizer(input_tensor: tf.Tensor,
@@ -63,6 +42,7 @@ def pertubation_symmetric_quantizer(input_tensor: tf.Tensor,
63
42
  max_lsbs_change: int = 1) -> tf.Tensor:
64
43
  """
65
44
  Quantize a tensor symmetrically with maximum LSBs shift.
45
+
66
46
  Args:
67
47
  input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
68
48
  auxvar_tensor: Tensor that manifests the bit shift the weight due to gptq
@@ -87,195 +67,115 @@ def pertubation_symmetric_quantizer(input_tensor: tf.Tensor,
87
67
  return delta * qutils.ste_clip(tensor_q, max_val=max_int, min_val=min_int)
88
68
 
89
69
 
90
- class STEWeightQuantizer(BaseTrainableQuantizer):
70
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
71
+ quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
72
+ quantizer_type=RoundingType.STE)
73
+ class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):
91
74
  """
92
- Trainable constrained quantizer to quantize a layer inputs.
75
+ Trainable symmetric quantizer to quantize a layer weights.
93
76
  """
94
77
 
95
78
  def __init__(self,
96
- num_bits: int,
97
- per_axis: bool,
98
- signed: bool,
99
- threshold_values: np.ndarray,
100
- quantization_axis: int = -1,
101
- power_of_two: bool = True,
79
+ quantization_config: TrainableQuantizerWeightsConfig,
102
80
  max_lsbs_change_map: dict = DefaultDict({}, lambda: 1)):
103
81
  """
104
- Initialize a TrainableWeightQuantizer object with parameters to use
105
- for the quantization.
82
+ Initialize a STEWeightGPTQQuantizer object with parameters to use for the quantization.
106
83
 
107
84
  Args:
108
- num_bits: Number of bits to use for the quantization.
109
- per_axis: Whether to quantize per-channel or per-tensor.
110
- signed: Signedness to use for the quantization range.
111
- threshold_values: Threshold to use for the quantization.
112
- quantization_axis: Axis of tensor to use for the quantization.
113
- power_of_two: Whether the threshold should be constrained or not.
85
+ quantization_config: Trainable weights quantizer config.
114
86
  max_lsbs_change_map: a mapping between number of bits to max lsb change.
115
87
  """
116
- self.num_bits = num_bits
117
- self.per_axis = per_axis
118
- self.signed = signed
88
+ super().__init__(quantization_config)
89
+ self.num_bits = quantization_config.weights_n_bits
90
+ self.per_channel = quantization_config.weights_per_channel_threshold
91
+
92
+ threshold_values = quantization_config.weights_quantization_params[THRESHOLD]
119
93
  self.threshold_shape = np.asarray(threshold_values).shape
120
- self.threshold_values = np.reshape(np.asarray(threshold_values), [-1]) if self.per_axis else float(
94
+ self.threshold_values = np.reshape(np.asarray(threshold_values), [-1]) if self.per_channel else float(
121
95
  threshold_values)
122
- self.quantization_axis = quantization_axis
123
- self.power_of_two = power_of_two
124
- self.max_lsbs_change = max_lsbs_change_map.get(num_bits)
125
- self.quantizer_parameters = {}
126
96
 
127
- def build(self,
128
- tensor_shape: TensorShape,
129
- name: str,
130
- layer: QuantizeWrapper) -> Dict[str, tf.Variable]:
97
+ self.quantization_axis = quantization_config.weights_channels_axis
98
+ self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
99
+ self.max_lsbs_change = max_lsbs_change_map.get(self.num_bits)
100
+
101
+ def initialize_quantization(self,
102
+ tensor_shape: Any,
103
+ name: str,
104
+ layer: Any):
131
105
  """
132
- Add min and max variables to layer.
133
- Args:
134
- tensor_shape: Tensor shape the quantizer quantize.
135
- name: Prefix of variables names.
136
- layer: Layer to add the variables to. The variables are saved
137
- in the layer's scope.
106
+ Add quantizer parameters to the quantizer parameters dictionary
138
107
 
139
- Returns:
140
- Dictionary of new variables.
108
+ Args:
109
+ tensor_shape: tensor shape of the quantized tensor.
110
+ name: Tensor name.
111
+ layer: Layer to quantize.
141
112
  """
142
- w_shape = get_kernel(layer.weights).shape
143
- ar_iter = layer.add_weight(
144
- name + gptq_constants.GPTQ_ITER,
145
- shape=(),
146
- initializer=tf.keras.initializers.Constant(0.0),
147
- trainable=False)
148
113
 
149
114
  ptq_threshold_tensor = layer.add_weight(
150
- name + gptq_constants.THRESHOLD_TENSOR,
151
- shape=len(self.threshold_values) if self.per_axis else (),
115
+ f"{name}_{PTQ_THRESHOLD}",
116
+ shape=len(self.threshold_values) if self.per_channel else (),
152
117
  initializer=tf.keras.initializers.Constant(1.0),
153
118
  trainable=False)
154
119
  ptq_threshold_tensor.assign(self.threshold_values)
155
120
 
121
+ w = getattr(layer.layer, name)
156
122
  auxvar_tensor = layer.add_weight(
157
- name + gptq_constants.AUXVAR,
158
- shape=w_shape,
123
+ f"{name}_{AUXVAR}",
124
+ shape=list(w.shape),
159
125
  initializer=tf.keras.initializers.Constant(0.0),
160
126
  trainable=True)
161
127
 
162
128
  # save the quantizer added parameters for later calculations
163
- self.quantizer_parameters = {gptq_constants.THRESHOLD_TENSOR: ptq_threshold_tensor,
164
- gptq_constants.AUXVAR: auxvar_tensor,
165
- gptq_constants.GPTQ_ITER: ar_iter}
166
- return self.quantizer_parameters
129
+ self.add_quantizer_variable(PTQ_THRESHOLD, ptq_threshold_tensor, VariableGroup.QPARAMS)
130
+ self.add_quantizer_variable(AUXVAR, auxvar_tensor, VariableGroup.WEIGHTS)
167
131
 
168
- def __call__(self, inputs: tf.Tensor,
169
- training: bool,
170
- weights: Dict[str, tf.Variable],
171
- **kwargs: Dict[str, Any]):
132
+ def __call__(self,
133
+ inputs: tf.Tensor,
134
+ training: bool):
172
135
  """
173
136
  Quantize a tensor.
137
+
174
138
  Args:
175
139
  inputs: Input tensor to quantize.
176
140
  training: Whether the graph is in training mode.
177
- weights: Dictionary of weights the quantizer can use to quantize the tensor.
178
- **kwargs: Additional variables the quantizer may receive.
179
141
 
180
142
  Returns:
181
143
  The quantized tensor.
182
144
  """
183
145
 
184
- auxvar = weights[gptq_constants.AUXVAR]
185
- ptq_threshold_tensor = weights[gptq_constants.THRESHOLD_TENSOR]
146
+ auxvar = self.get_quantizer_variable(AUXVAR)
147
+ ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
186
148
 
187
- if self.per_axis:
188
- input_shape = inputs.shape
189
- n_axis = len(input_shape)
190
- quantization_axis = n_axis + self.quantization_axis if self.quantization_axis < 0 else \
191
- self.quantization_axis
192
- reshape_shape = [-1 if i == quantization_axis else 1 for i in range(n_axis)]
149
+ if self.per_channel:
150
+ reshape_shape = get_threshold_reshape_shape(inputs.shape,
151
+ quant_axis=self.quantization_axis,
152
+ quant_axis_dim=-1)
193
153
  ptq_threshold_tensor = tf.reshape(ptq_threshold_tensor, reshape_shape)
194
- q_tensor = pertubation_symmetric_quantizer(inputs, auxvar,
154
+ q_tensor = pertubation_symmetric_quantizer(inputs,
155
+ auxvar,
195
156
  ptq_threshold_tensor,
196
157
  self.num_bits,
197
- self.signed,
198
- self.power_of_two,
158
+ signed=True,
159
+ power_of_two=self.power_of_two,
199
160
  max_lsbs_change=self.max_lsbs_change)
200
161
  return q_tensor
201
162
  else:
202
- return pertubation_symmetric_quantizer(inputs, auxvar,
163
+ return pertubation_symmetric_quantizer(inputs,
164
+ auxvar,
203
165
  ptq_threshold_tensor,
204
166
  self.num_bits,
205
- self.signed,
206
- self.power_of_two)
207
-
208
- def get_aux_variable(self) -> tf.Tensor:
209
- return self.quantizer_parameters[gptq_constants.AUXVAR]
210
-
211
- def get_config(self) -> Dict[str, Any]:
212
- """
213
- Returns: Configuration of TrainableQuantizer.
214
- """
167
+ signed=True,
168
+ power_of_two=self.power_of_two)
215
169
 
216
- return {
217
- 'num_bits': self.num_bits,
218
- 'per_axis': self.per_axis,
219
- 'symmetric': self.symmetric,
220
- 'power_of_two': self.power_of_two
221
- }
222
170
 
223
- def get_quant_config(self, layer) -> Dict[str, np.ndarray]:
171
+ def get_quant_config(self) -> Dict[str, np.ndarray]:
224
172
  """
225
173
  Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
226
174
 
227
- Args:
228
- layer: quantized layer
229
-
230
175
  Returns:
231
176
  A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
232
177
  Keys must match NodeQuantizationConfig attributes
233
178
 
234
179
  """
235
- old_threshold = self.quantizer_parameters[gptq_constants.THRESHOLD_TENSOR]
180
+ old_threshold = self.get_quantizer_variable(PTQ_THRESHOLD)
236
181
  return {THRESHOLD: old_threshold.numpy().reshape(self.threshold_shape)}
237
-
238
- def get_trainable_parameters(self):
239
- """
240
- A function to get a list trainable of trainable parameters of the quantizer for GPTQ retraining
241
-
242
- Returns:
243
- A list of trainable Tensors
244
-
245
- """
246
- return [t for t in self.quantizer_parameters.values() if t.trainable]
247
-
248
- def get_quantization_variable(self) -> List[tf.Tensor]:
249
- """
250
- This function return a list of quantizer parameters.
251
- Returns: A list of the quantizer parameters
252
-
253
- """
254
- return [self.quantizer_parameters[gptq_constants.THRESHOLD_TENSOR]]
255
-
256
- def __eq__(self, other: Any) -> bool:
257
- """
258
- Check if equals to another object.
259
- Args:
260
- other: Other object to compare.
261
-
262
- Returns:
263
- Whether they are equal or not.
264
- """
265
- if not isinstance(other, STEWeightQuantizer):
266
- return False
267
-
268
- return (self.num_bits == other.num_bits and
269
- self.per_axis == other.per_axis and
270
- self.symmetric == other.symmetric)
271
-
272
- def __ne__(self, other: Any) -> bool:
273
- """
274
- Check if not equals to another object.
275
- Args:
276
- other: Other object to compare.
277
-
278
- Returns:
279
- Whether they are differ or not.
280
- """
281
- return not self.__eq__(other)
@@ -12,24 +12,29 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Callable, List, Tuple
15
+ from typing import Callable, List, Tuple, Union
16
16
 
17
17
  import numpy as np
18
+ from torch.nn import Module
18
19
  from tqdm import tqdm
19
20
  import copy
20
21
  import torch
21
22
  from model_compression_toolkit.core.common.logger import Logger
23
+ from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
24
+ from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
22
25
  from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
23
26
  from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
24
- from model_compression_toolkit.core.common import Graph
27
+ from model_compression_toolkit.core.common import Graph, BaseNode
25
28
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
26
29
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
27
- from model_compression_toolkit.core.pytorch.constants import BIAS, KERNEL
28
- from model_compression_toolkit.gptq.pytorch.gptq_model_builder import GPTQPytorchModelBuilder
30
+ from model_compression_toolkit.core.pytorch.constants import BIAS
29
31
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, set_model, torch_tensor_to_numpy
30
- from model_compression_toolkit.gptq.pytorch.gptq_graph_info import get_trainable_parameters, get_weights_for_loss
31
- from model_compression_toolkit.gptq.pytorch.quantizer.quantizer_wrapper import WeightQuantizerWrapper
32
- from model_compression_toolkit.gptq.pytorch.gptq_graph_info import get_gumbel_probability
32
+ from model_compression_toolkit.gptq.pytorch.graph_info import get_gptq_trainable_parameters, \
33
+ get_weights_for_loss
34
+ from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder
35
+ from model_compression_toolkit import quantizers_infrastructure as qi
36
+ from model_compression_toolkit.gptq.pytorch.quantizer.regularization_factory import get_regularization
37
+ from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
33
38
 
34
39
 
35
40
  class PytorchGPTQTrainer(GPTQTrainer):
@@ -66,11 +71,9 @@ class PytorchGPTQTrainer(GPTQTrainer):
66
71
  else:
67
72
  self.input_scale = self.gptq_user_info.input_scale
68
73
 
69
- trainable_weights, trainable_bias, trainable_threshold, trainable_temperature = get_trainable_parameters(
74
+ trainable_weights, trainable_bias, trainable_threshold = get_gptq_trainable_parameters(
70
75
  self.fxp_model,
71
- add_bias=self.gptq_config.train_bias,
72
- quantization_parameters_learning=self.gptq_config.quantization_parameters_learning,
73
- is_gumbel=self.gptq_config.is_gumbel)
76
+ add_bias=self.gptq_config.train_bias)
74
77
 
75
78
  self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
76
79
  if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
@@ -81,10 +84,45 @@ class PytorchGPTQTrainer(GPTQTrainer):
81
84
 
82
85
  self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
83
86
  trainable_bias,
84
- trainable_threshold,
85
- trainable_temperature)
87
+ trainable_threshold)
86
88
 
87
- self.weights_for_average_loss = to_torch_tensor(self.compute_jacobian_based_weights(representative_data_gen))
89
+ self.weights_for_average_loss = to_torch_tensor(self.compute_hessian_based_weights(representative_data_gen))
90
+
91
+ self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
92
+
93
+ def _is_gptq_applicable(self,
94
+ node: BaseNode) -> bool:
95
+ """
96
+ A function for deciding if a layer should be fine-tuned during GPTQ.
97
+ Args:
98
+ node (BaseNode): Node for quantization decision
99
+ Returns:
100
+ A boolean whether the layer is to be wrapped with a Quantization Wrapper.
101
+ """
102
+
103
+ if node.is_weights_quantization_enabled() and not self.fw_info.is_kernel_op(node.type):
104
+ Logger.error(f"GPTQ Error: Quantizing node {node.name} of type {node.type} "
105
+ f"without a kernel isn't supported.")
106
+ return node.is_weights_quantization_enabled()
107
+
108
+ def gptq_wrapper(self, n: BaseNode, layer: Module) -> Union[qi.PytorchQuantizationWrapper, Module]:
109
+ """
110
+ A function which takes a computational graph node and a pytorch layer and perform the quantization wrapping.
111
+
112
+ Args:
113
+ n: A node of mct graph.
114
+ layer: A pytorch layer
115
+
116
+ Returns: Wrapped layer if the layer should be wrap, otherwise returns the layer as is.
117
+ """
118
+
119
+ if self._is_gptq_applicable(n):
120
+ weights_quantizers, activation_quantizers = quantization_builder(n, self.gptq_config)
121
+ return qi.PytorchQuantizationWrapper(layer,
122
+ weights_quantizers=weights_quantizers,
123
+ activation_quantizers=activation_quantizers)
124
+ else:
125
+ return layer
88
126
 
89
127
  def build_gptq_model(self):
90
128
  """
@@ -92,10 +130,13 @@ class PytorchGPTQTrainer(GPTQTrainer):
92
130
  Returns:
93
131
  Quantized graph for GPTQ fine-tuning, GPTQ graph user info
94
132
  """
95
- return GPTQPytorchModelBuilder(self.graph_quant,
96
- self.gptq_config,
97
- append2output=self.compare_points,
98
- return_float_outputs=True).build_model()
133
+ gptq_model, gptq_user_info = PyTorchModelBuilder(graph=self.graph_quant,
134
+ append2output=self.compare_points,
135
+ fw_info=self.fw_info,
136
+ wrapper=self.gptq_wrapper,
137
+ return_float_outputs=True).build_model()
138
+
139
+ return gptq_model, gptq_user_info
99
140
 
100
141
  def train(self, representative_data_gen: Callable):
101
142
  """
@@ -145,14 +186,9 @@ class PytorchGPTQTrainer(GPTQTrainer):
145
186
  self.compare_points_std,
146
187
  self.weights_for_average_loss)
147
188
 
148
- if self.gptq_config.is_gumbel and self.gptq_config.quantizer_config.temperature_learning:
149
- gumbel_prob = get_gumbel_probability(self.fxp_model)
150
- gumbel_reg = 0
151
- for p in gumbel_prob:
152
- entropy = -torch.mean(torch.sum(p * torch.log(torch.maximum(p, self.gptq_config.eps*torch.ones_like(p))),dim=0))
153
- gumbel_reg += entropy
154
- gumbel_reg = 0 if gumbel_reg == 0 else gumbel_reg/len(gumbel_prob)
155
- loss_value += self.gptq_config.quantizer_config.gumbel_entropy_regularization * gumbel_reg
189
+ reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor)
190
+
191
+ loss_value += reg_value
156
192
 
157
193
  # Back-pass
158
194
  loss_value.backward()
@@ -202,20 +238,23 @@ class PytorchGPTQTrainer(GPTQTrainer):
202
238
 
203
239
  # Update graph after training
204
240
  for name, layer in self.fxp_model.named_modules():
205
- if isinstance(layer, WeightQuantizerWrapper):
241
+ if isinstance(layer, PytorchQuantizationWrapper):
206
242
  node = self.graph_quant.find_node_by_name(name)
207
243
  if len(node) != 1:
208
244
  Logger.error(f"Can't update GPTQ graph due to missing layer named: {name}")
209
245
  node = node[0]
210
- # Weight
211
- node.set_weights_by_keys(KERNEL, self.fw_impl.to_numpy(layer.weight_quantizer(layer.float_weight, training=False)))
212
- # Weight quantization params
213
- if self.gptq_config.quantization_parameters_learning:
214
- node.final_weights_quantization_cfg.set_weights_quantization_param(layer.weight_quantizer.get_weight_quant_params())
215
- # Bias
216
- if self.gptq_config.train_bias and hasattr(layer.op, BIAS):
217
- node.set_weights_by_keys(BIAS, self.fw_impl.to_numpy(getattr(layer.op, BIAS)))
218
-
246
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
247
+ fw_info=self.fw_info)
248
+ weights, weight_quant_config, activation_quant_config = \
249
+ layer.weights_quantizers[kernel_attribute].update_layer_quantization_params(layer)
250
+ for weight_attr, weight in weights.items():
251
+ node.set_weights_by_keys(weight_attr, self.fw_impl.to_numpy(weight))
252
+ for config_attr, config_value in weight_quant_config.items():
253
+ node.final_weights_quantization_cfg.set_quant_config_attr(config_attr, config_value)
254
+ for config_attr, config_value in activation_quant_config.items():
255
+ node.final_activation_quantization_cfg.set_quant_config_attr(config_attr, config_value)
256
+ if self.gptq_config.train_bias and hasattr(layer.layer, BIAS):
257
+ node.set_weights_by_keys(BIAS, self.fw_impl.to_numpy(getattr(layer.layer, BIAS)))
219
258
 
220
259
  return graph_quant
221
260
 
@@ -229,7 +268,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
229
268
 
230
269
  # Fxp model: unfreeze bias trainable parameters
231
270
  for layer in self.fxp_model.modules():
232
- if isinstance(layer, WeightQuantizerWrapper):
233
- if hasattr(layer.op, BIAS):
234
- bias = getattr(layer.op, BIAS)
271
+ if isinstance(layer, PytorchQuantizationWrapper):
272
+ if hasattr(layer.layer, BIAS):
273
+ bias = getattr(layer.layer, BIAS)
235
274
  bias.requires_grad = self.gptq_config.train_bias
@@ -0,0 +1,81 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import torch
16
+ import torch.nn as nn
17
+ from typing import List
18
+ from model_compression_toolkit.core.pytorch.constants import BIAS
19
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
20
+ from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
21
+ from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
22
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
23
+
24
+
25
+ def get_gptq_trainable_parameters(fxp_model: nn.Module,
26
+ add_bias: bool = False,
27
+ ) -> (List[nn.Parameter], List[nn.Parameter], List[nn.Parameter]):
28
+ """
29
+ Get trainable parameters from all layers in a model
30
+
31
+ Args:
32
+ fxp_model: Model to get its trainable parameters.
33
+ add_bias: Whether to include biases of the model (if there are) or not.
34
+
35
+ Returns:
36
+ A list of trainable variables in a model. Each item is a list of a layers weights.
37
+ """
38
+
39
+ trainable_aux_weights = nn.ParameterList()
40
+ trainable_threshold = nn.ParameterList()
41
+ trainable_bias = nn.ParameterList()
42
+
43
+ for layer in fxp_model.modules():
44
+ if isinstance(layer, PytorchQuantizationWrapper):
45
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
46
+ fw_info=DEFAULT_PYTORCH_INFO)
47
+
48
+ # collect trainable weights per quantizer
49
+ quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
50
+ quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
51
+ trainable_aux_weights.extend(quantizer_trainable_weights)
52
+ trainable_threshold.extend(quantizer_trainable_threshold)
53
+
54
+ if add_bias and hasattr(layer.layer, BIAS):
55
+ bias = getattr(layer.layer, BIAS)
56
+ trainable_bias.append(bias)
57
+
58
+ return trainable_aux_weights, trainable_bias, trainable_threshold
59
+
60
+
61
+ def get_weights_for_loss(fxp_model: nn.Module) -> [List[nn.Parameter], List[torch.Tensor]]:
62
+ """
63
+ Get all float and quantized kernels for the GPTQ loss
64
+
65
+ Args:
66
+ fxp_model: Model to get its float and quantized weights.
67
+
68
+ Returns:
69
+ A list of float kernels, each item is the float kernel of the layer
70
+ A list of quantized kernels, each item is the quantized kernel of the layer
71
+ """
72
+
73
+ flp_weights_list, fxp_weights_list = [], []
74
+ for layer in fxp_model.modules():
75
+ if isinstance(layer, PytorchQuantizationWrapper):
76
+ # Collect pairs of float and quantized weights per layer
77
+ for weight, quantizer_vars, quantizer in layer.get_weights_vars():
78
+ flp_weights_list.append(quantizer_vars)
79
+ fxp_weights_list.append(quantizer(training=False, inputs=quantizer_vars))
80
+
81
+ return flp_weights_list, fxp_weights_list
@@ -21,6 +21,7 @@ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV
21
21
  from model_compression_toolkit.core.common.target_platform import TargetPlatformCapabilities
22
22
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
23
23
  from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
24
+ from model_compression_toolkit.gptq.keras.quantization_facade import GPTQ_MOMENTUM
24
25
  from model_compression_toolkit.gptq.runner import gptq_runner
25
26
  from model_compression_toolkit.core.exporter import export_model
26
27
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
@@ -38,7 +39,7 @@ if FOUND_TORCH:
38
39
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
39
40
  from model_compression_toolkit.core.pytorch.constants import DEFAULT_TP_MODEL
40
41
  from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss
41
- from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_fully_quantized_pytorch_model
42
+ from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
42
43
  import torch
43
44
  from torch.nn import Module
44
45
  from torch.optim import Adam, Optimizer
@@ -71,26 +72,19 @@ if FOUND_TORCH:
71
72
  Import MCT and Create a GradientPTQConfigV2 to run for 5 epochs:
72
73
 
73
74
  >>> import model_compression_toolkit as mct
74
- >>> gptq_conf = mct.get_pytorch_gptq_config(n_epochs=5)
75
+ >>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=5)
75
76
 
76
77
  Other PyTorch optimizers can be passed with dummy params:
77
78
 
78
79
  >>> import torch
79
- >>> gptq_conf = mct.get_pytorch_gptq_config(n_epochs=3, optimizer=torch.optim.Adam([torch.Tensor(1)]))
80
+ >>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=3, optimizer=torch.optim.Adam([torch.Tensor(1)]))
80
81
 
81
82
  The configuration can be passed to :func:`~model_compression_toolkit.pytorch_post_training_quantization` in order to quantize a pytorch model using gptq.
82
83
 
83
84
  """
84
- bias_optimizer = Adam([torch.Tensor([])], lr=LR_BIAS_DEFAULT)
85
- optimizer_quantization_parameter = Adam([torch.Tensor([])], lr=LR_QUANTIZATION_PARAM_DEFAULT)
86
- return GradientPTQConfigV2(n_epochs,
87
- optimizer,
88
- optimizer_rest=optimizer_rest,
89
- loss=loss,
90
- log_function=log_function,
91
- train_bias=True,
92
- optimizer_quantization_parameter=optimizer_quantization_parameter,
93
- optimizer_bias=bias_optimizer)
85
+ bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
86
+ return GradientPTQConfigV2(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
87
+ log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer)
94
88
 
95
89
 
96
90
  def pytorch_gradient_post_training_quantization_experimental(model: Module,
@@ -152,15 +146,15 @@ if FOUND_TORCH:
152
146
 
153
147
  Pass the module, the representative dataset generator and the configuration (optional) to get a quantized module
154
148
 
155
- >>> quantized_module, quantization_info = mct.pytorch_gradient_post_training_quantization_experimental(module, repr_datagen, core_config=config, gptq_config=gptq_conf)
149
+ >>> quantized_module, quantization_info = mct.gptq.pytorch_gradient_post_training_quantization_experimental(module, repr_datagen, core_config=config, gptq_config=gptq_conf)
156
150
 
157
151
  """
158
152
 
159
153
  if core_config.mixed_precision_enable:
160
154
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
161
155
  common.Logger.error("Given quantization config to mixed-precision facade is not of type "
162
- "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization API,"
163
- "or pass a valid mixed precision configuration.")
156
+ "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
157
+ "API, or pass a valid mixed precision configuration.") # pragma: no cover
164
158
 
165
159
  common.Logger.info("Using experimental mixed-precision quantization. "
166
160
  "If you encounter an issue please file a bug.")
@@ -212,10 +206,10 @@ else:
212
206
  def get_pytorch_gptq_config(*args, **kwargs):
213
207
  Logger.critical('Installing Pytorch is mandatory '
214
208
  'when using pytorch_gradient_post_training_quantization_experimental. '
215
- 'Could not find torch package.')
209
+ 'Could not find torch package.') # pragma: no cover
216
210
 
217
211
 
218
212
  def pytorch_gradient_post_training_quantization_experimental(*args, **kwargs):
219
213
  Logger.critical('Installing Pytorch is mandatory '
220
214
  'when using pytorch_gradient_post_training_quantization_experimental. '
221
- 'Could not find the torch package.')
215
+ 'Could not find the torch package.') # pragma: no cover