mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -0,0 +1,269 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================f
15
+ from typing import List, Union, Any, Dict, Tuple
16
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
17
+ from model_compression_toolkit.core.common.logger import Logger
18
+ from model_compression_toolkit.quantizers_infrastructure import BaseInferableQuantizer
19
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import LAYER, TRAINING
20
+ import inspect
21
+
22
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.pytorch.base_pytorch_quantizer import \
23
+ BasePytorchTrainableQuantizer
24
+
25
+ if FOUND_TORCH:
26
+ import torch
27
+ import torch.nn as nn
28
+
29
+
30
+ class PytorchQuantizationWrapper(nn.Module):
31
+ def __init__(self,
32
+ module: nn.Module,
33
+ weights_quantizers: Dict[str, BaseInferableQuantizer] = None,
34
+ activation_quantizers: List[BaseInferableQuantizer] = None):
35
+ """
36
+ Pytorch Quantization Wrapper takes a pytorch module and quantizers and infer a quantized module.
37
+
38
+ Args:
39
+ module: A pytorch module.
40
+ weights_quantizers: A dictionary between a weight's name to its quantizer.
41
+ activation_quantizers: A list of activations quantization, one for each layer output.
42
+ """
43
+ super().__init__()
44
+ if isinstance(module, nn.Module):
45
+ self.add_module(LAYER, module)
46
+ else:
47
+ # Functional layers
48
+ setattr(self, LAYER, module)
49
+
50
+ self.weights_quantizers = weights_quantizers if weights_quantizers is not None else dict()
51
+ self.activation_quantizers = activation_quantizers if activation_quantizers is not None else list()
52
+ self._set_weights_vars(True)
53
+ self._set_activation_vars()
54
+
55
+ def add_weights_quantizer(self, param_name: str, quantizer: BaseInferableQuantizer):
56
+ """
57
+ This function adds a weights quantizer to existing wrapper
58
+
59
+ Args:
60
+ param_name: The name of the parameter to quantize
61
+ quantizer: A quantizer.
62
+
63
+ Returns: None
64
+
65
+ """
66
+ self.weights_quantizers.update({param_name: quantizer})
67
+
68
+ @property
69
+ def is_activation_quantization(self) -> bool:
70
+ """
71
+ This function check activation quantizer exists in wrapper.
72
+ Returns: a boolean if activation quantizer exists
73
+
74
+ """
75
+ return self.num_activation_quantizers > 0
76
+
77
+ @property
78
+ def is_weights_quantization(self) -> bool:
79
+ """
80
+ This function check weights quantizer exists in wrapper.
81
+
82
+ Returns: a boolean if weights quantizer exists
83
+
84
+ """
85
+ return self.num_weights_quantizers > 0
86
+
87
+ @property
88
+ def num_weights_quantizers(self) -> int:
89
+ """
90
+ Returns: number of weights quantizers
91
+ """
92
+ return len(self.weights_quantizers)
93
+
94
+ @property
95
+ def num_activation_quantizers(self) -> int:
96
+ """
97
+ Returns: number of activations quantizers
98
+ """
99
+ return len(self.activation_quantizers)
100
+
101
+ def convert_to_inferable_quantizers(self):
102
+ """
103
+ Convert the wrapper quantizers with inferable quantizers
104
+
105
+ """
106
+ # Activation quantizers
107
+ if self.is_activation_quantization:
108
+ inferable_activation_quantizers = []
109
+ for quantizer in self.activation_quantizers:
110
+ if isinstance(quantizer, BasePytorchTrainableQuantizer):
111
+ inferable_activation_quantizers.append(quantizer.convert2inferable())
112
+ else:
113
+ Logger.error('Can only convert trainable quantizers based on BasePytorchTrainableQuantizer') # pragma: no cover
114
+ self.activation_quantizers = inferable_activation_quantizers
115
+ self._set_activation_vars()
116
+
117
+ # Weight quantizers
118
+ if self.is_weights_quantization:
119
+ inferable_weight_quantizers = {}
120
+ for name, quantizer in self.weights_quantizers.items():
121
+ if isinstance(quantizer, BasePytorchTrainableQuantizer):
122
+ inferable_weight_quantizers.update({name: quantizer.convert2inferable()})
123
+ else:
124
+ Logger.error('Can only convert trainable quantizers based on BasePytorchTrainableQuantizer') # pragma: no cover
125
+ self.weights_quantizers = inferable_weight_quantizers
126
+ self._set_weights_vars(False)
127
+
128
+ def _set_weights_vars(self, is_training: bool = True):
129
+ """
130
+ Initialize learnable weights as parameters in the wrapper, and their quantizers
131
+
132
+ Args:
133
+ is_training: Whether working with InferableQuantizers or not. If so, do not register weight as parameter.
134
+
135
+ """
136
+ self._weights_vars = []
137
+
138
+ # Init weights quantizers
139
+ for name, quantizer in self.weights_quantizers.items():
140
+ if is_training:
141
+ weight = getattr(self.layer, name).detach()
142
+ delattr(self.layer, name)
143
+ setattr(self.layer, name, weight)
144
+ self.register_parameter(name, torch.nn.Parameter(weight, requires_grad=True))
145
+ else:
146
+ weight = getattr(self, name).detach()
147
+ delattr(self.layer, name)
148
+ setattr(self.layer, name, weight)
149
+ quantizer.initialize_quantization(weight.shape, name, self)
150
+ self._weights_vars.append((name, getattr(self, name), quantizer))
151
+
152
+ def _set_activation_vars(self):
153
+ """
154
+ Initialize layer outputs and their quantizers in the wrapper
155
+ """
156
+ self._activation_vars = []
157
+ for i, quantizer in enumerate(self.activation_quantizers):
158
+ quantizer.initialize_quantization(None, f"tensor{i}", self)
159
+ self._activation_vars.append(quantizer)
160
+
161
+ def set_quantize_weights(self, quantized_weights: dict):
162
+ """
163
+ This function updates layer weights after quantization.
164
+
165
+ Args:
166
+ quantized_weights: a dict of weight to update
167
+
168
+ Returns: None
169
+
170
+ """
171
+ for weight_attr in self.weights_quantizers.keys():
172
+ weight = quantized_weights.get(weight_attr)
173
+ setattr(self.layer, weight_attr, weight)
174
+
175
+ def get_weights_vars(self) -> List[Tuple[str, Any, BaseInferableQuantizer]]:
176
+ """
177
+ A getter of the layer's weights variables.
178
+
179
+ Returns:
180
+ List pf tuples of the wrapped layer's weights variables with weight name, values and assigned quantizer.
181
+
182
+ """
183
+
184
+ return self._weights_vars
185
+
186
+ def forward(self,
187
+ *args: List[Any],
188
+ **kwargs: Dict[str, Any]) -> Union[torch.Tensor, List[torch.Tensor]]:
189
+ """
190
+ PytorchQuantizationWrapper forward functions
191
+ Args:
192
+ args: arguments to pass to internal layer.
193
+ kwargs: key-word dictionary to pass to the internal layer.
194
+
195
+ Returns: a tensor that simulates a quantized layer.
196
+
197
+ """
198
+
199
+ # ----------------------------------
200
+ # Quantize all weights, and replace them in the underlying layer.
201
+ # ----------------------------------
202
+ if self.is_weights_quantization:
203
+
204
+ quantized_weights = {}
205
+ for name, unquantized_weight, quantizer in self._weights_vars:
206
+ s = inspect.signature(quantizer.__call__)
207
+ if TRAINING in s.parameters.keys():
208
+ quantized_weight = quantizer(unquantized_weight, self.training)
209
+ else:
210
+ quantized_weight = quantizer(unquantized_weight)
211
+
212
+ quantized_weights.update({name: quantized_weight})
213
+
214
+ self.set_quantize_weights(quantized_weights)
215
+
216
+ # ----------------------------------
217
+ # Layer operation
218
+ # ----------------------------------
219
+ outputs = self.layer(*args, **kwargs)
220
+
221
+ # ----------------------------------
222
+ # Quantize all activations
223
+ # ----------------------------------
224
+ if self.is_activation_quantization:
225
+
226
+ if not isinstance(outputs, list):
227
+ outputs = [outputs]
228
+
229
+ if len(outputs) != self.num_activation_quantizers:
230
+ Logger.error(f"Number of outputs {len(outputs)} is incompatible number of activation quantizers {self.num_activation_quantizers}") # pragma: no cover
231
+
232
+ # Quantize all activations tensors
233
+ outputs_quantized = []
234
+ for quantizer, output in zip(self._activation_vars, outputs):
235
+ outputs_quantized.append(quantizer(output))
236
+
237
+ outputs = outputs_quantized[0] if len(outputs_quantized) == 1 else outputs_quantized
238
+
239
+ return outputs
240
+
241
+ def get_quantized_weights(self) -> Dict[str, torch.Tensor]:
242
+ """
243
+
244
+ Returns: A dictionary of weights attributes to quantized weights.
245
+
246
+ """
247
+ quantized_weights = {}
248
+ weights_var = self.get_weights_vars()
249
+ for name, w, quantizer in weights_var:
250
+ quantized_weights[name] = quantizer(w)
251
+ return quantized_weights
252
+
253
+ else:
254
+ class PytorchQuantizationWrapper(object):
255
+ def __init__(self,
256
+ layer,
257
+ weight_quantizers: Dict[str, BaseInferableQuantizer] = None,
258
+ activation_quantizers: List[BaseInferableQuantizer] = None):
259
+ """
260
+ Pytorch Quantization Wrapper takes a pytorch module and quantizers and infer a quantized layer.
261
+
262
+ Args:
263
+ layer: A pytorch module.
264
+ weight_quantizers: A dictionary between a weight's name to its quantizer.
265
+ activation_quantizers: A list of activations quantization, one for each layer output.
266
+ """
267
+ Logger.critical('Installing Pytorch is mandatory '
268
+ 'when using PytorchQuantizationWrapper. '
269
+ 'Could not find torch package.') # pragma: no cover
@@ -0,0 +1,152 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ import numpy as np
19
+
20
+
21
+ def get_working_device():
22
+ """
23
+ Get the working device of the environment
24
+
25
+ Returns:
26
+ Device "cuda" if GPU is available, else "cpu"
27
+
28
+ """
29
+ return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
+
31
+
32
+ def to_torch_tensor(tensor):
33
+ """
34
+ Convert a Numpy array to a Torch tensor.
35
+ Args:
36
+ tensor: Numpy array.
37
+
38
+ Returns:
39
+ Torch tensor converted from the input Numpy array.
40
+ """
41
+ working_device = get_working_device()
42
+ if isinstance(tensor, torch.Tensor):
43
+ return tensor.to(working_device)
44
+ elif isinstance(tensor, list):
45
+ return [to_torch_tensor(t) for t in tensor]
46
+ elif isinstance(tensor, tuple):
47
+ return (to_torch_tensor(t) for t in tensor)
48
+ elif isinstance(tensor, np.ndarray):
49
+ return torch.from_numpy(tensor.astype(np.float32)).to(working_device)
50
+ elif isinstance(tensor, float):
51
+ return torch.Tensor([tensor]).to(working_device)
52
+ elif isinstance(tensor, int):
53
+ return torch.Tensor([tensor]).int().to(working_device)
54
+ else:
55
+ raise Exception(f'Conversion of type {type(tensor)} to {type(torch.Tensor)} is not supported')
56
+
57
+
58
+ def fix_range_to_include_zero(range_min: torch.Tensor,
59
+ range_max: torch.Tensor,
60
+ n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ """
62
+ Adjusting the quantization range to include representation of 0.0 in the quantization grid.
63
+ If quantization per-channel, then range_min and range_max should be tensors in the specific shape that allows
64
+ quantization along the channel_axis.
65
+ Args:
66
+ range_min: min bound of the quantization range (before adjustment).
67
+ range_max: max bound of the quantization range (before adjustment).
68
+ n_bits: Number of bits to quantize the tensor.
69
+ Returns: adjusted quantization range
70
+ """
71
+ min_positive = range_min > 0
72
+ max_negative = range_max < 0
73
+ mid_range = torch.logical_and(torch.logical_not(min_positive), torch.logical_not(max_negative))
74
+ min_positive = min_positive.float()
75
+ max_negative = max_negative.float()
76
+ mid_range = mid_range.float()
77
+
78
+ scale = (range_max - range_min) / (2 ** n_bits - 1)
79
+ min_range_adj = scale * torch.round(range_min / scale)
80
+ max_range_adj = range_max - range_min + min_range_adj
81
+
82
+ min_range_adj = min_range_adj * mid_range + max_negative * range_min
83
+ max_range_adj = max_range_adj * mid_range + min_positive * range_max
84
+ return min_range_adj, max_range_adj
85
+
86
+
87
+ def lut_quantizer(tensor_data: torch.Tensor,
88
+ cluster_centers: torch.Tensor,
89
+ signed: bool,
90
+ threshold: torch.Tensor,
91
+ multiplier_n_bits: int,
92
+ eps: float) -> torch.Tensor:
93
+ """
94
+ Quantize a tensor using a non-uniform quantization based on the pre-defined clusters.
95
+ 1. Scales tensor_data with the threshold into n-bit quantization range.
96
+ 2. Assigns cluster centers to each value.
97
+ 3. Scales back by multiplying the result by threshold and dividing with the quantization range max value.
98
+ The result is the quantized tensor.
99
+
100
+ Args:
101
+ tensor_data: Input activation tensor.
102
+ cluster_centers: The cluster centers to assign the tensor values.
103
+ signed: Whether the quantization is signed or not.
104
+ threshold: Threshold for quantization.
105
+ multiplier_n_bits: Number of bits that determines the quantization range
106
+ eps: Small value for numerical stability in division.
107
+
108
+ Returns: Quantized tensor.
109
+ """
110
+
111
+ tensor = int_quantization_with_threshold(tensor_data, n_bits=multiplier_n_bits, signed=signed, threshold=threshold,
112
+ eps=eps)
113
+ tensor = tensor.unsqueeze(-1)
114
+
115
+ expanded_cluster_centers = cluster_centers.reshape([*[1 for _ in range(len(tensor.shape) - 1)], -1])
116
+ cluster_assignments = torch.argmin(torch.abs(tensor - expanded_cluster_centers), dim=-1)
117
+ centers = cluster_centers.flatten()[cluster_assignments]
118
+
119
+ quant_tensor = (centers / (2 ** (multiplier_n_bits - int(signed)))) * threshold
120
+
121
+ return quant_tensor
122
+
123
+
124
+ def int_quantization_with_threshold(data: torch.Tensor,
125
+ n_bits: int,
126
+ signed: bool,
127
+ threshold: torch.Tensor,
128
+ eps: float) -> torch.Tensor:
129
+ """
130
+ Divides data by threshold and quantize it to integers in the quantization range (depends on signed value).
131
+
132
+ Args:
133
+ data: Tensor data.
134
+ n_bits: Number of bits that determines the quantization range.
135
+ signed: Whether the quantization is signed or not.
136
+ threshold: Threshold for quantization.
137
+ eps: Small value for numerical stability in division.
138
+
139
+ Returns:
140
+ Uniform Quantized tensor.
141
+
142
+ """
143
+
144
+ if signed:
145
+ clip_max = 2 ** (n_bits - 1) - 1
146
+ clip_min = -2 ** (n_bits - 1)
147
+ else:
148
+ clip_max = 2 ** n_bits - 1
149
+ clip_min = 0
150
+
151
+ return torch.clip((data / (threshold + eps)) * (2 ** (n_bits - int(signed))),
152
+ min=clip_min, max=clip_max)
@@ -0,0 +1,35 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.activation_inferable_quantizers.activation_pot_inferable_quantizer \
17
+ import ActivationPOTInferableQuantizer
18
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.activation_inferable_quantizers.activation_symmetric_inferable_quantizer \
19
+ import ActivationSymmetricInferableQuantizer
20
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.activation_inferable_quantizers.activation_uniform_inferable_quantizer \
21
+ import ActivationUniformInferableQuantizer
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.activation_inferable_quantizers.activation_lut_pot_inferable_quantizer \
23
+ import ActivationLutPOTInferableQuantizer
24
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.base_pytorch_inferable_quantizer \
25
+ import BasePyTorchInferableQuantizer
26
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.weights_inferable_quantizers.weights_pot_inferable_quantizer \
27
+ import WeightsPOTInferableQuantizer
28
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.weights_inferable_quantizers.weights_symmetric_inferable_quantizer \
29
+ import WeightsSymmetricInferableQuantizer
30
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.weights_inferable_quantizers.weights_uniform_inferable_quantizer \
31
+ import WeightsUniformInferableQuantizer
32
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.weights_inferable_quantizers.weights_lut_symmetric_inferable_quantizer \
33
+ import WeightsLUTSymmetricInferableQuantizer
34
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.weights_inferable_quantizers.weights_lut_pot_inferable_quantizer \
35
+ import WeightsLUTPOTInferableQuantizer
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -0,0 +1,97 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import numpy as np
16
+
17
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer \
20
+ import mark_quantizer, QuantizationTarget
21
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants \
22
+ import MULTIPLIER_N_BITS, EPS
23
+
24
+ if FOUND_TORCH:
25
+ import torch
26
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizer_utils \
27
+ import to_torch_tensor, get_working_device, lut_quantizer
28
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers \
29
+ .base_lut_symmetric_inferable_quantizer import BaseLUTSymmetricInferableQuantizer
30
+
31
+
32
+ @mark_quantizer(quantization_target=QuantizationTarget.Activation,
33
+ quantization_method=[QuantizationMethod.LUT_POT_QUANTIZER],
34
+ quantizer_type=None)
35
+ class ActivationLutPOTInferableQuantizer(BaseLUTSymmetricInferableQuantizer):
36
+ """
37
+ Class for quantizing activations using a lut power-of-two quantizer
38
+ """
39
+
40
+ def __init__(self,
41
+ num_bits: int,
42
+ cluster_centers: np.ndarray,
43
+ threshold: np.ndarray,
44
+ signed: bool,
45
+ multiplier_n_bits: int = MULTIPLIER_N_BITS,
46
+ eps: float = EPS):
47
+ """
48
+ Initialize the quantizer with the specified parameters.
49
+
50
+ Args:
51
+ num_bits: number of bits to use for quantization
52
+ cluster_centers: the cluster centers to assign the activations
53
+ threshold: threshold for quantizing activations
54
+ signed: whether to use signed quantization or not
55
+ multiplier_n_bits: Number of bits that determines the quantization range
56
+ eps: Small value for numerical stability in division
57
+ """
58
+
59
+ super(ActivationLutPOTInferableQuantizer, self).__init__(
60
+ num_bits=num_bits,
61
+ cluster_centers=cluster_centers,
62
+ threshold=threshold,
63
+ signed=signed,
64
+ multiplier_n_bits=multiplier_n_bits,
65
+ eps=eps)
66
+
67
+ is_threshold_pot = np.all(np.round(np.log2(threshold.flatten())) == np.log2(threshold.flatten()))
68
+ assert is_threshold_pot, f'Expected threshold to be power of 2 but is {threshold}'
69
+
70
+ # Activation supports only per-tensor quantization
71
+ assert len(
72
+ self.threshold) == 1, f'For activation, quantization per channel is not supported and threshold ' \
73
+ f'should be of length 1 but is {len(threshold)}'
74
+ self.threshold = self.threshold[0]
75
+
76
+ self.cluster_centers = to_torch_tensor(self.cluster_centers).to(get_working_device())
77
+
78
+ def __call__(self, inputs: torch.Tensor):
79
+ """
80
+ Quantize the given inputs using the quantizer parameters.
81
+
82
+ Args:
83
+ inputs: input tensor to quantize
84
+
85
+ Returns:
86
+ quantized tensor.
87
+ """
88
+ inputs.requires_grad = False
89
+ return lut_quantizer(inputs, cluster_centers=self.cluster_centers, signed=self.signed,
90
+ threshold=self.threshold, multiplier_n_bits=self.multiplier_n_bits, eps=self.eps)
91
+
92
+ else:
93
+ class ActivationLutPOTInferableQuantizer: # pragma: no cover
94
+ def __init__(self, *args, **kwargs):
95
+ raise Exception('Installing torch is mandatory '
96
+ 'when using ActivationLutPOTInferableQuantizer. '
97
+ 'Could not find torch package.')
@@ -0,0 +1,62 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numpy as np
17
+
18
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
21
+ QuantizationTarget
22
+
23
+ if FOUND_TORCH:
24
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.activation_inferable_quantizers.activation_symmetric_inferable_quantizer import \
25
+ ActivationSymmetricInferableQuantizer
26
+
27
+
28
+ @mark_quantizer(quantization_target=QuantizationTarget.Activation,
29
+ quantization_method=[QuantizationMethod.POWER_OF_TWO],
30
+ quantizer_type=None)
31
+ class ActivationPOTInferableQuantizer(ActivationSymmetricInferableQuantizer):
32
+ """
33
+ Class for quantizing activations using power-of-two quantizer
34
+ """
35
+
36
+ def __init__(self,
37
+ num_bits: int,
38
+ threshold: np.ndarray,
39
+ signed: bool):
40
+ """
41
+ Initialize the quantizer with the specified parameters.
42
+
43
+ Args:
44
+ num_bits: number of bits to use for quantization
45
+ threshold: threshold for quantizing activations
46
+ signed: whether to use signed quantization or not
47
+ """
48
+ # target of Activation quantization
49
+ super(ActivationPOTInferableQuantizer, self).__init__(num_bits=num_bits,
50
+ signed=signed,
51
+ threshold=threshold)
52
+
53
+ is_threshold_pot = np.all(np.round(np.log2(threshold.flatten()))==np.log2(threshold.flatten()))
54
+ assert is_threshold_pot, f'Expected threshold to be power of 2 but is {threshold}'
55
+
56
+
57
+ else:
58
+ class ActivationPOTInferableQuantizer: # pragma: no cover
59
+ def __init__(self, *args, **kwargs):
60
+ raise Exception('Installing torch is mandatory '
61
+ 'when using ActivationPOTInferableQuantizer. '
62
+ 'Could not find torch package.')