mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -1,157 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from typing import Union, List
16
- from abc import abstractmethod
17
- import torch
18
- import numpy as np
19
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
20
- from model_compression_toolkit.core.common.logger import Logger
21
- from model_compression_toolkit.gptq.pytorch.quantizer.gptq_quantizer import BaseWeightQuantizer
22
- from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
23
- from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import sample_gumbel
24
- from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
25
- from model_compression_toolkit.core.common.target_platform.op_quantization_config import QuantizationMethod
26
- from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import ste_clip
27
-
28
- P_INIT = 0.01
29
- GR_SHIFT_BASE = 2
30
-
31
-
32
- def init_aux_var(ceil_indicator: torch.Tensor, w_shape: torch.Size, m: int, p: float = P_INIT) -> torch.Tensor:
33
- """
34
- This function generate a random pi matrix for Gumbel Rounding
35
- Args:
36
- ceil_indicator: An array of indicator if the value should be ceil or floor.
37
- w_shape(torch.Size): A list of integers that represent the shape of the weights tensor to be quantization
38
- m(int): An integer that define the number of shift
39
- p(float): A floating point number that represent the probability of non-round options of pi matrix.
40
-
41
- Returns: A torch tensor of pi tensor
42
-
43
- """
44
-
45
- if m < 2:
46
- Logger.error("m must be larger than two")
47
- if m % 2 != 0:
48
- Logger.error("m must be module two")
49
- m_hat = m // 2 - 1
50
- shift = -np.log(-np.log(1 - p))
51
- n = np.random.randn(*[m, *w_shape]) * np.sqrt(np.power(np.pi, 2) / 6)
52
- n = n.reshape([m, -1]).T
53
- ceil_indicator = ceil_indicator.cpu().numpy().flatten()
54
- n[np.arange(ceil_indicator.size), ceil_indicator + m_hat] += shift
55
- n = n.T.reshape(*[m, *w_shape])
56
- return torch.from_numpy(n).float()
57
-
58
-
59
- def init_shift_var(m: int) -> torch.Tensor:
60
- """
61
- This function generate a tensor of 2*m+1 from -m to m
62
- Args:
63
- m: An integer value the represent m
64
-
65
- Returns: A tensor of size m
66
-
67
- """
68
- m_hat = m // 2
69
- aux_index_shift = [-m_hat + i + 1 for i in range(m)]
70
- return torch.Tensor(aux_index_shift)
71
-
72
-
73
- class BaseGumbelWeightQuantizer(BaseWeightQuantizer):
74
- """
75
- Base class that implements a quantizer with trainable parameters to be used for GPTQ training.
76
- """
77
-
78
- def __init__(self,
79
- weights_quantization_cfg: NodeWeightsQuantizationConfig,
80
- gptq_config: GradientPTQConfigV2,
81
- weight_shape: torch.Size):
82
- """
83
- Construct a Pytorch model that utilize a fake weight quantizer of Gumbel rounding
84
- Args:
85
- weights_quantization_cfg: Configuration of weight quantization
86
- gptq_config: GradientPTQConfigV2 object with parameters about the tuning process.
87
- weight_shape: weight shape for auxiliary tensor creation.
88
- """
89
- super().__init__()
90
- self.power_of_two = QuantizationMethod.POWER_OF_TWO == weights_quantization_cfg.weights_quantization_method
91
- self.reshape_aux_shift = [-1, *[1 for _ in range(len(weight_shape))]]
92
- self.num_bits = weights_quantization_cfg.weights_n_bits
93
- self.weight_shape = weight_shape
94
- self.max_delta_change = gptq_config.lsb_change_per_bit_width.get(self.num_bits)
95
- self.quantization_parameter_learning = gptq_config.quantization_parameters_learning
96
- self.m = GR_SHIFT_BASE * self.max_delta_change + GR_SHIFT_BASE
97
- self.minimal_temp = gptq_config.quantizer_config.minimal_temp
98
- self.maximal_temp = gptq_config.quantizer_config.maximal_temp
99
- self.temperature_learning = gptq_config.quantizer_config.temperature_learning
100
- self.cycle_iterations = max(1, int(gptq_config.n_epochs / gptq_config.quantizer_config.n_cycles))
101
- self.shift_tensor = to_torch_tensor(init_shift_var(self.m))
102
- self.tau = None
103
- self.g_t = 0
104
- self.p_t = None
105
- self.n_iter = 0
106
- self.update_gumbel_param = True
107
- scale = self.cycle_iterations / (-2 * np.log(0.001))
108
-
109
- self.gumbel_scale = gptq_config.quantizer_config.gumbel_scale
110
- self.gumbel_scale_per_bitwidth = gptq_config.quantizer_config.gumbel_scale_per_bitwidth
111
-
112
- def tau_function(i: int) -> float:
113
- """
114
- A function that generates the gumbel temperature.
115
- Args:
116
- i: An int that represents the current iteration number
117
-
118
- Returns: A temperature value.
119
-
120
- """
121
- if i < (self.cycle_iterations - 1):
122
- index = ((i + 1) % self.cycle_iterations) / scale
123
- else:
124
- index = (i % self.cycle_iterations) / scale
125
-
126
- x = np.exp(-index)
127
- return self.minimal_temp + (self.maximal_temp - self.minimal_temp) * x
128
-
129
- self.tau_function = tau_function
130
-
131
- def get_gumbel_probability(self) -> torch.Tensor:
132
- """
133
- A function that return the gumbel probability value.
134
- Returns: gumbel probability
135
- """
136
- return self.p_t
137
-
138
- def update_iteration(self, training):
139
- """
140
- A function that update parameters for GPTQ fine-tuning
141
- Args:
142
- training: whether in training mode or not
143
- """
144
- if self.temperature_learning:
145
- self.tau = ste_clip(self.temp_tensor, self.minimal_temp, self.maximal_temp)
146
- else:
147
- self.tau = self.tau_function(self.n_iter)
148
- if self.update_gumbel_param and training:
149
- self.n_iter += 1
150
- self.g_t = sample_gumbel([self.m, *self.weight_shape])
151
-
152
- @abstractmethod
153
- def get_temperature_variable(self) -> Union[torch.Tensor, List]:
154
- """
155
- Returns temperature trainable variables
156
- """
157
- raise Logger.error(f"{self.__class__.__name__} have to implement this abstract function.")
@@ -1,150 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import torch
16
- import torch.nn as nn
17
- from typing import List, Union
18
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
19
- from model_compression_toolkit.gptq.pytorch.quantizer.gumbel_rounding.base_gumbel_weights_quantizer import BaseGumbelWeightQuantizer, init_aux_var
20
- from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
21
- from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import symmetric_quantizer
22
- from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import ste_clip, ste_gumbel, gumbel_softmax, power_of_two_max
23
- from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, THRESHOLD_TENSOR, TEMP, SCALE_TENSOR
24
- from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
25
- from model_compression_toolkit.core.common.constants import THRESHOLD
26
-
27
-
28
- class SymmetricGumbelWeightQuantizer(BaseGumbelWeightQuantizer):
29
- """
30
- Class that implements a quantizer with trainable parameters to be used for GPTQ training.
31
- """
32
-
33
- def __init__(self,
34
- weights_quantization_cfg: NodeWeightsQuantizationConfig,
35
- gptq_config: GradientPTQConfig,
36
- weight: torch.Tensor):
37
- """
38
- Construct a Pytorch model that utilize a fake weight quantizer of Symmetric Gumbel rounding
39
- Args:
40
- weights_quantization_cfg: Configuration of weight quantization
41
- gptq_config: GradientPTQConfig object with parameters about the tuning process.
42
- weight: weight for auxiliary tensor creation.
43
- """
44
- super().__init__(weights_quantization_cfg, gptq_config, weight.shape)
45
- self.signed = True
46
- self.min_int = -int(self.signed) * (2 ** (self.num_bits - int(self.signed)))
47
- self.max_int = (2 ** (self.num_bits - int(self.signed))) - 1
48
- self.threshold_tensor = to_torch_tensor(weights_quantization_cfg.weights_quantization_params.get(THRESHOLD))
49
- self.scale_tensor = torch.ones(self.weight_shape)
50
-
51
- # Set trainable tensors
52
- self.set_trainable_params(weight)
53
-
54
-
55
- def set_trainable_params(self, weight: torch.nn.Parameter):
56
- """
57
- A function to set a list of trainable parameters of the quantizer for GPTQ retraining
58
- Args:
59
- weight: weight for auxiliary tensor creation.
60
- """
61
- q_error = weight - symmetric_quantizer(weight,
62
- self.threshold_tensor,
63
- num_bits=self.num_bits,
64
- signed=True,
65
- power_of_two=self.power_of_two)
66
- ceil_indicator = (q_error < 0).int() # Negative error means the choosen point is rounded to ceil.
67
- self.aux_tensor = nn.Parameter(to_torch_tensor(init_aux_var(ceil_indicator, self.weight_shape, self.m)), requires_grad=True)
68
- self.trainable_params.update({AUXVAR: self.aux_tensor})
69
- self.temp_tensor = nn.Parameter(to_torch_tensor(self.maximal_temp*torch.ones([1,*self.weight_shape])), requires_grad=True)
70
- self.trainable_params.update({TEMP: self.temp_tensor})
71
- if self.quantization_parameter_learning and not self.power_of_two:
72
- self.scale_tensor = nn.Parameter(to_torch_tensor(self.scale_tensor), requires_grad=True)
73
- self.trainable_params.update({SCALE_TENSOR: self.scale_tensor})
74
- elif self.quantization_parameter_learning:
75
- self.threshold_tensor = nn.Parameter(self.threshold_tensor, requires_grad=True)
76
- self.trainable_params.update({THRESHOLD_TENSOR: self.threshold_tensor})
77
- else:
78
- self.trainable_params.update({THRESHOLD_TENSOR: self.threshold_tensor})
79
-
80
- def get_aux_variable(self) -> torch.Tensor:
81
- """
82
- Returns auxiliary trainable variables
83
- """
84
- return self.trainable_params.get(AUXVAR)
85
-
86
- def get_quantization_variable(self) -> Union[torch.Tensor, List]:
87
- """
88
- Returns quantization trainable variables
89
- """
90
- if self.quantization_parameter_learning and not self.power_of_two:
91
- return [self.trainable_params.get(SCALE_TENSOR)]
92
- else:
93
- return [self.trainable_params.get(THRESHOLD_TENSOR)]
94
-
95
- def get_temperature_variable(self) -> Union[torch.Tensor, List]:
96
- """
97
- Returns temperature trainable variables
98
- """
99
- return self.trainable_params.get(TEMP)
100
-
101
- def get_weight_quant_params(self) -> dict:
102
- """
103
- Returns weight quantization dictionary params
104
- """
105
- threshold_tensor = self.threshold_tensor
106
- if self.power_of_two:
107
- threshold_tensor = power_of_two_max(threshold_tensor)
108
- elif self.quantization_parameter_learning:
109
- threshold_tensor = threshold_tensor*self.scale_tensor
110
- return {THRESHOLD: torch_tensor_to_numpy(threshold_tensor.detach())}
111
-
112
- def forward(self, w: nn.Parameter, training:bool = True) -> nn.Parameter:
113
- """
114
- Weight fake quantizer
115
- Args:
116
- w: weights to quantize.
117
- training: whether in training mode or not
118
- Returns:
119
- quantized weights
120
- """
121
- self.update_iteration(training)
122
-
123
- #####################################################
124
- # Gumbel Softmax
125
- #####################################################
126
- if training:
127
- gumbel_scale = self.gumbel_scale if self.gumbel_scale_per_bitwidth is None \
128
- else self.gumbel_scale_per_bitwidth.get(self.num_bits, self.gumbel_scale)
129
- self.p_t = gumbel_softmax(self.aux_tensor, self.tau, self.g_t, gumbel_scale=gumbel_scale)
130
- else:
131
- self.p_t = ste_gumbel(gumbel_softmax(self.aux_tensor, self.minimal_temp, 0))
132
-
133
- auxhat_tensor = torch.sum(self.p_t * self.shift_tensor.reshape(self.reshape_aux_shift), dim=0)
134
-
135
- #####################################################
136
- # Quantizer
137
- #####################################################
138
- threshold_tensor = self.threshold_tensor
139
- if self.power_of_two:
140
- threshold_tensor = power_of_two_max(threshold_tensor)
141
- delta_tensor = threshold_tensor / (2 ** (self.num_bits-int(self.signed)))
142
- w0 = torch.floor(w / delta_tensor).detach()
143
- w1 = w0 + auxhat_tensor
144
- w2 = ste_clip(w1, min_val=self.min_int, max_val=self.max_int)
145
- w_q = delta_tensor * w2
146
- # Scale
147
- if self.quantization_parameter_learning and not self.power_of_two:
148
- w_q *= self.scale_tensor
149
- return w_q
150
-
@@ -1,143 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import numpy as np
16
- import torch
17
- import torch.nn as nn
18
- from typing import List, Union, Dict
19
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
20
- from model_compression_toolkit.gptq.pytorch.quantizer.gumbel_rounding.base_gumbel_weights_quantizer import BaseGumbelWeightQuantizer, init_aux_var
21
- from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
22
- from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import uniform_quantizer, fix_range_to_include_zero
23
- from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import ste_clip, ste_gumbel, gumbel_softmax
24
- from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_MAX_RANGE, PTQ_MIN_RANGE, TEMP
25
- from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
26
- from model_compression_toolkit.core.common.constants import RANGE_MAX, RANGE_MIN
27
-
28
- class UniformGumbelWeightQuantizer(BaseGumbelWeightQuantizer):
29
- """
30
- Class that implements a quantizer with trainable parameters to be used for GPTQ training.
31
- """
32
-
33
- def __init__(self,
34
- weights_quantization_cfg: NodeWeightsQuantizationConfig,
35
- gptq_config: GradientPTQConfig,
36
- weight: torch.nn.Parameter):
37
- """
38
- Construct a Pytorch model that utilize a fake weight quantizer of Uniform Gumbel rounding
39
- Args:
40
- weights_quantization_cfg: Configuration of weight quantization
41
- gptq_config: GradientPTQConfig object with parameters about the tuning process.
42
- weight: weight for auxiliary tensor creation.
43
- """
44
- super().__init__(weights_quantization_cfg, gptq_config, weight.shape)
45
- self.min_int = 0
46
- self.max_int = 2**self.num_bits - 1
47
- self.max_range_tensor = weights_quantization_cfg.weights_quantization_params.get(RANGE_MAX)
48
- self.min_range_tensor = weights_quantization_cfg.weights_quantization_params.get(RANGE_MIN)
49
-
50
- # Set trainable tensors
51
- self.set_trainable_params(weight)
52
-
53
-
54
- def set_trainable_params(self, weight: torch.nn.Parameter):
55
- """
56
- A function to set a list of trainable parameters of the quantizer for GPTQ retraining
57
- """
58
- self.temp_tensor = nn.Parameter(to_torch_tensor(self.maximal_temp*torch.ones([1,*self.weight_shape])), requires_grad=True)
59
- self.trainable_params.update({TEMP: self.temp_tensor})
60
- self.max_range_tensor = nn.Parameter(to_torch_tensor(self.max_range_tensor), requires_grad=self.quantization_parameter_learning)
61
- self.trainable_params.update({PTQ_MAX_RANGE: self.max_range_tensor})
62
- self.min_range_tensor = nn.Parameter(to_torch_tensor(self.min_range_tensor), requires_grad=self.quantization_parameter_learning)
63
- self.trainable_params.update({PTQ_MIN_RANGE: self.min_range_tensor})
64
- q_error = weight - uniform_quantizer(weight,
65
- self.min_range_tensor,
66
- self.max_range_tensor,
67
- n_bits=self.num_bits)
68
- ceil_indicator = (q_error < 0).int() # Negative error means the choosen point is rounded to ceil.
69
- self.aux_tensor = nn.Parameter(to_torch_tensor(init_aux_var(ceil_indicator, self.weight_shape, self.m)), requires_grad=True)
70
- self.trainable_params.update({AUXVAR: self.aux_tensor})
71
-
72
- def get_aux_variable(self) -> torch.Tensor:
73
- """
74
- Returns auxiliary trainable variables
75
- """
76
- return self.trainable_params.get(AUXVAR)
77
-
78
- def get_quantization_variable(self) -> Union[torch.Tensor, List]:
79
- """
80
- Returns quantization trainable variables
81
- """
82
- return [self.trainable_params.get(PTQ_MAX_RANGE), self.trainable_params.get(PTQ_MIN_RANGE)]
83
-
84
- def get_temperature_variable(self) -> Union[torch.Tensor, List]:
85
- """
86
- Returns temperature trainable variables
87
- """
88
- return self.trainable_params.get(TEMP)
89
-
90
- def get_weight_quant_params(self) -> Dict[str, np.ndarray]:
91
- """
92
- Returns weight quantization dictionary params
93
- """
94
- max_range_tensor = self.max_range_tensor
95
- min_range_tensor = self.min_range_tensor
96
- return {PTQ_MAX_RANGE: torch_tensor_to_numpy(max_range_tensor.detach()),
97
- PTQ_MIN_RANGE: torch_tensor_to_numpy(min_range_tensor.detach())}
98
-
99
- def forward(self, w: nn.Parameter, training:bool = True) -> nn.Parameter:
100
- """
101
- Weight fake quantizer
102
- Args:
103
- w: weights to quantize.
104
- training: whether in training mode or not
105
- Returns:
106
- quantized weights
107
- """
108
- self.update_iteration(training)
109
-
110
- #####################################################
111
- # Gumbel Softmax
112
- #####################################################
113
- if training:
114
- self.p_t = gumbel_softmax(self.aux_tensor, self.tau, self.g_t)
115
- else:
116
- self.p_t = ste_gumbel(gumbel_softmax(self.aux_tensor, self.minimal_temp, 0))
117
-
118
- auxhat_tensor = torch.sum(self.p_t * self.shift_tensor.reshape(self.reshape_aux_shift), dim=0)
119
-
120
- #####################################################
121
- # Quantizer
122
- #####################################################
123
- max_range_tensor = self.max_range_tensor
124
- min_range_tensor = self.min_range_tensor
125
-
126
- # adjusts the quantization rage so the quantization grid include zero.
127
- a, b = fix_range_to_include_zero(min_range_tensor, max_range_tensor, self.num_bits)
128
-
129
- # Compute the step size of quantized values.
130
- delta_tensor = (b - a) / (2 ** self.num_bits - 1)
131
-
132
- # Apply rounding
133
- w0 = torch.floor((w - a) / delta_tensor).detach() # Apply rounding
134
-
135
- w1 = w0 + auxhat_tensor
136
-
137
- # Clip data in range
138
- w2 = ste_clip(w1, min_val=self.min_int, max_val=self.max_int)
139
-
140
- # Quantize the data between min/max of quantization range.
141
- w_q = delta_tensor * w2 + a
142
- return w_q
143
-
@@ -1,103 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import torch
16
- import torch.nn as nn
17
- from model_compression_toolkit.core.common import BaseNode, Logger
18
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType
19
- from model_compression_toolkit.gptq.pytorch.quantizer.gptq_quantizer import BaseWeightQuantizer
20
- from model_compression_toolkit.gptq.pytorch.quantizer.ste_rounding.ste_weights_quantizer import STEWeightQuantizer
21
- from model_compression_toolkit.gptq.pytorch.quantizer.gumbel_rounding.sym_gumbel_weights_quantizer import SymmetricGumbelWeightQuantizer
22
- from model_compression_toolkit.gptq.pytorch.quantizer.gumbel_rounding.uniform_gumbel_weights_quantizer import UniformGumbelWeightQuantizer
23
- from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder
24
- from model_compression_toolkit.core.pytorch.constants import KERNEL
25
- from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
26
- from model_compression_toolkit.core.common.target_platform.op_quantization_config import QuantizationMethod
27
-
28
-
29
- class WeightQuantizerWrapper(nn.Module):
30
-
31
- def __init__(self, node: BaseNode, gptq_config: GradientPTQConfig, weight_quantizer: BaseWeightQuantizer):
32
- """
33
- Construct a Pytorch model that constitutes as a wrapper for a Pytorch layer, built from a given graph node.
34
- Args:
35
- node: Node to build its Pytorch quantizer wrapper.
36
- gptq_config: GradientPTQConfig object with parameters about the tuning process.
37
- weight_quantizer: BaseWeightQuantizer object for gradient based weight quantizer
38
- """
39
- super().__init__()
40
-
41
- # loading operation
42
- self.op = node.type(**node.framework_attr)
43
-
44
- # loading the weights from the graph node (weights of the trained model)
45
- self.op.load_state_dict({k: torch.Tensor(v) for k, v in node.weights.items()}, strict=False)
46
- self.float_weight = to_torch_tensor(getattr(self.op, KERNEL)).detach()
47
-
48
- # replace non-gradient needed nn.Parameter with gradient needed torch.tensor
49
- delattr(self.op, KERNEL)
50
- setattr(self.op, KERNEL, self.float_weight)
51
- setattr(getattr(self.op, KERNEL), 'requires_grad', True)
52
-
53
- # quantizer
54
- self.weight_quantizer = weight_quantizer(node.final_weights_quantization_cfg, gptq_config, self.float_weight)
55
-
56
- def forward(self, x: torch.Tensor) -> torch.Tensor:
57
- """
58
- Weight fake quantizer wrapper
59
- Args:
60
- x: input to layer.
61
- Returns:
62
- Output of layer after using operation with fake quantized weights
63
- """
64
- # Run weight quantizer
65
- setattr(self.op, KERNEL, self.weight_quantizer(self.float_weight))
66
- # Do computation
67
- return self.op(x)
68
-
69
-
70
- def quantizer_wrapper(node: BaseNode, gptq_config: GradientPTQConfig) -> nn.Module:
71
- """
72
- Construct a Pytorch model that constitutes as a wrapper for a Pytorch layer, built from a given graph node.
73
- Args:
74
- node: Node to build its Pytorch layer.
75
- gptq_config: GradientPTQConfig with parameters about the tuning process.
76
- """
77
- if node.is_weights_quantization_enabled():
78
- quantization_method = node.final_weights_quantization_cfg.weights_quantization_method
79
- if quantization_method in [QuantizationMethod.SYMMETRIC, QuantizationMethod.POWER_OF_TWO]:
80
- # STE quantizer
81
- # ---------------
82
- if gptq_config.rounding_type == RoundingType.STE:
83
- node_instance = WeightQuantizerWrapper(node, gptq_config, STEWeightQuantizer)
84
-
85
- # Symmetric Gumbel rounding quantizer
86
- # ------------------------------------
87
- elif gptq_config.rounding_type == RoundingType.GumbelRounding:
88
- node_instance = WeightQuantizerWrapper(node, gptq_config, SymmetricGumbelWeightQuantizer)
89
-
90
- elif quantization_method == QuantizationMethod.UNIFORM:
91
- # Uniform Gumbel rounding quantizer
92
- # ------------------------------------
93
- if gptq_config.rounding_type == RoundingType.GumbelRounding:
94
- node_instance = WeightQuantizerWrapper(node, gptq_config, UniformGumbelWeightQuantizer)
95
- else:
96
- Logger.error(f"For quantization method {quantization_method}, GPTQ Rounding type {gptq_config.rounding_type} is not supported")
97
- else:
98
- Logger.error(f"For quantization method {quantization_method}, GPTQ Rounding type {gptq_config.rounding_type} is not supported")
99
- else:
100
- # No quantization
101
- node_instance = node_builder(node)
102
-
103
- return node_instance
@@ -1,103 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import torch
16
- import torch.nn as nn
17
- from typing import List, Union
18
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
19
- from model_compression_toolkit.gptq.pytorch.quantizer.gptq_quantizer import BaseWeightQuantizer
20
- from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
21
- from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import ste_round, ste_clip
22
- from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR
23
- from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
24
- from model_compression_toolkit.core.common.constants import THRESHOLD
25
-
26
-
27
- class STEWeightQuantizer(BaseWeightQuantizer):
28
- """
29
- Class that implements a quantizer with trainable parameters to be used for GPTQ training.
30
- """
31
-
32
- def __init__(self,
33
- weights_quantization_cfg: NodeWeightsQuantizationConfig,
34
- gptq_config: GradientPTQConfig,
35
- weight: torch.nn.Parameter):
36
- """
37
- Construct a Pytorch model that utilize a fake weight quantizer of STE (Straight Through Estimator) for symmetric quantizer.
38
- Args:
39
- weights_quantization_cfg: Configuration of weight quantization
40
- gptq_config: GradientPTQConfig object with parameters about the tuning process.
41
- weight: weight for auxiliary tensor creation.
42
- """
43
- super().__init__()
44
-
45
- self.signed = True
46
- self.num_bits = weights_quantization_cfg.weights_n_bits
47
- self.min_int = -int(self.signed) * (2 ** (self.num_bits - int(self.signed)))
48
- self.max_int = (2 ** (self.num_bits - int(self.signed))) - 1
49
- self.weight_shape = weight.shape
50
- self.threshold_values = weights_quantization_cfg.weights_quantization_params.get(THRESHOLD)
51
- self.delta_tensor = self.threshold_values / (2 ** (self.num_bits-int(self.signed)))
52
- self.max_delta_change = gptq_config.lsb_change_per_bit_width.get(self.num_bits)
53
-
54
- # Set trainable tensors
55
- self.set_trainable_params()
56
-
57
- # Create tensors
58
- self.delta_tensor = to_torch_tensor(self.delta_tensor)
59
- self.max_tensor_change = self.delta_tensor * self.max_delta_change
60
-
61
- def set_trainable_params(self):
62
- """
63
- A function to set a list of trainable parameters of the quantizer for GPTQ retraining
64
- """
65
- self.aux_tensor = nn.Parameter(to_torch_tensor(torch.zeros(self.weight_shape)), requires_grad=True)
66
- self.trainable_params.update({AUXVAR: self.aux_tensor})
67
-
68
- def get_aux_variable(self) -> torch.Tensor:
69
- """
70
- Returns auxiliary trainable variables
71
- """
72
- return self.trainable_params.get(AUXVAR)
73
-
74
- def get_quantization_variable(self) -> Union[torch.Tensor, List]:
75
- """
76
- Returns quantization trainable variables
77
- """
78
- return []
79
-
80
- def get_weight_quantization_params(self) -> dict:
81
- """
82
- Returns weight quantization dictionary params
83
- """
84
- return {THRESHOLD: self.threshold_values}
85
-
86
- def forward(self, w: nn.Parameter, training: bool = True) -> nn.Parameter:
87
- """
88
- Weight fake quantizer
89
- Args:
90
- w: weights to quantize.
91
- training: whether in training mode or not
92
- Returns:
93
- quantized weights
94
- """
95
- v0 = ste_clip(self.aux_tensor, min_val=-self.max_tensor_change, max_val=self.max_tensor_change)
96
- v1 = v0 / self.delta_tensor
97
- w0 = torch.round(w / self.delta_tensor).detach()
98
- w1 = w0 + v1
99
- w2 = ste_round(w1)
100
- w3 = ste_clip(w2, min_val=self.min_int, max_val=self.max_int)
101
- w_q = self.delta_tensor * w3
102
- return w_q
103
-