mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -0,0 +1,49 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import numpy as np
16
+ from typing import Tuple
17
+
18
+
19
+ def adjust_range_to_include_zero(range_min: np.ndarray,
20
+ range_max: np.ndarray,
21
+ n_bits: int) -> Tuple[np.ndarray, np.ndarray]:
22
+ """
23
+ Adjusting the quantization range to include representation of 0.0 in the quantization grid.
24
+ For per_channel quantization range_min\range_max should be tensors in the specific shape that allows
25
+ quantization along the channel_axis.
26
+
27
+ Args:
28
+ range_min: min bound of the quantization range (before adjustment).
29
+ range_max: max bound of the quantization range (before adjustment).
30
+ n_bits: Number of bits to quantize the tensor.
31
+
32
+ Returns: adjusted quantization range
33
+ """
34
+ scale = (range_max - range_min) / (2 ** n_bits - 1)
35
+ min_range_adj = scale * np.round(range_min / scale)
36
+ max_range_adj = range_max - range_min + min_range_adj
37
+
38
+ min_positive = range_min > 0
39
+ max_negative = range_max < 0
40
+ mid_range = np.logical_and(np.logical_not(min_positive), np.logical_not(max_negative))
41
+
42
+ min_range_adj = min_range_adj * mid_range + max_negative * range_min
43
+ max_range_adj = max_range_adj * mid_range + min_positive * range_max
44
+
45
+ # Make sure min_range_adj < 0 and max_range_adj > 0 to avoid small numeric error
46
+ min_range_adj = np.minimum(min_range_adj, 0)
47
+ max_range_adj = np.maximum(max_range_adj, 0)
48
+
49
+ return min_range_adj, max_range_adj
@@ -0,0 +1,14 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -14,15 +14,17 @@
14
14
  # ==============================================================================
15
15
  from model_compression_toolkit.core.common import Logger
16
16
  from model_compression_toolkit.core.common.constants import FOUND_TF
17
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_all_subclasses import get_all_subclasses
17
18
 
18
19
  if FOUND_TF:
19
20
  import tensorflow as tf
20
- from model_compression_toolkit import qunatizers_infrastructure as qi
21
- from model_compression_toolkit.qunatizers_infrastructure.keras.base_keras_quantizer import BaseKerasQuantizer
22
-
21
+ from model_compression_toolkit import quantizers_infrastructure as qi
22
+ from model_compression_toolkit.quantizers_infrastructure import BaseKerasTrainableQuantizer
23
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer \
24
+ import \
25
+ BaseKerasInferableQuantizer
23
26
  keras = tf.keras
24
27
 
25
-
26
28
  def keras_load_quantized_model(filepath, custom_objects=None, compile=True, options=None):
27
29
  """
28
30
  This function wraps the keras load model and MCT quantization custom class to it.
@@ -36,9 +38,25 @@ if FOUND_TF:
36
38
  Returns: A keras Model
37
39
 
38
40
  """
39
- qi_custom_objects = {subclass.__name__: subclass for subclass in BaseKerasQuantizer.__subclasses__()}
40
- qi_custom_objects.update({qi.KerasQuantizationWrapper.__name__: qi.KerasQuantizationWrapper,
41
- qi.KerasNodeQuantizationDispatcher.__name__: qi.KerasNodeQuantizationDispatcher})
41
+ qi_inferable_custom_objects = {subclass.__name__: subclass for subclass in
42
+ get_all_subclasses(BaseKerasInferableQuantizer)}
43
+ all_inferable_names = list(qi_inferable_custom_objects.keys())
44
+ if len(set(all_inferable_names)) < len(all_inferable_names):
45
+ Logger.error(f"Found multiple quantizers with the same name that inherit from BaseKerasInferableQuantizer"
46
+ f"while trying to load a model.")
47
+
48
+ qi_trainable_custom_objects = {subclass.__name__: subclass for subclass in
49
+ get_all_subclasses(BaseKerasTrainableQuantizer)}
50
+ all_trainable_names = list(qi_trainable_custom_objects.keys())
51
+ if len(set(all_trainable_names)) < len(all_trainable_names):
52
+ Logger.error(f"Found multiple quantizers with the same name that inherit from BaseKerasTrainableQuantizer"
53
+ f"while trying to load a model.")
54
+
55
+ # Merge dictionaries into one dict
56
+ qi_custom_objects = {**qi_inferable_custom_objects, **qi_trainable_custom_objects}
57
+
58
+ # Add non-quantizers custom objects
59
+ qi_custom_objects.update({qi.KerasQuantizationWrapper.__name__: qi.KerasQuantizationWrapper})
42
60
  if custom_objects is not None:
43
61
  qi_custom_objects.update(custom_objects)
44
62
  return tf.keras.models.load_model(filepath,
@@ -60,4 +78,4 @@ else:
60
78
  """
61
79
  Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
62
80
  'when using keras_load_quantized_model. '
63
- 'Could not find Tensorflow package.')
81
+ 'Could not find Tensorflow package.') # pragma: no cover
@@ -0,0 +1,345 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Dict, List, Any, Tuple
16
+ from model_compression_toolkit import quantizers_infrastructure as qi
17
+ from model_compression_toolkit.core.common.constants import FOUND_TF
18
+ from model_compression_toolkit.core.common.logger import Logger
19
+ from model_compression_toolkit.quantizers_infrastructure import BaseInferableQuantizer
20
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import WEIGHTS_QUANTIZERS, ACTIVATION_QUANTIZERS, LAYER, STEPS, TRAINING
21
+
22
+ if FOUND_TF:
23
+ import tensorflow as tf
24
+ from tensorflow.python.util import tf_inspect
25
+ from tensorflow_model_optimization.python.core.keras import utils
26
+
27
+ keras = tf.keras
28
+
29
+ def _make_quantizer_fn(quantizer, x, training):
30
+ """Use currying to return True/False specialized fns to the cond."""
31
+
32
+ def quantizer_fn():
33
+ return quantizer(x, training)
34
+
35
+ return quantizer_fn
36
+
37
+
38
+ def _weight_name(name: str) -> str:
39
+ """Extracts the weight name from the full TensorFlow variable name.
40
+
41
+ For example, returns 'kernel' for 'dense_2/kernel:0'.
42
+
43
+ Args:
44
+ name: TensorFlow variable name.
45
+
46
+ Returns:
47
+ Extracted weight name.
48
+ """
49
+ return name.split(':')[0].split('/')[-1]
50
+
51
+
52
+ class KerasQuantizationWrapper(tf.keras.layers.Wrapper):
53
+ def __init__(self,
54
+ layer,
55
+ weights_quantizers: Dict[str, BaseInferableQuantizer] = None,
56
+ activation_quantizers: List[BaseInferableQuantizer] = None,
57
+ **kwargs):
58
+ """
59
+ Keras Quantization Wrapper takes a keras layer and quantizers and infer a quantized layer.
60
+
61
+ Args:
62
+ layer: A keras layer.
63
+ weights_quantizers: A dictionary between a weight's name to its quantizer.
64
+ activation_quantizers: A list of activations quantization, one for each layer output.
65
+ """
66
+ super(KerasQuantizationWrapper, self).__init__(layer, **kwargs)
67
+ self._track_trackable(layer, name='layer')
68
+ self.weights_quantizers = weights_quantizers if weights_quantizers is not None else dict()
69
+ self.activation_quantizers = activation_quantizers if activation_quantizers is not None else list()
70
+
71
+ def add_weights_quantizer(self, param_name: str, quantizer: BaseInferableQuantizer):
72
+ """
73
+ This function adds a weights quantizer to existing wrapper
74
+
75
+ Args:
76
+ param_name: The name of the parameter to quantize
77
+ quantizer: A quantizer.
78
+
79
+ Returns: None
80
+
81
+ """
82
+ self.weights_quantizers.update({param_name: quantizer})
83
+
84
+ @property
85
+ def is_activation_quantization(self) -> bool:
86
+ """
87
+ This function check activation quantizer exists in wrapper.
88
+ Returns: a boolean if activation quantizer exists
89
+
90
+ """
91
+ return self.num_activation_quantizers > 0
92
+
93
+ @property
94
+ def is_weights_quantization(self) -> bool:
95
+ """
96
+ This function check weights quantizer exists in wrapper.
97
+
98
+ Returns: a boolean if weights quantizer exists
99
+
100
+ """
101
+ return self.num_weights_quantizers > 0
102
+
103
+ @property
104
+ def num_weights_quantizers(self) -> int:
105
+ """
106
+ Returns: number of weights quantizers
107
+ """
108
+ return len(self.weights_quantizers)
109
+
110
+ @property
111
+ def num_activation_quantizers(self) -> int:
112
+ """
113
+ Returns: number of activations quantizers
114
+ """
115
+ return len(self.activation_quantizers)
116
+
117
+ def get_config(self):
118
+ """
119
+ Returns: Configuration of KerasQuantizationWrapper.
120
+
121
+ """
122
+ base_config = super(KerasQuantizationWrapper, self).get_config()
123
+ config = {
124
+ ACTIVATION_QUANTIZERS: [keras.utils.serialize_keras_object(act) for act in self.activation_quantizers],
125
+ WEIGHTS_QUANTIZERS: {k: keras.utils.serialize_keras_object(v) for k, v in self.weights_quantizers.items()}}
126
+ return dict(list(base_config.items()) + list(config.items()))
127
+
128
+ def _set_weights_vars(self, is_training: bool = True):
129
+ """
130
+ This function sets weights quantizers vars to the layer
131
+
132
+ Args:
133
+ is_training: Flag to indicate whether training or not
134
+
135
+ Returns: None
136
+ """
137
+ self._weights_vars = []
138
+ for name, quantizer in self.weights_quantizers.items():
139
+ weight = getattr(self.layer, name)
140
+ quantizer.initialize_quantization(weight.shape, _weight_name(weight.name) if is_training else None,
141
+ self)
142
+ self._weights_vars.append((name, weight, quantizer))
143
+ self._trainable_weights.append(weight) # Must when inherit from tf.keras.layers.Wrapper in tf2.10 and below
144
+
145
+ def _set_activations_vars(self):
146
+ """
147
+ This function sets activations quantizers vars to the layer
148
+
149
+ Returns: None
150
+ """
151
+ self._activation_vars = []
152
+ for i, quantizer in enumerate(self.activation_quantizers):
153
+ quantizer.initialize_quantization(None, self.layer.name + f'/out{i}', self)
154
+ self._activation_vars.append(quantizer)
155
+
156
+ @classmethod
157
+ def from_config(cls, config):
158
+ """
159
+
160
+ Args:
161
+ config(dict): dictionary of KerasQuantizationWrapper Configuration
162
+
163
+ Returns: A KerasQuantizationWrapper
164
+
165
+ """
166
+ config = config.copy()
167
+ activation_quantizers = [keras.utils.deserialize_keras_object(act,
168
+ module_objects=globals(),
169
+ custom_objects=None) for act in
170
+ config.pop(ACTIVATION_QUANTIZERS)]
171
+ weights_quantizers = {k: keras.utils.deserialize_keras_object(v,
172
+ module_objects=globals(),
173
+ custom_objects=None) for k, v in
174
+ config.pop(WEIGHTS_QUANTIZERS).items()}
175
+ layer = tf.keras.layers.deserialize(config.pop(LAYER))
176
+ return cls(layer=layer, weights_quantizers=weights_quantizers, activation_quantizers=activation_quantizers, **config)
177
+
178
+ def build(self, input_shape):
179
+ """
180
+ KerasQuantization Wrapper build function.
181
+ Args:
182
+ input_shape: the layer input shape
183
+
184
+ Returns: None
185
+
186
+ """
187
+ super(KerasQuantizationWrapper, self).build(input_shape)
188
+
189
+ self.optimizer_step = self.add_weight(
190
+ STEPS,
191
+ initializer=tf.keras.initializers.Constant(-1),
192
+ dtype=tf.dtypes.int32,
193
+ trainable=False)
194
+
195
+ self._set_weights_vars()
196
+ self._set_activations_vars()
197
+
198
+ def set_quantize_weights(self, quantized_weights: dict):
199
+ """
200
+ This function update layer weights after quantization.
201
+
202
+ Args:
203
+ quantized_weights: a dict of weight to update
204
+
205
+ Returns: None
206
+
207
+ """
208
+ for weight_attr in self.weights_quantizers.keys():
209
+ weight = quantized_weights.get(weight_attr)
210
+ current_weight = getattr(self.layer, weight_attr)
211
+ if current_weight.shape != weight.shape:
212
+ Logger.error(
213
+ f"Existing layer weight shape {current_weight.shape} is incompatible with provided weight "
214
+ f"shape {weight.shape}") # pragma: no cover
215
+
216
+ setattr(self.layer, weight_attr, weight)
217
+
218
+ def call(self, inputs, training=None, **kwargs):
219
+ """
220
+ KerasQuantizationWrapper call functions
221
+ Args:
222
+ inputs: Input tensors to specified layer
223
+ training: a boolean stating if layer is in training mode.
224
+ **kwargs:
225
+
226
+ Returns: tensors that simulate a quantized layer.
227
+
228
+ """
229
+ if training is None:
230
+ training = tf.keras.backend.learning_phase()
231
+
232
+ # Quantize all weights, and replace them in the underlying layer.
233
+ quantized_weights = {}
234
+ for name, unquantized_weight, quantizer in self._weights_vars:
235
+
236
+ weights_quantizer_args_spec = tf_inspect.getfullargspec(quantizer.__call__).args
237
+ if TRAINING in weights_quantizer_args_spec:
238
+ quantized_weight = utils.smart_cond(
239
+ training,
240
+ _make_quantizer_fn(quantizer, unquantized_weight, True),
241
+ _make_quantizer_fn(quantizer, unquantized_weight, False))
242
+ quantized_weights.update({name: quantized_weight})
243
+ else:
244
+ # Keras weights inferable quantizer
245
+ quantized_weight = quantizer(unquantized_weight)
246
+ quantized_weights.update({name: quantized_weight})
247
+
248
+ self.set_quantize_weights(quantized_weights)
249
+
250
+ args_spec = tf_inspect.getfullargspec(self.layer.call).args
251
+ if TRAINING in args_spec:
252
+ outputs = self.layer.call(inputs, training=training, **kwargs)
253
+ else:
254
+ outputs = self.layer.call(inputs, **kwargs)
255
+
256
+ # Quantize all activations if quantizers exist.
257
+ if self.is_activation_quantization:
258
+ num_outputs = len(outputs) if isinstance(outputs, (list, tuple)) else 1
259
+ if self.num_activation_quantizers != num_outputs:
260
+ Logger.error('Quantization wrapper output quantization error: '
261
+ f'number of outputs and quantizers mismatch ({num_outputs}!='
262
+ f'{self.num_activation_quantizers}')
263
+ if num_outputs == 1:
264
+ outputs = [outputs]
265
+
266
+ _outputs = []
267
+ for _output, act_quant in zip(outputs, self.activation_quantizers):
268
+ activation_quantizer_args_spec = tf_inspect.getfullargspec(act_quant.__call__).args
269
+ if TRAINING in activation_quantizer_args_spec:
270
+ _outputs.append(utils.smart_cond(
271
+ training,
272
+ _make_quantizer_fn(act_quant, _output, True),
273
+ _make_quantizer_fn(act_quant, _output, False)))
274
+ else:
275
+ # Keras activation inferable quantizer.
276
+ _outputs.append(act_quant(_output))
277
+ outputs = _outputs[0] if num_outputs == 1 else _outputs
278
+
279
+ return outputs
280
+
281
+ def convert_to_inferable_quantizers(self):
282
+ """
283
+ Convert layer's quantizers to inferable.
284
+
285
+ Returns:
286
+ None
287
+ """
288
+ # Activations quantizers
289
+ inferable_activation_quantizers = []
290
+ if self.is_activation_quantization:
291
+ for quantizer in self.activation_quantizers:
292
+ if isinstance(quantizer, qi.BaseKerasTrainableQuantizer):
293
+ inferable_activation_quantizers.append(quantizer.convert2inferable())
294
+ self.activation_quantizers = inferable_activation_quantizers
295
+ self._set_activations_vars()
296
+
297
+ # Weight quantizers
298
+ inferable_weight_quantizers = {}
299
+ if self.is_weights_quantization:
300
+ for name, quantizer in self.weights_quantizers.items():
301
+ if isinstance(quantizer, qi.BaseKerasTrainableQuantizer):
302
+ inferable_weight_quantizers.update({name: quantizer.convert2inferable()})
303
+ self.weights_quantizers = inferable_weight_quantizers
304
+ self._set_weights_vars(False)
305
+
306
+ def get_weights_vars(self) -> List[Tuple[str, Any, BaseInferableQuantizer]]:
307
+ """
308
+ A getter of the layer's weights variables.
309
+
310
+ Returns:
311
+ List pf tuples of the wrapped layer's weights variables with weight name, values and assigned quantizer.
312
+
313
+ """
314
+
315
+ return self._weights_vars
316
+
317
+ def get_quantized_weights(self) -> Dict[str, tf.Tensor]:
318
+ """
319
+
320
+ Returns: A dictionary of weights attributes to quantized weights.
321
+
322
+ """
323
+ quantized_weights = {}
324
+ weights_var = self.get_weights_vars()
325
+ for name, w, quantizer in weights_var:
326
+ quantized_weights[name] = quantizer(w)
327
+ return quantized_weights
328
+
329
+ else:
330
+ class KerasQuantizationWrapper(object):
331
+ def __init__(self,
332
+ layer,
333
+ weights_quantizers: Dict[str, BaseInferableQuantizer] = None,
334
+ activation_quantizers: List[BaseInferableQuantizer] = None):
335
+ """
336
+ Keras Quantization Wrapper takes a keras layer and quantizers and infer a quantized layer.
337
+
338
+ Args:
339
+ layer: A keras layer.
340
+ weights_quantizers: A dictionary between a weight's name to its quantizer.
341
+ activation_quantizers: A list of activations quantization, one for each layer output.
342
+ """
343
+ Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
344
+ 'when using KerasQuantizationWrapper. '
345
+ 'Could not find Tensorflow package.') # pragma: no cover
@@ -0,0 +1,85 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numpy as np
17
+ import tensorflow as tf
18
+
19
+
20
+ def lut_quantizer(tensor_data: tf.Tensor,
21
+ cluster_centers: np.ndarray,
22
+ signed: bool,
23
+ threshold: np.ndarray,
24
+ multiplier_n_bits: int,
25
+ eps: float) -> tf.Tensor:
26
+ """
27
+ Quantize a tensor using a non-uniform quantization based on the pre-defined clusters.
28
+ 1. Scales tensor_data with the threshold into multiplier_n_bits quantization range.
29
+ 2. Assigns cluster centers to each value.
30
+ 3. Scales back by multiplying the result by threshold and dividing with the quantization range max value.
31
+ The result is the quantized tensor.
32
+
33
+ Args:
34
+ tensor_data: Input activation tensor.
35
+ cluster_centers: the cluster centers to assign the tensor values.
36
+ signed: Whether the quantization is signed or not.
37
+ threshold: threshold for quantization.
38
+ multiplier_n_bits: Number of bits that determines the quantization range
39
+ eps: Small value for numerical stability in division.
40
+
41
+ Returns: Quantized tensor.
42
+ """
43
+
44
+ tensor = int_quantization_with_threshold(tensor_data, n_bits=multiplier_n_bits, signed=signed, threshold=threshold,
45
+ eps=eps)
46
+ tensor = tf.expand_dims(tensor, -1)
47
+
48
+ expanded_cluster_centers = cluster_centers.reshape([*[1 for _ in range(len(tensor.shape) - 1)], -1])
49
+ cluster_assignments = tf.argmin(tf.abs(tensor - expanded_cluster_centers), axis=-1)
50
+ centers = tf.gather(cluster_centers.flatten(), cluster_assignments)
51
+
52
+ quant_tensor = (centers / (2 ** (multiplier_n_bits - int(signed)))) * threshold
53
+
54
+ return quant_tensor
55
+
56
+
57
+ def int_quantization_with_threshold(data: tf.Tensor,
58
+ n_bits: int,
59
+ signed: bool,
60
+ threshold: np.ndarray,
61
+ eps: float) -> tf.Tensor:
62
+ """
63
+ Divides data by threshold and quantize it to integers in the quantization range (depends on signed value).
64
+
65
+ Args:
66
+ data: tensor data.
67
+ n_bits: number of bits that determines the quantization range.
68
+ signed: Whether the quantization is signed or not.
69
+ threshold: threshold for quantization.
70
+ eps: Small value for numerical stability in division.
71
+
72
+ Returns:
73
+ Uniform Quantized tensor.
74
+
75
+ """
76
+
77
+ if signed:
78
+ clip_max = 2 ** (n_bits - 1) - 1
79
+ clip_min = -2 ** (n_bits - 1)
80
+ else:
81
+ clip_max = 2 ** n_bits - 1
82
+ clip_min = 0
83
+
84
+ return tf.clip_by_value((data / (threshold + eps)) * (2 ** (n_bits - int(signed))),
85
+ clip_value_max=clip_max, clip_value_min=clip_min)
@@ -0,0 +1,27 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import \
16
+ BaseKerasInferableQuantizer
17
+
18
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.weights_inferable_quantizers.weights_pot_inferable_quantizer import WeightsPOTInferableQuantizer
19
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.weights_inferable_quantizers.weights_symmetric_inferable_quantizer import WeightsSymmetricInferableQuantizer
20
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.weights_inferable_quantizers.weights_uniform_inferable_quantizer import WeightsUniformInferableQuantizer
21
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.weights_inferable_quantizers.weights_lut_symmetric_inferable_quantizer import WeightsLUTSymmetricInferableQuantizer
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.weights_inferable_quantizers.weights_lut_pot_inferable_quantizer import WeightsLUTPOTInferableQuantizer
23
+
24
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.activation_inferable_quantizers.activation_pot_inferable_quantizer import ActivationPOTInferableQuantizer
25
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.activation_inferable_quantizers.activation_symmetric_inferable_quantizer import ActivationSymmetricInferableQuantizer
26
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.activation_inferable_quantizers.activation_uniform_inferable_quantizer import ActivationUniformInferableQuantizer
27
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.activation_inferable_quantizers.activation_lut_pot_inferable_quantizer import ActivationLutPOTInferableQuantizer
@@ -0,0 +1,14 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================