mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -14,35 +14,69 @@
14
14
  # ==============================================================================
15
15
  from typing import Any
16
16
 
17
- from keras.engine.input_layer import InputLayer
18
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapperV2
19
17
 
20
18
  from model_compression_toolkit.core.common import Logger
21
- from model_compression_toolkit.exporter.model_wrapper.keras.builder.quantize_config_to_node import \
22
- SUPPORTED_QUANTIZATION_CONFIG
23
- from model_compression_toolkit.exporter.model_wrapper.keras.extended_quantize_wrapper import ExtendedQuantizeWrapper
19
+ from model_compression_toolkit.core.common.constants import FOUND_TF
24
20
 
21
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
25
22
 
26
- def is_keras_layer_exportable(layer: Any) -> bool:
27
- """
28
- Check whether a Keras layer is a valid exportable layer or not.
29
23
 
30
- Args:
31
- layer: Keras layer to check if considered to be valid for exporting.
24
+ if FOUND_TF:
25
+ from keras.engine.input_layer import InputLayer
26
+ from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
32
27
 
33
- Returns:
28
+ def is_keras_layer_exportable(layer: Any) -> bool:
29
+ """
34
30
  Check whether a Keras layer is a valid exportable layer or not.
35
- """
36
- # Keras Input layers are not wrapped
37
- if isinstance(layer, InputLayer):
38
- return True
39
31
 
40
- valid_layer = isinstance(layer, ExtendedQuantizeWrapper)
41
- if not valid_layer:
42
- Logger.error(f'Exportable layer must be wrapped using ExtendedQuantizeWrapper, but layer {layer.name} is of type {type(layer)}')
32
+ Args:
33
+ layer: Keras layer to check if considered to be valid for exporting.
34
+
35
+ Returns:
36
+ Check whether a Keras layer is a valid exportable layer or not.
37
+ """
38
+ # Keras Input layers are not wrapped
39
+ if isinstance(layer, InputLayer):
40
+ return True
41
+
42
+ valid_layer = isinstance(layer, KerasQuantizationWrapper)
43
+ if not valid_layer:
44
+ Logger.error(
45
+ f'Exportable layer must be wrapped using KerasQuantizationWrapper, but layer {layer.name} is of type '
46
+ f'{type(layer)}') # pragma: no cover
47
+
48
+ valid_weights_quantizers = isinstance(layer.weights_quantizers, dict)
49
+ if not valid_weights_quantizers:
50
+ Logger.error(
51
+ f'KerasQuantizationWrapper must have a weights_quantizers but has a '
52
+ f'{type(layer.weights_quantizers)} object') # pragma: no cover
43
53
 
44
- valid_quantize_config = type(layer.quantize_config) in SUPPORTED_QUANTIZATION_CONFIG
45
- if not valid_quantize_config:
46
- Logger.error(f'QuantizeConfig of layer is not supported. Type: {type(layer.quantize_config)}. Supported configs: {SUPPORTED_QUANTIZATION_CONFIG}.')
54
+ for _, weights_quantizer in layer.weights_quantizers.items():
55
+ if not isinstance(weights_quantizer, BaseInferableQuantizer):
56
+ Logger.error(
57
+ f'weights_quantizer must be a BaseInferableQuantizer object but has a '
58
+ f'{type(weights_quantizer)} object') # pragma: no cover
47
59
 
48
- return True
60
+ valid_activation_quantizers = isinstance(layer.activation_quantizers, list)
61
+ if not valid_activation_quantizers:
62
+ Logger.error(
63
+ f'KerasQuantizationWrapper must have a activation_quantizers list but has a '
64
+ f'{type(layer.activation_quantizers)} object') # pragma: no cover
65
+
66
+ for activation_quantizers in layer.activation_quantizers:
67
+ if not isinstance(activation_quantizers, BaseInferableQuantizer):
68
+ Logger.error(
69
+ f'activation_quantizers must be a BaseInferableQuantizer object but has a '
70
+ f'{type(activation_quantizers)} object') # pragma: no cover
71
+
72
+ quantizers = layer.activation_quantizers + list(layer.weights_quantizers.values())
73
+ is_valid_quantizers = all([isinstance(x, BaseInferableQuantizer) for x in quantizers])
74
+ if not is_valid_quantizers:
75
+ Logger.error(f'Found a quantizer that is not of type BaseInferableQuantizer') # pragma: no cover
76
+
77
+ return True
78
+ else:
79
+ def is_keras_layer_exportable(*args, **kwargs): # pragma: no cover
80
+ Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
81
+ 'when using is_keras_layer_exportable. '
82
+ 'Could not find Tensorflow package.')
@@ -13,132 +13,49 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import List, Any, Tuple
17
-
18
- import torch
19
16
 
17
+ from model_compression_toolkit import quantizers_infrastructure as qi
20
18
  from model_compression_toolkit.core import common
21
- from model_compression_toolkit.core.common import BaseNode, Graph
22
- from model_compression_toolkit.core.common.user_info import UserInformation
23
- from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder
24
- from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder, \
25
- PytorchModel
26
- from model_compression_toolkit.core.pytorch.back2framework.quantization_wrapper.quantized_layer_wrapper import \
27
- QuantizedLayerWrapper
28
- from model_compression_toolkit.core.pytorch.constants import BUFFER, CONSTANT
29
- from model_compression_toolkit.core.pytorch.reader.node_holders import BufferHolder, ConstantHolder
30
- from model_compression_toolkit.core.pytorch.utils import get_working_device
31
-
32
-
33
- from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantize_config import get_quantization_config
34
-
35
-
36
- def get_fully_quantized_pytorch_model(graph: Graph):
37
- """
38
- Convert graph to fully quantized PyTorch model.
19
+ from model_compression_toolkit.core.common import Graph, Logger
20
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
39
21
 
40
- Args:
41
- graph: Graph to convert to a PyTorch model.
22
+ if FOUND_TORCH:
23
+ import torch
24
+ from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
25
+ from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizers import \
26
+ get_quantization_quantizers
42
27
 
43
- Returns:
44
- Fully quantized PyTorch model.
45
- """
46
- return FullyQuantizedPyTorchModelBuilder(graph=graph).build_model()
47
-
48
-
49
-
50
- class FullyQuantizedPyTorchModel(PytorchModel):
51
- """
52
- PyTorch model with all quantization information.
53
- """
54
-
55
- def __init__(self,
56
- graph: common.Graph):
28
+ def fully_quantized_wrapper(node: common.BaseNode, module: torch.nn.Module) -> qi.PytorchQuantizationWrapper:
57
29
  """
30
+ A function which takes a computational graph node and a pytorch module and
31
+ perform the quantization wrapping
58
32
 
59
33
  Args:
60
- graph: Graph to build its corresponding Pytorch model.
61
- """
34
+ node: A node of mct graph.
35
+ module: A Pytorch module
62
36
 
63
- super().__init__(graph)
37
+ Returns: Wrapped layer
64
38
 
65
-
66
- def _add_modules(self):
67
39
  """
40
+ weight_quantizers, activation_quantizers = get_quantization_quantizers(node)
41
+ wrapped_layer = qi.PytorchQuantizationWrapper(module, weight_quantizers, activation_quantizers)
42
+ return wrapped_layer
68
43
 
69
- Add nodes in graph as modules.
70
-
71
- """
72
- for n in self.node_sort:
73
- if n.type == BufferHolder:
74
- self.add_module(n.name, node_builder(n))
75
- self.get_submodule(n.name).register_buffer(n.name,
76
- torch.Tensor(n.get_weights_by_keys(BUFFER)).to(get_working_device()))
77
- elif n.type == ConstantHolder:
78
- self.add_module(n.name, node_builder(n))
79
- self.get_submodule(n.name).register_buffer(n.name,
80
- torch.Tensor(n.get_weights_by_keys(CONSTANT)).to(get_working_device()))
81
-
82
- else:
83
- # Create a wrapper based on the corresponding quantization config.
84
- layer_wrapper = QuantizedLayerWrapper(n, get_quantization_config(n))
85
- # Add the wrapped layer to the model.
86
- self.add_module(n.name, layer_wrapper)
87
-
88
- def _get_op_func(self,
89
- node: BaseNode,
90
- configurable_nodes_names: List[str]) -> Any:
91
- """
92
- Get the operator corresponding to the passed node.
93
-
94
- Args:
95
- node: Node to get its op.
96
- configurable_nodes_names: List of nodes that are configurable.
97
-
98
- Returns:
99
- Operator (module) of the node.
100
- """
101
- return getattr(self, node.name)
102
44
 
103
- def _quantize_node_activations(self,
104
- node: BaseNode,
105
- input_tensors: List[torch.Tensor]) -> List[torch.Tensor]:
45
+ def get_exportable_pytorch_model(graph: Graph):
106
46
  """
107
- Quantize node's activation given input tensors.
47
+ Convert graph to fully quantized PyTorch model.
108
48
 
109
49
  Args:
110
- node: Node to quantize its outputs.
111
- input_tensors: Input tensors of the node.
50
+ graph: Graph to convert to a PyTorch model.
112
51
 
113
52
  Returns:
114
- Output of the node.
115
-
116
- """
117
- return input_tensors
118
-
119
-
120
-
121
-
122
-
123
- class FullyQuantizedPyTorchModelBuilder(PyTorchModelBuilder):
124
- """
125
- Fully-Quantized PyTorch model.
126
- """
127
-
128
- def __init__(self,
129
- graph: common.Graph):
130
- """
131
-
132
- Args:
133
- graph: Graph to build the model from.
134
- """
135
-
136
- super().__init__(graph)
137
-
138
- def build_model(self) -> Tuple[PytorchModel, UserInformation]:
139
- """
140
- Build a PyTorch fully quantized model and return it.
141
- Returns: Fully quantized PyTorch model and user information.
142
-
143
- """
144
- return FullyQuantizedPyTorchModel(self.graph), self.graph.user_info
53
+ Fully quantized PyTorch model.
54
+ """
55
+ return PyTorchModelBuilder(graph=graph,
56
+ wrapper=fully_quantized_wrapper).build_model()
57
+ else:
58
+ def get_exportable_pytorch_model(*args, **kwargs): # pragma: no cover
59
+ Logger.error('Installing torch is mandatory '
60
+ 'when using get_exportable_pytorch_model. '
61
+ 'Could not find PyTorch package.')
@@ -12,31 +12,83 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import numpy as np
16
15
 
17
- from typing import List, Callable
16
+ from typing import Dict, Any
18
17
 
19
18
  from model_compression_toolkit.core.common import BaseNode, Logger
20
- from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX
21
- from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import calculate_delta
19
+ from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX, \
20
+ SCALE_PER_CHANNEL, CLUSTER_CENTERS
22
21
  from model_compression_toolkit.core.common.target_platform import QuantizationMethod
22
+ from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
23
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
24
+ get_inferable_quantizer_class
25
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
26
+ constants as qi_inferable_quantizers_constants, BasePyTorchInferableQuantizer
27
+ import numpy as np
23
28
 
24
- # Supporting other quantizer types in the future
25
- from model_compression_toolkit.exporter.model_wrapper.pytorch.quantizers.fq_quantizer import FakeQuantQuantizer
26
- from model_compression_toolkit.exporter.model_wrapper.pytorch.quantizers.uniform_weights_quantizer import \
27
- UniformWeightsQuantizer
28
- import torch
29
29
 
30
- SUPPORTED_WEIGHT_QUANTIZER_TYPES = [QuantizationMethod.POWER_OF_TWO,
31
- QuantizationMethod.SYMMETRIC,
32
- QuantizationMethod.UNIFORM]
30
+ def get_weights_inferable_quantizer_kwargs(node: BaseNode) -> Dict[str, Any]:
31
+ # Get the weights quantization configuration for the node
32
+ node_w_qc = node.final_weights_quantization_cfg
33
+ quantization_method = node_w_qc.weights_quantization_method
34
+
35
+ # Return the appropriate quantization parameters based on the quantization method
36
+ if quantization_method in [QuantizationMethod.POWER_OF_TWO,
37
+ QuantizationMethod.SYMMETRIC]:
38
+ return {qi_inferable_quantizers_constants.NUM_BITS: node_w_qc.weights_n_bits,
39
+ qi_inferable_quantizers_constants.THRESHOLD: node_w_qc.weights_quantization_params[THRESHOLD].flatten(),
40
+ qi_inferable_quantizers_constants.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
41
+ qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
42
+
43
+ elif quantization_method in [QuantizationMethod.UNIFORM]:
44
+ return {qi_inferable_quantizers_constants.NUM_BITS: node_w_qc.weights_n_bits,
45
+ qi_inferable_quantizers_constants.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
46
+ qi_inferable_quantizers_constants.MIN_RANGE: node_w_qc.weights_quantization_params[RANGE_MIN].flatten(),
47
+ qi_inferable_quantizers_constants.MAX_RANGE: node_w_qc.weights_quantization_params[RANGE_MAX].flatten(),
48
+ qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
49
+
50
+ elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER, QuantizationMethod.LUT_SYM_QUANTIZER]:
51
+ return {qi_inferable_quantizers_constants.NUM_BITS: node_w_qc.weights_n_bits,
52
+ qi_inferable_quantizers_constants.CLUSTER_CENTERS: node_w_qc.weights_quantization_params[CLUSTER_CENTERS].flatten(),
53
+ qi_inferable_quantizers_constants.THRESHOLD: node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten(),
54
+ qi_inferable_quantizers_constants.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
55
+ qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
56
+ # TODO: Add MULTIPLIER_N_BITS & EPS to node quantization config
33
57
 
34
- SUPPORTED_ACTIVATION_QUANTIZER_TYPES = [QuantizationMethod.POWER_OF_TWO,
35
- QuantizationMethod.SYMMETRIC,
36
- QuantizationMethod.UNIFORM]
58
+ else:
59
+ Logger.critical(f'Not supported quantization method for weights inferable quantizers.') # pragma: no cover
60
+
61
+
62
+ def get_activation_inferable_quantizer_kwargs(node: BaseNode) -> Dict[str, Any]:
63
+ # Get the activation quantization configuration for the node
64
+ node_qc = node.final_activation_quantization_cfg
65
+ quantization_method = node_qc.activation_quantization_method
66
+
67
+ # Return the appropriate quantization parameters based on the quantization method
68
+ if quantization_method in [QuantizationMethod.POWER_OF_TWO,
69
+ QuantizationMethod.SYMMETRIC]:
70
+ return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.activation_n_bits,
71
+ qi_inferable_quantizers_constants.THRESHOLD: np.asarray([node_qc.activation_quantization_params[THRESHOLD]]),
72
+ qi_inferable_quantizers_constants.SIGNED: node_qc.activation_quantization_params.get(SIGNED)}
73
+
74
+ elif quantization_method in [QuantizationMethod.UNIFORM]:
75
+ return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.activation_n_bits,
76
+ qi_inferable_quantizers_constants.MIN_RANGE: np.asarray([node_qc.activation_quantization_params[RANGE_MIN]]),
77
+ qi_inferable_quantizers_constants.MAX_RANGE: np.asarray([node_qc.activation_quantization_params[RANGE_MAX]])}
78
+
79
+ elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER]:
80
+ return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.activation_n_bits,
81
+ qi_inferable_quantizers_constants.CLUSTER_CENTERS: np.asarray(
82
+ [node_qc.activation_quantization_params[CLUSTER_CENTERS]]),
83
+ qi_inferable_quantizers_constants.THRESHOLD: np.asarray(
84
+ [node_qc.activation_quantization_params[THRESHOLD]]),
85
+ qi_inferable_quantizers_constants.SIGNED: node_qc.activation_quantization_params.get(SIGNED)}
86
+ # TODO: Add MULTIPLIER_N_BITS & EPS to node quantization config
87
+ else:
88
+ Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
37
89
 
38
90
 
39
- def get_weights_quantizer_for_node(node: BaseNode) -> List[Callable]:
91
+ def get_weights_quantizer_for_node(node: BaseNode) -> BasePyTorchInferableQuantizer:
40
92
  """
41
93
  Get weights quantizer for a node.
42
94
 
@@ -48,41 +100,20 @@ def get_weights_quantizer_for_node(node: BaseNode) -> List[Callable]:
48
100
 
49
101
  """
50
102
  if node.final_weights_quantization_cfg is None:
51
- Logger.critical(f'Can not set quantizer for a node with no final weights quantization configuration')
52
-
103
+ Logger.critical(f'Can not set quantizer for a node with no final weights quantization configuration') # pragma:
104
+ # no cover
53
105
  node_w_qc = node.final_weights_quantization_cfg
54
106
  weights_quantization_method = node_w_qc.weights_quantization_method
55
107
 
56
- if weights_quantization_method not in SUPPORTED_WEIGHT_QUANTIZER_TYPES:
57
- Logger.error(f'Fully quantized models are now supported for {SUPPORTED_WEIGHT_QUANTIZER_TYPES} quantization methods, but node has {weights_quantization_method} quantization method')
58
-
59
- # Compute quantizer params based on node's quantization params
60
- if weights_quantization_method in [QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC]:
61
- weight_thresholds = node_w_qc.weights_quantization_params.get(THRESHOLD)
62
- assert weight_thresholds is not None
63
- if weights_quantization_method == QuantizationMethod.POWER_OF_TWO:
64
- is_threshold_pot = np.all([int(np.log2(x)) == np.log2(x) for x in weight_thresholds.flatten()])
65
- if not is_threshold_pot:
66
- Logger.error(f'Expected threshold to be power of 2 but is {weight_thresholds}')
67
-
68
- min_range = -weight_thresholds
69
- max_range = weight_thresholds - calculate_delta(weight_thresholds,
70
- n_bits=node_w_qc.weights_n_bits,
71
- signed=True)
72
-
73
- else:
74
- Logger.error(f'For now fully quantized models support only {SUPPORTED_WEIGHT_QUANTIZER_TYPES} for weights quantization, but found {weights_quantization_method}')
108
+ quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Weights,
109
+ weights_quantization_method,
110
+ BasePyTorchInferableQuantizer)
111
+ kwargs = get_weights_inferable_quantizer_kwargs(node)
75
112
 
76
- return [UniformWeightsQuantizer(num_bits=node_w_qc.weights_n_bits,
77
- max_range=max_range,
78
- min_range=min_range,
79
- quantization_method=node_w_qc.weights_quantization_method,
80
- per_channel=node_w_qc.weights_per_channel_threshold,
81
- output_channels_axis=node_w_qc.weights_channels_axis
82
- )]
113
+ return quantier_for_node(**kwargs)
83
114
 
84
115
 
85
- def get_activations_quantizer_for_node(node: BaseNode) -> List[Callable]:
116
+ def get_activations_quantizer_for_node(node: BaseNode) -> BasePyTorchInferableQuantizer:
86
117
  """
87
118
  Get activation quantizer for a node.
88
119
 
@@ -93,43 +124,16 @@ def get_activations_quantizer_for_node(node: BaseNode) -> List[Callable]:
93
124
  Quantizer for the node's activations.
94
125
 
95
126
  """
96
-
97
127
  if node.final_activation_quantization_cfg is None:
98
- Logger.critical(f'Can not set quantizer for a node with no final activation quantization configuration')
99
-
128
+ Logger.critical(f'Can not set quantizer for a node with no final activation quantization configuration') #
129
+ # pragma: no cover
100
130
  node_act_qc = node.final_activation_quantization_cfg
101
131
  activation_quantization_method = node_act_qc.activation_quantization_method
102
132
 
103
- if activation_quantization_method not in SUPPORTED_ACTIVATION_QUANTIZER_TYPES:
104
- Logger.error(
105
- f'Fully quantized models are now supported for {SUPPORTED_ACTIVATION_QUANTIZER_TYPES} quantization methods, '
106
- f'but node has {activation_quantization_method} quantization method')
107
-
108
- activation_thresholds = node_act_qc.activation_quantization_params.get(THRESHOLD)
109
-
110
- if activation_quantization_method in [QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC]:
111
- if activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
112
- is_threshold_pot = np.all([int(np.log2(x)) == np.log2(x) for x in activation_thresholds.flatten()])
113
- if not is_threshold_pot:
114
- Logger.error(f'Expected threshold to be power of 2 but is {node_act_qc.activation_quantization_params.get(THRESHOLD)}')
133
+ quantizer_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
134
+ activation_quantization_method,
135
+ BasePyTorchInferableQuantizer)
136
+ kwargs = get_activation_inferable_quantizer_kwargs(node)
115
137
 
116
- min_range = 0
117
- if node_act_qc.activation_quantization_params.get(SIGNED):
118
- min_range = -activation_thresholds
119
-
120
- max_range = activation_thresholds - calculate_delta(
121
- activation_thresholds,
122
- n_bits=node_act_qc.activation_n_bits,
123
- signed=node_act_qc.activation_quantization_params.get(SIGNED))
124
-
125
- elif activation_quantization_method in [QuantizationMethod.UNIFORM]:
126
- min_range = node_act_qc.activation_quantization_params.get(RANGE_MIN)
127
- max_range = node_act_qc.activation_quantization_params.get(RANGE_MAX)
128
-
129
- else:
130
- Logger.error(f'For now fully quantized models support only {SUPPORTED_ACTIVATION_QUANTIZER_TYPES} for activation quantization, but found {activation_quantization_method}')
138
+ return quantizer_for_node(**kwargs)
131
139
 
132
- return [FakeQuantQuantizer(nbits=node_act_qc.activation_n_bits,
133
- min_range=min_range,
134
- max_range=max_range,
135
- quantization_method=activation_quantization_method)]
@@ -0,0 +1,47 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Tuple, List, Dict
16
+ from model_compression_toolkit.core.common import BaseNode
17
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
18
+ from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizer import \
19
+ get_activations_quantizer_for_node, get_weights_quantizer_for_node
20
+
21
+
22
+ def get_quantization_quantizers(node: BaseNode) -> Tuple[Dict, List]:
23
+ """
24
+ Create quantizers to wrap a layer for its corresponding node.
25
+
26
+ Args:
27
+ node: Node to create quantizers for.
28
+
29
+ Returns:
30
+ weight_quantizers: A dictionary between a weight's name to its quantizer.
31
+ activation_quantizers: A list of activations quantization, one for each layer output.
32
+
33
+ """
34
+ weight_quantizers = {}
35
+ if node.is_weights_quantization_enabled():
36
+ weight_attrs = DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(node.type)
37
+ weight_quantizer = get_weights_quantizer_for_node(node)
38
+ for attr in weight_attrs:
39
+ weight_quantizers[attr] = weight_quantizer
40
+
41
+ activation_quantizers = []
42
+ if node.is_activation_quantization_enabled():
43
+ num_of_outputs = len(node.output_shape) if isinstance(node.output_shape, list) else 1
44
+ activation_quantizers = [get_activations_quantizer_for_node(node)] * num_of_outputs
45
+
46
+
47
+ return weight_quantizers, activation_quantizers
@@ -0,0 +1,44 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Any
16
+
17
+ from model_compression_toolkit.core.common import Logger
18
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
+
20
+ if FOUND_TORCH:
21
+ from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
23
+ BasePyTorchInferableQuantizer
24
+ def is_pytorch_layer_exportable(layer: Any) -> bool:
25
+ """
26
+ Check whether a torch Module is a valid exportable module or not.
27
+
28
+ Args:
29
+ layer: PyTorch module to check if considered to be valid for exporting.
30
+
31
+ Returns:
32
+ Check whether a PyTorch layer is a valid exportable layer or not.
33
+ """
34
+ if isinstance(layer, PytorchQuantizationWrapper):
35
+ quantizers = list(layer.weights_quantizers.values())
36
+ quantizers.extend(layer.activation_quantizers)
37
+ if all([isinstance(q, BasePyTorchInferableQuantizer) for q in quantizers]):
38
+ return True
39
+ return False
40
+ else:
41
+ def is_pytorch_layer_exportable(*args, **kwargs): # pragma: no cover
42
+ Logger.error('Installing torch is mandatory '
43
+ 'when using is_pytorch_layer_exportable. '
44
+ 'Could not find PyTorch package.')
@@ -12,3 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
16
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GradientPTQConfigV2
17
+ from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization_experimental
18
+ from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
19
+ from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization_experimental
20
+ from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config