mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -0,0 +1,83 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import numpy as np
16
+
17
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
20
+ QuantizationTarget
21
+
22
+ if FOUND_TORCH:
23
+ import torch
24
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.base_symmetric_inferable_quantizer import \
25
+ BaseSymmetricInferableQuantizer
26
+
27
+
28
+ @mark_quantizer(quantization_target=QuantizationTarget.Activation,
29
+ quantization_method=[QuantizationMethod.SYMMETRIC],
30
+ quantizer_type=None)
31
+ class ActivationSymmetricInferableQuantizer(BaseSymmetricInferableQuantizer):
32
+ """
33
+ Class for quantizing activations using a symmetric quantizer
34
+ """
35
+
36
+ def __init__(self,
37
+ num_bits: int,
38
+ threshold: np.ndarray,
39
+ signed: bool):
40
+ """
41
+ Initialize the quantizer with the specified parameters.
42
+
43
+ Args:
44
+ num_bits: number of bits to use for quantization
45
+ threshold: threshold for quantizing activations
46
+ signed: whether to use signed quantization or not
47
+ """
48
+
49
+ super(ActivationSymmetricInferableQuantizer, self).__init__(
50
+ num_bits=num_bits,
51
+ threshold=threshold,
52
+ signed=signed)
53
+
54
+ # Activation supports only per-tensor quantization
55
+ assert len(
56
+ self.scales) == 1, f'For activation, quantization per channel is not supported and threshold should ' \
57
+ f'be of length 1 but is {len(threshold)}'
58
+ self.scales = self.scales[0]
59
+
60
+ self.zero_points = 0
61
+
62
+ def __call__(self, inputs: torch.Tensor):
63
+ """
64
+ Quantize the given inputs using the quantizer parameters.
65
+
66
+ Args:
67
+ inputs: input tensor to quantize
68
+
69
+ Returns:
70
+ quantized tensor.
71
+ """
72
+ return torch.fake_quantize_per_tensor_affine(inputs,
73
+ scale=self.scales,
74
+ zero_point=self.zero_points,
75
+ quant_min=self.min_quantized_domain,
76
+ quant_max=self.max_quantized_domain)
77
+
78
+ else:
79
+ class ActivationSymmetricInferableQuantizer: # pragma: no cover
80
+ def __init__(self, *args, **kwargs):
81
+ raise Exception('Installing torch is mandatory '
82
+ 'when using ActivationSymmetricInferableQuantizer. '
83
+ 'Could not find torch package.')
@@ -0,0 +1,100 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import numpy as np
16
+
17
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
20
+ QuantizationTarget
21
+
22
+ if FOUND_TORCH:
23
+ import torch
24
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.base_uniform_inferable_quantizer import \
25
+ BaseUniformInferableQuantizer
26
+
27
+
28
+ @mark_quantizer(quantization_target=QuantizationTarget.Activation,
29
+ quantization_method=[QuantizationMethod.UNIFORM],
30
+ quantizer_type=None)
31
+ class ActivationUniformInferableQuantizer(BaseUniformInferableQuantizer):
32
+ """
33
+ Class for quantizing activations using an uniform quantizer
34
+ """
35
+
36
+ def __init__(self,
37
+ num_bits: int,
38
+ min_range: np.ndarray,
39
+ max_range: np.ndarray,
40
+ ):
41
+ """
42
+ Initialize the quantizer with the specified parameters.
43
+
44
+ Args:
45
+ num_bits: number of bits to use for quantization
46
+ min_range: min range for quantizing activations
47
+ max_range: max range for quantizing activations
48
+ """
49
+ super(ActivationUniformInferableQuantizer, self).__init__(num_bits=num_bits,
50
+ min_range=min_range,
51
+ max_range=max_range)
52
+
53
+ assert isinstance(min_range,
54
+ np.ndarray), f'min_range is expected to be numpy array, but is of type {type(min_range)}'
55
+ assert isinstance(max_range,
56
+ np.ndarray), f'max_range is expected to be numpy array, but is of type {type(max_range)}'
57
+ assert min_range.ndim == 1, f'min_range is expected to be flatten, but of shape {min_range.shape}'
58
+ assert max_range.ndim == 1, f'max_range is expected to be flatten, but of shape {min_range.shape}'
59
+
60
+ assert len(
61
+ min_range) == 1, f'For activation, quantization per channel is not supported and min_range should be ' \
62
+ f'of length 1 but is {len(min_range)}'
63
+ assert len(
64
+ max_range) == 1, f'For activation, quantization per channel is not supported and max_range should be ' \
65
+ f'of length 1 but is {len(max_range)}'
66
+
67
+ # Activation is per-tensor thus we expect only a single min/max values
68
+ min_range = min_range[0]
69
+ max_range = max_range[0]
70
+
71
+ # fixing quantization range to include 0
72
+ a = 0 if min_range > 0 else min_range
73
+ b = 0 if max_range < 0 else max_range
74
+
75
+ self.scale = float((b - a) / ((2 ** num_bits) - 1))
76
+ self.zero_point = int(-np.round(a / self.scale)) # zp has to be positive, and a <=0, so we multiply by -1
77
+
78
+ def __call__(self, inputs: torch.Tensor):
79
+ """
80
+ Quantize the given inputs using the quantizer parameters.
81
+
82
+ Args:
83
+ inputs: input tensor to quantize
84
+
85
+ Returns:
86
+ quantized tensor.
87
+ """
88
+ return torch.fake_quantize_per_tensor_affine(inputs,
89
+ scale=self.scale,
90
+ zero_point=self.zero_point,
91
+ quant_min=self.min_quantized_domain,
92
+ quant_max=self.max_quantized_domain)
93
+
94
+
95
+ else:
96
+ class ActivationUniformInferableQuantizer: # pragma: no cover
97
+ def __init__(self, *args, **kwargs):
98
+ raise Exception('Installing torch is mandatory '
99
+ 'when using ActivationUniformInferableQuantizer. '
100
+ 'Could not find torch package.')
@@ -0,0 +1,95 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import numpy as np
16
+ import warnings
17
+
18
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer \
21
+ import mark_quantizer
22
+
23
+ if FOUND_TORCH:
24
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers \
25
+ .base_pytorch_inferable_quantizer import BasePyTorchInferableQuantizer
26
+
27
+
28
+ @mark_quantizer(quantization_target=None,
29
+ quantization_method=[QuantizationMethod.LUT_SYM_QUANTIZER],
30
+ quantizer_type=None)
31
+ class BaseLUTSymmetricInferableQuantizer(BasePyTorchInferableQuantizer):
32
+
33
+ def __init__(self,
34
+ num_bits: int,
35
+ cluster_centers: np.ndarray,
36
+ threshold: np.ndarray,
37
+ signed: bool,
38
+ multiplier_n_bits: int,
39
+ eps: float):
40
+ """
41
+ Initialize the quantizer with the specified parameters.
42
+
43
+ Args:
44
+ num_bits: number of bits to use for quantization
45
+ cluster_centers: the cluster centers to assign the values
46
+ threshold: threshold for quantizing values
47
+ signed: whether or not to use signed quantization
48
+ multiplier_n_bits: Number of bits that determines the quantization range
49
+ eps: Small value for numerical stability in division
50
+ """
51
+
52
+ super(BaseLUTSymmetricInferableQuantizer, self).__init__()
53
+
54
+ assert isinstance(threshold,
55
+ np.ndarray), f'Threshold is expected to be numpy array, but is of type {type(threshold)}'
56
+ assert threshold.ndim == 1, f'Threshold is expected to be flatten, but of shape {threshold.shape}'
57
+
58
+ assert len(np.unique(cluster_centers)) <= 2 ** num_bits, \
59
+ f'Expected num of cluster centers to be less or equal than {2 ** num_bits} ' \
60
+ f'but got {len(cluster_centers)}'
61
+
62
+ assert not np.any(cluster_centers - cluster_centers.astype(int)), f'Expected cluster centers to be integers'
63
+
64
+ if signed:
65
+ assert np.all((-1 * (2 ** (multiplier_n_bits - int(signed))) <= cluster_centers) &
66
+ (cluster_centers <= (2 ** (multiplier_n_bits - int(signed)) - 1))), \
67
+ f'Expected cluster centers in the quantization range'
68
+ else:
69
+ assert np.all(cluster_centers <= (2 ** multiplier_n_bits)), f'Expected cluster centers in the ' \
70
+ f'quantization range'
71
+
72
+ # If unsigned activation quantization, all cluster_centers must be positive
73
+ if not signed:
74
+ assert np.all(cluster_centers >= 0), f'Expected unsigned cluster centers in unsigned activation ' \
75
+ f'quantization'
76
+
77
+ # num_bits must be less than multiplier_n_bits
78
+ assert num_bits <= multiplier_n_bits, f'Look-Up-Table bit configuration has {num_bits} bits. It must be ' \
79
+ f'less then {multiplier_n_bits}'
80
+ if num_bits == multiplier_n_bits:
81
+ warnings.warn("Num of bits equal to multiplier n bits, Please be aware LUT quantizier may be "
82
+ "inefficient in that case, consider using SymmetricInferableQuantizer instead")
83
+
84
+ self.signed = signed
85
+ self.threshold = threshold
86
+ self.cluster_centers = cluster_centers
87
+ self.num_bits = num_bits
88
+ self.multiplier_n_bits = multiplier_n_bits
89
+ self.eps = eps
90
+
91
+ else:
92
+ class BaseLUTSymmetricInferableQuantizer: # pragma: no cover
93
+ def __init__(self, *args, **kwargs):
94
+ raise Exception('Installing torch is mandatory when using BaseLUTSymmetricInferableQuantizer. Could not '
95
+ 'find torch package.')
@@ -0,0 +1,48 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from abc import abstractmethod
16
+
17
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
+ from model_compression_toolkit.quantizers_infrastructure import BaseInferableQuantizer
19
+
20
+ if FOUND_TORCH:
21
+ import torch
22
+
23
+
24
+ class BasePyTorchInferableQuantizer(BaseInferableQuantizer):
25
+ def __init__(self):
26
+ """
27
+ This class is a base quantizer for PyTorch quantizers for inference only.
28
+ """
29
+ super(BasePyTorchInferableQuantizer, self).__init__()
30
+
31
+ @abstractmethod
32
+ def __call__(self, inputs: torch.Tensor):
33
+ """
34
+ Quantize the given inputs using the quantizer parameters.
35
+
36
+ Args:
37
+ inputs: input tensor to quantize
38
+
39
+ Returns:
40
+ quantized tensor.
41
+ """
42
+ raise NotImplemented(f'{self.__class__.__name__} did not implement __call__') # pragma: no cover
43
+ else:
44
+ class BasePyTorchInferableQuantizer: # pragma: no cover
45
+ def __init__(self, *args, **kwargs):
46
+ raise Exception('Installing torch is mandatory '
47
+ 'when using BasePyTorchInferableQuantizer. '
48
+ 'Could not find torch package.')
@@ -0,0 +1,70 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import numpy as np
16
+
17
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
20
+
21
+ if FOUND_TORCH:
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.base_pytorch_inferable_quantizer import \
23
+ BasePyTorchInferableQuantizer
24
+
25
+
26
+ @mark_quantizer(quantization_target=None,
27
+ quantization_method=[QuantizationMethod.SYMMETRIC],
28
+ quantizer_type=None)
29
+ class BaseSymmetricInferableQuantizer(BasePyTorchInferableQuantizer):
30
+
31
+ def __init__(self,
32
+ num_bits: int,
33
+ threshold: np.ndarray,
34
+ signed: bool):
35
+ """
36
+ Initialize the quantizer with the specified parameters.
37
+
38
+ Args:
39
+ num_bits: number of bits to use for quantization
40
+ threshold: threshold for quantizing weights
41
+ signed: whether or not to use signed quantization
42
+ """
43
+
44
+ super(BaseSymmetricInferableQuantizer, self).__init__()
45
+
46
+ assert isinstance(threshold,
47
+ np.ndarray), f'Threshold is expected to be numpy array, but is of type {type(threshold)}'
48
+ assert threshold.ndim == 1, f'Threshold is expected to be flatten, but of shape {threshold.shape}'
49
+
50
+ self.signed = signed
51
+ self.threshold = threshold
52
+ self.num_bits = num_bits
53
+
54
+ if signed:
55
+ self.min_quantized_domain = -2 ** (num_bits - 1)
56
+ self.max_quantized_domain = 2 ** (num_bits - 1) - 1
57
+ self.scales = threshold / 2 ** (num_bits - 1)
58
+ else:
59
+ self.min_quantized_domain = 0
60
+ self.max_quantized_domain = (2 ** num_bits) - 1
61
+ self.scales = threshold / 2 ** num_bits
62
+
63
+
64
+
65
+ else:
66
+ class BaseSymmetricInferableQuantizer: # pragma: no cover
67
+ def __init__(self, *args, **kwargs):
68
+ raise Exception('Installing torch is mandatory '
69
+ 'when using BaseSymmetricInferableQuantizer. '
70
+ 'Could not find torch package.')
@@ -0,0 +1,57 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import numpy as np
16
+
17
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
20
+
21
+ if FOUND_TORCH:
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.base_pytorch_inferable_quantizer import \
23
+ BasePyTorchInferableQuantizer
24
+
25
+
26
+ @mark_quantizer(quantization_target=None,
27
+ quantization_method=[QuantizationMethod.UNIFORM],
28
+ quantizer_type=None)
29
+ class BaseUniformInferableQuantizer(BasePyTorchInferableQuantizer):
30
+
31
+ def __init__(self,
32
+ num_bits: int,
33
+ min_range: np.ndarray,
34
+ max_range: np.ndarray):
35
+ """
36
+ Initialize the quantizer with the specified parameters.
37
+
38
+ Args:
39
+ num_bits: number of bits to use for quantization
40
+ min_range: min quantization range for quantizing
41
+ max_range: max quantization range for quantizing
42
+ """
43
+
44
+ super(BaseUniformInferableQuantizer, self).__init__()
45
+ self.num_bits = num_bits
46
+ self.min_range = min_range
47
+ self.max_range = max_range
48
+ self.min_quantized_domain = 0
49
+ self.max_quantized_domain = 2 ** num_bits - 1
50
+
51
+
52
+ else:
53
+ class BaseUniformInferableQuantizer: # pragma: no cover
54
+ def __init__(self, *args, **kwargs):
55
+ raise Exception('Installing torch is mandatory '
56
+ 'when using BaseUniformInferableQuantizer. '
57
+ 'Could not find torch package.')
@@ -0,0 +1,26 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # Inferable pytorch quantizer signature parameters:
17
+ NUM_BITS = 'num_bits'
18
+ SIGNED = 'signed'
19
+ THRESHOLD = 'threshold'
20
+ PER_CHANNEL = 'per_channel'
21
+ MIN_RANGE = 'min_range'
22
+ MAX_RANGE = 'max_range'
23
+ CHANNEL_AXIS = 'channel_axis'
24
+ CLUSTER_CENTERS = 'cluster_centers'
25
+ MULTIPLIER_N_BITS = 'multiplier_n_bits'
26
+ EPS = 'eps'
@@ -0,0 +1,14 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,77 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numpy as np
17
+
18
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer \
21
+ import mark_quantizer, QuantizationTarget
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants \
23
+ import MULTIPLIER_N_BITS, EPS
24
+
25
+ if FOUND_TORCH:
26
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers \
27
+ .weights_inferable_quantizers.weights_lut_symmetric_inferable_quantizer import \
28
+ WeightsLUTSymmetricInferableQuantizer
29
+
30
+
31
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
32
+ quantization_method=[QuantizationMethod.LUT_POT_QUANTIZER],
33
+ quantizer_type=None)
34
+ class WeightsLUTPOTInferableQuantizer(WeightsLUTSymmetricInferableQuantizer):
35
+ """
36
+ Class for quantizing weights using lut power-of-two quantizer
37
+ """
38
+
39
+ def __init__(self,
40
+ num_bits: int,
41
+ cluster_centers: np.ndarray,
42
+ threshold: np.ndarray,
43
+ per_channel: bool,
44
+ channel_axis: int = None,
45
+ multiplier_n_bits: int = MULTIPLIER_N_BITS,
46
+ eps: float = EPS):
47
+ """
48
+ Initialize the quantizer with the specified parameters.
49
+
50
+ Args:
51
+ num_bits: number of bits to use for quantization
52
+ cluster_centers: the cluster centers to assign the weights
53
+ threshold: threshold for quantizing weights
54
+ per_channel: whether to use per-channel quantization
55
+ channel_axis: Axis of input to apply per-channel quantization on
56
+ multiplier_n_bits: Number of bits that determines the quantization range
57
+ eps: Small value for numerical stability in division
58
+ """
59
+ # target of Weights quantization
60
+ super(WeightsLUTPOTInferableQuantizer, self).__init__(num_bits=num_bits,
61
+ threshold=threshold,
62
+ cluster_centers=cluster_centers,
63
+ per_channel=per_channel,
64
+ channel_axis=channel_axis,
65
+ multiplier_n_bits=multiplier_n_bits,
66
+ eps=eps)
67
+
68
+ is_threshold_pot = np.all(np.round(np.log2(threshold.flatten())) == np.log2(threshold.flatten()))
69
+ assert is_threshold_pot, f'Expected threshold to be power of 2 but is {threshold}'
70
+
71
+
72
+ else:
73
+ class WeightsLUTPOTInferableQuantizer: # pragma: no cover
74
+ def __init__(self, *args, **kwargs):
75
+ raise Exception('Installing torch is mandatory '
76
+ 'when using WeightsLUTPOTInferableQuantizer. '
77
+ 'Could not find torch package.')