mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -1,247 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import tensorflow as tf
16
- import numpy as np
17
-
18
- from model_compression_toolkit import GumbelConfig
19
- from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
20
- from model_compression_toolkit.gptq.keras.quantizer.gumbel_rounding.base_gumbel_rounding import GumbelRoundingBase, \
21
- init_aux_var
22
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
23
- from tensorflow.python.framework.tensor_shape import TensorShape
24
- from model_compression_toolkit.core.common.defaultdict import DefaultDict
25
- from typing import Dict, Any, List
26
- from model_compression_toolkit.gptq.keras.quantizer.gumbel_rounding.gumbel_softmax import gumbel_softmax, ste_gumbel
27
- from model_compression_toolkit.core.common.constants import RANGE_MIN, RANGE_MAX
28
- from model_compression_toolkit.gptq.common import gptq_constants
29
- from model_compression_toolkit.gptq.keras.quantizer.ste_rounding.uniform_ste import rounding_uniform_quantizer
30
-
31
-
32
- def gumbel_rounding_uniform_quantizer(tensor_data: tf.Tensor,
33
- auxvar_tensor: tf.Variable,
34
- range_min: tf.Tensor,
35
- range_max: tf.Tensor,
36
- n_bits: int) -> tf.Tensor:
37
- """
38
- Quantize a tensor according to given range (min, max) and number of bits.
39
-
40
- Args:
41
- tensor_data: Tensor values to quantize.
42
- auxvar_tensor: Tensor that manifests the bit shift the weight due to gptq.
43
- range_min: minimum bound of the range for quantization (or array of min values per channel).
44
- range_max: maximum bound of the range for quantization (or array of max values per channel).
45
- n_bits: Number of bits to quantize the tensor.
46
-
47
- Returns:
48
- Quantized data.
49
- """
50
-
51
- # adjusts the quantization rage so the quantization grid include zero.
52
- a, b = qutils.fix_range_to_include_zero(range_min, range_max, n_bits)
53
-
54
- # Compute the step size of quantized values.
55
- delta = (b - a) / (2 ** n_bits - 1)
56
-
57
- input_tensor_int = tf.stop_gradient(tf.floor((tensor_data - a) / delta)) # Apply rounding
58
- tensor_q = input_tensor_int + auxvar_tensor
59
-
60
- # Clip data in range
61
- clipped_tensor = qutils.ste_clip(tensor_q, min_val=0, max_val=2 ** n_bits - 1)
62
-
63
- # Quantize the data between min/max of quantization range.
64
- q = delta * clipped_tensor + a
65
- return q
66
-
67
-
68
- class UniformGumbelRounding(GumbelRoundingBase):
69
- """
70
- Trainable constrained quantizer to quantize a layer inputs.
71
- """
72
- PTQ_MIN_RANGE = "_min_range"
73
- PTQ_MAX_RANGE = "_max_range"
74
-
75
- def __init__(self, num_bits: int, per_axis: bool, signed: bool, quantization_parameter_learning: bool,
76
- min_range: np.ndarray, max_range: np.ndarray, gumbel_config: GumbelConfig,
77
- quantization_axis: int = -1, max_lsbs_change_map: dict = DefaultDict({}, lambda: 1),
78
- max_iteration: int = 10000):
79
- """
80
- Initialize a TrainableWeightQuantizer object with parameters to use
81
- for the quantization.
82
-
83
- Args:
84
- num_bits: Number of bits to use for the quantization.
85
- per_axis: Whether to quantize per-channel or per-tensor.
86
- signed: Signedness to use for the quantization range.
87
- quantization_parameter_learning: Threshold to use for the quantization.
88
- min_range: a numpy array of the min range.
89
- max_range: a numpy array of the max range.
90
- gumbel_config: A class with the gumbel rounding configurations.
91
- quantization_axis: Axis of tensor to use for the quantization.
92
- max_lsbs_change_map: a mapping between number of bits to max lsb change.
93
- max_iteration: The number of iteration of gptq.
94
- """
95
- super().__init__(num_bits, per_axis, signed, False, False, quantization_parameter_learning,
96
- quantization_axis, gumbel_config,
97
- max_lsbs_change_map,
98
- max_iteration)
99
- self.threshold_shape = np.asarray(min_range).shape
100
- self.min_range = np.reshape(np.asarray(min_range), [-1]) if self.per_axis else float(
101
- min_range)
102
- self.max_range = np.reshape(np.asarray(max_range), [-1]) if self.per_axis else float(
103
- max_range)
104
- self.k_threshold = len(self.max_range) if self.per_axis else 1
105
-
106
- def build(self,
107
- tensor_shape: TensorShape,
108
- name: str,
109
- layer: QuantizeWrapper) -> Dict[str, tf.Variable]:
110
- """
111
- Add min and max variables to layer.
112
- Args:
113
- tensor_shape: Tensor shape the quantizer quantize.
114
- name: Prefix of variables names.
115
- layer: Layer to add the variables to. The variables are saved
116
- in the layer's scope.
117
-
118
- Returns:
119
- Dictionary of new variables.
120
- """
121
- super().build(tensor_shape, name, layer)
122
-
123
- if self.per_axis:
124
- input_shape = tensor_shape
125
- n_axis = len(input_shape)
126
- quantization_axis = n_axis + self.quantization_axis if self.quantization_axis < 0 else \
127
- self.quantization_axis
128
- reshape_shape = [self.k_threshold if i == quantization_axis else 1 for i in range(n_axis)]
129
- else:
130
- reshape_shape = [self.k_threshold]
131
-
132
- max_range = layer.add_weight(
133
- name + self.PTQ_MAX_RANGE,
134
- shape=reshape_shape,
135
- initializer=tf.keras.initializers.Constant(1.0),
136
- trainable=self.quantization_parameter_learning)
137
- max_range.assign(self.max_range.reshape(reshape_shape))
138
-
139
- min_range = layer.add_weight(
140
- name + self.PTQ_MIN_RANGE,
141
- shape=reshape_shape,
142
- initializer=tf.keras.initializers.Constant(1.0),
143
- trainable=self.quantization_parameter_learning)
144
- min_range.assign(self.min_range.reshape(reshape_shape))
145
-
146
- auxvar_tensor = layer.add_weight(
147
- name + gptq_constants.AUXVAR,
148
- shape=[self.m, *self.w_shape],
149
- initializer=tf.keras.initializers.Constant(0.0),
150
- trainable=True)
151
- w = getattr(layer.layer, name)
152
-
153
- q_error = w - rounding_uniform_quantizer(w, min_range, max_range,
154
- n_bits=self.num_bits)
155
- ceil_indicator = (q_error < 0).numpy().astype("int") # Negative error means the choose point is rounded to ceil.
156
- auxvar_tensor.assign(init_aux_var(ceil_indicator, self.w_shape, self.m))
157
-
158
- self.quantizer_parameters.update({gptq_constants.AUXVAR: auxvar_tensor,
159
- self.PTQ_MAX_RANGE: max_range,
160
- self.PTQ_MIN_RANGE: min_range})
161
- return self.quantizer_parameters
162
-
163
- def __call__(self, inputs: tf.Tensor,
164
- training: bool,
165
- weights: Dict[str, tf.Variable],
166
- **kwargs: Dict[str, Any]):
167
- """
168
- Quantize a tensor.
169
- Args:
170
- inputs: Input tensor to quantize.
171
- training: Whether the graph is in training mode.
172
- weights: Dictionary of weights the quantizer can use to quantize the tensor.
173
- **kwargs: Additional variables the quantizer may receive.
174
-
175
- Returns:
176
- The quantized tensor.
177
- """
178
-
179
- auxvar = weights[gptq_constants.AUXVAR]
180
- ar_iter = weights[gptq_constants.GPTQ_ITER]
181
- ptq_min_range = weights[self.PTQ_MIN_RANGE]
182
- ptq_max_range = weights[self.PTQ_MAX_RANGE]
183
- aux_index_shift = weights[gptq_constants.AUXSHIFT]
184
- self.update_iteration(training, ar_iter)
185
- if self.per_axis:
186
- input_shape = inputs.shape
187
- n_axis = len(input_shape)
188
- quantization_axis = n_axis + self.quantization_axis if self.quantization_axis < 0 else \
189
- self.quantization_axis
190
- reshape_shape = [-1 if i == quantization_axis else 1 for i in range(n_axis)]
191
-
192
- reshape_shape_aux_ind = [-1, *[1 for _ in range(n_axis)]]
193
- #####################################################
194
- # Gumbel Softmax
195
- #####################################################
196
- if training:
197
- p_t = gumbel_softmax(auxvar, self.tau, self.g_t)
198
- else:
199
- p_t = gumbel_softmax(auxvar, self.minimal_temp, 0)
200
- p_t = ste_gumbel(p_t)
201
- self.p_t = p_t
202
- #####################################################
203
- # Calculate v hat and threshold hat
204
- #####################################################
205
- ptq_min_range = tf.reshape(ptq_min_range, reshape_shape)
206
- ptq_max_range = tf.reshape(ptq_max_range, reshape_shape)
207
-
208
- auxvar_hat = tf.reduce_sum(p_t * tf.reshape(aux_index_shift, reshape_shape_aux_ind), axis=0)
209
- #####################################################
210
- # Quantized Input
211
- #####################################################
212
- q_tensor = gumbel_rounding_uniform_quantizer(inputs, auxvar_hat,
213
- ptq_min_range,
214
- ptq_max_range,
215
- self.num_bits)
216
- return q_tensor
217
- else:
218
- raise NotImplemented
219
- return gumbel_rounding_uniform_quantizer(inputs, auxvar_hat,
220
- ptq_max_range,
221
- ptq_min_range,
222
- self.num_bits)
223
-
224
- def get_quant_config(self, layer) -> Dict[str, np.ndarray]:
225
- """
226
- Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
227
-
228
- Args:
229
- layer: quantized layer
230
-
231
- Returns:
232
- A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
233
- Keys must match NodeQuantizationConfig attributes
234
-
235
- """
236
- min_range = self.quantizer_parameters[self.PTQ_MIN_RANGE]
237
- max_range = self.quantizer_parameters[self.PTQ_MAX_RANGE]
238
- return {RANGE_MIN: min_range.numpy().reshape(self.threshold_shape),
239
- RANGE_MAX: max_range.numpy().reshape(self.threshold_shape)}
240
-
241
- def get_quantization_variable(self) -> List[tf.Tensor]:
242
- """
243
- This function return a list of quantizer parameters.
244
- Returns: A list of the quantizer parameters
245
-
246
- """
247
- return [self.quantizer_parameters[self.PTQ_MIN_RANGE], self.quantizer_parameters[self.PTQ_MAX_RANGE]]
@@ -1,50 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import tensorflow as tf
16
- from model_compression_toolkit.core.keras.constants import KERNEL
17
-
18
-
19
- def get_kernel(weights_list: list) -> tf.Tensor:
20
- """
21
- This function a list of weights and return the kernel
22
- Args:
23
- weights_list: A list of Tensors
24
-
25
- Returns: The kernel tensor.
26
-
27
- """
28
- for w in weights_list:
29
- if KERNEL in w.name:
30
- return w
31
- raise Exception("Can't find kernel variable")
32
-
33
-
34
- def threshold_reshape(threshold_tensor: tf.Tensor, input_w: tf.Tensor, in_quantization_axis: int) -> tf.Tensor:
35
- """
36
- This function take a threshold tensor and re-aline it to the weight tensor channel axis.
37
- Args:
38
- threshold_tensor: A tensor of threshold
39
- input_w: A weight tensor
40
- in_quantization_axis: A int value that represent the channel axis.
41
-
42
- Returns: A reshape tensor of threshold.
43
-
44
- """
45
- input_shape = input_w.shape
46
- n_axis = len(input_shape)
47
- quantization_axis = n_axis + in_quantization_axis if in_quantization_axis < 0 else in_quantization_axis
48
- reshape_shape = [-1 if i == quantization_axis else 1 for i in range(n_axis)]
49
- ptq_threshold_tensor = tf.reshape(threshold_tensor, reshape_shape)
50
- return ptq_threshold_tensor
@@ -1,49 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import tensorflow as tf
16
-
17
- from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
18
-
19
-
20
- def rounding_uniform_quantizer(tensor_data: tf.Tensor,
21
- range_min: tf.Tensor,
22
- range_max: tf.Tensor,
23
- n_bits: int) -> tf.Tensor:
24
- """
25
- Quantize a tensor according to given range (min, max) and number of bits.
26
-
27
- Args:
28
- tensor_data: Tensor values to quantize.
29
- range_min: minimum bound of the range for quantization (or array of min values per channel).
30
- range_max: maximum bound of the range for quantization (or array of max values per channel).
31
- n_bits: Number of bits to quantize the tensor.
32
-
33
- Returns:
34
- Quantized data.
35
- """
36
- # adjusts the quantization rage so the quantization grid include zero.
37
- a, b = qutils.fix_range_to_include_zero(range_min, range_max, n_bits)
38
-
39
- # Compute the step size of quantized values.
40
- delta = (b - a) / (2 ** n_bits - 1)
41
-
42
- input_tensor_int = qutils.ste_round((tensor_data - a) / delta) # Apply rounding
43
-
44
- # Clip data in range
45
- clipped_tensor = qutils.ste_clip(input_tensor_int, min_val=0, max_val=2 ** n_bits - 1)
46
-
47
- # Quantize the data between min/max of quantization range.
48
- q = delta * clipped_tensor + a
49
- return q
@@ -1,94 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import torch
16
- import torch.nn as nn
17
- from typing import List
18
- from model_compression_toolkit.gptq.pytorch.quantizer.quantizer_wrapper import WeightQuantizerWrapper
19
- from model_compression_toolkit.gptq.pytorch.quantizer.gumbel_rounding.base_gumbel_weights_quantizer import BaseGumbelWeightQuantizer
20
- from model_compression_toolkit.core.pytorch.constants import BIAS
21
-
22
-
23
- def get_trainable_parameters(fxp_model: nn.Module,
24
- add_bias: bool = False,
25
- quantization_parameters_learning: bool = False,
26
- is_gumbel: bool = False) -> (List[nn.Parameter], List[nn.Parameter], List[nn.Parameter]):
27
- """
28
- Get trainable parameters from all layers in a model
29
-
30
- Args:
31
- fxp_model: Model to get its trainable parameters.
32
- add_bias: Whether to include biases of the model (if there are) or not.
33
- quantization_parameters_learning: Whether to include quantization parameters of the model or not.
34
- is_gumbel: Whether the fxp model is quantized using Gumbel Rounding
35
- Returns:
36
- A list of trainable variables in a model. Each item is a list of a layers weights.
37
- """
38
-
39
- trainable_aux_weights = nn.ParameterList()
40
- trainable_threshold = nn.ParameterList()
41
- trainable_bias = nn.ParameterList()
42
- trainable_temperature = nn.ParameterList()
43
-
44
- for layer in fxp_model.modules():
45
- if isinstance(layer, WeightQuantizerWrapper):
46
- trainable_aux_weights.append(layer.weight_quantizer.get_aux_variable())
47
- if quantization_parameters_learning:
48
- trainable_threshold.extend(layer.weight_quantizer.get_quantization_variable())
49
- if is_gumbel:
50
- trainable_temperature.append(layer.weight_quantizer.get_temperature_variable())
51
- if add_bias and hasattr(layer.op, BIAS):
52
- bias = getattr(layer.op, BIAS)
53
- trainable_bias.append(bias)
54
-
55
- return trainable_aux_weights, trainable_bias, trainable_threshold, trainable_temperature
56
-
57
-
58
- def get_gumbel_probability(fxp_model: nn.Module) -> List[torch.Tensor]:
59
- """
60
- This function return the gumbel softmax probability of GumRounding
61
- Args:
62
- fxp_model: A model to be quantized with GumRounding
63
-
64
- Returns: A list of tensors.
65
-
66
- """
67
- gumbel_prob_aux = []
68
- for layer in fxp_model.modules():
69
- if isinstance(layer, WeightQuantizerWrapper) and isinstance(layer.weight_quantizer, BaseGumbelWeightQuantizer):
70
- gumbel_prob_aux.append(layer.weight_quantizer.get_gumbel_probability())
71
- return gumbel_prob_aux
72
-
73
-
74
- def get_weights_for_loss(fxp_model: nn.Module) -> [List, List]:
75
- """
76
- Get all float and quantized kernels for the GPTQ loss
77
-
78
- Args:
79
- fxp_model: Model to get its float and quantized weights.
80
-
81
- Returns:
82
- A list of float kernels, each item is the float kernel of the layer
83
- A list of quantized kernels, each item is the quantized kernel of the layer
84
- """
85
-
86
- flp_weights_list, fxp_weights_list = [], []
87
- for layer in fxp_model.modules():
88
- if isinstance(layer, WeightQuantizerWrapper):
89
- # Collect pairs of float and quantized weights per layer
90
- weights = layer.op.weight
91
- flp_weights_list.append(weights)
92
- fxp_weights_list.append(layer.weight_quantizer(weights, training=False))
93
-
94
- return flp_weights_list, fxp_weights_list
@@ -1,113 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import torch
16
- from typing import Tuple, List
17
- from model_compression_toolkit.core.common.user_info import UserInformation
18
- from model_compression_toolkit.core import common
19
- from model_compression_toolkit.core.common.graph.base_graph import BaseNode
20
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
21
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
22
- from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder, PytorchModel
23
- from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
24
- from model_compression_toolkit.gptq.pytorch.quantizer.quantizer_wrapper import quantizer_wrapper
25
- from model_compression_toolkit.core.pytorch.utils import get_working_device
26
- from model_compression_toolkit.core.pytorch.constants import BUFFER
27
- from model_compression_toolkit.core.pytorch.reader.node_holders import BufferHolder
28
- from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder
29
-
30
-
31
- class GPTQPytorchModel(PytorchModel):
32
- """
33
- Class for GPTQ PyTorch model.
34
- """
35
-
36
- def __init__(self,
37
- graph: common.Graph,
38
- gptq_config: GradientPTQConfig,
39
- append2output=None,
40
- return_float_outputs: bool = True):
41
- """
42
- Args:
43
- graph: Graph to build the model from.
44
- gptq_config: Configuration for GPTQ optimization.
45
- append2output: Nodes to append to model's output.
46
- return_float_outputs: Whether the model returns float tensors or not.
47
- """
48
-
49
- super().__init__(graph,
50
- append2output,
51
- DEFAULT_PYTORCH_INFO,
52
- return_float_outputs)
53
-
54
- for node in graph.nodes():
55
- if not isinstance(node, FunctionalNode):
56
- if node.type == BufferHolder:
57
- self.add_module(node.name, node_builder(node))
58
- self.get_submodule(node.name).register_buffer(node.name,torch.Tensor(node.get_weights_by_keys(BUFFER)).to(get_working_device()))
59
- else:
60
- self.add_module(node.name, quantizer_wrapper(node, gptq_config))
61
-
62
-
63
- def _quantize_node_activations(self,
64
- node: BaseNode,
65
- input_tensors: List[torch.Tensor]) -> List[torch.Tensor]:
66
- """
67
- Quantize node's activation given input tensors.
68
-
69
- Args:
70
- node: Node to quantize its outputs.
71
- input_tensors: Input tensors of the node.
72
-
73
- Returns:
74
- Output of the node.
75
-
76
- """
77
- return node.final_activation_quantization_cfg.quantize_node_output(input_tensors)
78
-
79
-
80
- class GPTQPytorchModelBuilder(PyTorchModelBuilder):
81
- """
82
- Builder of GPTQ Pytorch models.
83
- """
84
-
85
- def __init__(self,
86
- graph: common.Graph,
87
- gptq_config: GradientPTQConfig,
88
- append2output=None,
89
- return_float_outputs: bool = True):
90
- """
91
-
92
- Args:
93
- graph: Graph to build the model from.
94
- gptq_config: Configuration for GPTQ optimization.
95
- append2output: Nodes to append to model's output.
96
- return_float_outputs: Whether the model returns float tensors or not.
97
- """
98
- super().__init__(graph,
99
- append2output,
100
- DEFAULT_PYTORCH_INFO,
101
- return_float_outputs)
102
- self.gptq_config = gptq_config
103
-
104
- def build_model(self) -> Tuple[PytorchModel, UserInformation]:
105
- """
106
- Build a GPTQ PyTorch model and return it.
107
- Returns:
108
- GPTQ PyTorch model and user information.
109
- """
110
- return GPTQPytorchModel(self.graph,
111
- self.gptq_config,
112
- self.append2output,
113
- self.return_float_outputs), self.graph.user_info
@@ -1,71 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import torch
16
- import torch.nn as nn
17
- from typing import List, Union, Dict, Any
18
- from abc import abstractmethod
19
- from model_compression_toolkit.core.common import Logger
20
-
21
-
22
- class BaseWeightQuantizer(nn.Module):
23
-
24
- def __init__(self):
25
- """
26
- Construct a Base Pytorch model that utilizes a fake weight quantizer
27
- """
28
- super().__init__()
29
- self.trainable_params = dict()
30
-
31
- def get_trainable_params(self) -> List:
32
- """
33
- A function to get a list of trainable parameters of the quantizer for GPTQ retraining
34
- Returns:
35
- A list of trainable tensors
36
- """
37
- return [value for value in self.trainable_params.values() if value is not None]
38
-
39
- @abstractmethod
40
- def get_aux_variable(self) -> torch.Tensor:
41
- """
42
- Returns auxiliary trainable variables
43
- """
44
- raise Logger.error(f'{self.__class__.__name__} have to implement the this abstract function.')
45
-
46
- @abstractmethod
47
- def get_quantization_variable(self) -> Union[torch.Tensor, List]:
48
- """
49
- Returns quantization trainable variables
50
- """
51
- raise Logger.error(f'{self.__class__.__name__} have to implement the this abstract function.')
52
-
53
-
54
- @abstractmethod
55
- def get_weight_quantization_params(self) -> Dict[str, Any]:
56
- """
57
- Returns weight quantization dictionary params
58
- """
59
- raise Logger.error(f'{self.__class__.__name__} have to implement the this abstract function.')
60
-
61
- @abstractmethod
62
- def forward(self, w:nn.parameter, training:bool = True) -> torch.Tensor:
63
- """
64
- Forward-Pass
65
- Args:
66
- w: weights to quantize.
67
- training: whether in training mode or not
68
- Returns:
69
- quantized weights
70
- """
71
- raise Logger.error(f'{self.__class__.__name__} have to implement the this abstract function.')
@@ -1,14 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================