mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -0,0 +1,179 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import warnings
16
+ from typing import List
17
+
18
+ import numpy as np
19
+
20
+ from model_compression_toolkit.core.common.constants import FOUND_TF
21
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
23
+ QuantizationTarget
24
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import MULTIPLIER_N_BITS, EPS
25
+
26
+ if FOUND_TF:
27
+ import tensorflow as tf
28
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import \
29
+ BaseKerasInferableQuantizer
30
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizer_utils import \
31
+ lut_quantizer
32
+
33
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
34
+ quantization_method=[QuantizationMethod.LUT_SYM_QUANTIZER],
35
+ quantizer_type=None)
36
+ class WeightsLUTSymmetricInferableQuantizer(BaseKerasInferableQuantizer):
37
+ """
38
+ Class for quantizing weights using a lut symmetric quantizer
39
+ """
40
+
41
+ def __init__(self,
42
+ num_bits: int,
43
+ cluster_centers: np.ndarray,
44
+ threshold: List[float],
45
+ per_channel: bool,
46
+ channel_axis: int = None,
47
+ input_rank: int = None,
48
+ multiplier_n_bits: int = MULTIPLIER_N_BITS,
49
+ eps: float = EPS):
50
+ """
51
+ Initialize the quantizer with the specified parameters.
52
+
53
+ Args:
54
+ num_bits: number of bits to use for quantization
55
+ cluster_centers: the cluster centers to assign the weights
56
+ threshold: threshold for quantizing weights
57
+ per_channel: whether to use per-channel quantization
58
+ channel_axis: axis along which to apply per-channel quantization
59
+ input_rank: number of dimensions of input tensor the quantizer quantizes
60
+ multiplier_n_bits: Number of bits that determines the quantization range
61
+ eps: Small value for numerical stability in division
62
+ """
63
+
64
+ super(WeightsLUTSymmetricInferableQuantizer, self).__init__()
65
+
66
+ assert isinstance(threshold, list), f'Expected threshold to be of type list but is {type(threshold)}'
67
+ assert all([isinstance(x, (float, np.float32, np.float64)) for x in
68
+ threshold]), f'Expected threshold list to contain float or np.float values but found ' \
69
+ f'{[type(x) for x in threshold]}'
70
+
71
+ self.threshold = np.asarray(threshold)
72
+
73
+ if per_channel:
74
+ assert input_rank is not None, f'Input rank is missing in per channel quantization'
75
+ assert channel_axis is not None, f'Channel axis is missing in per channel quantization'
76
+ assert len(threshold) >= 1, f'In per-channel quantization threshold list should be of length >= 1 ' \
77
+ f'but is {len(threshold)} '
78
+ else:
79
+ assert len(threshold) == 1, f'In per-tensor quantization threshold should be of length 1 but is' \
80
+ f' {len(threshold)}'
81
+ self.threshold = self.threshold[0]
82
+
83
+ assert len(np.unique(cluster_centers)) <= 2 ** num_bits, \
84
+ f'Expected num of cluster centers to be less or equal than {2 ** num_bits} ' \
85
+ f'but got {len(cluster_centers)}'
86
+
87
+ assert not np.any(cluster_centers - cluster_centers.astype(int)), f'Expected cluster centers to be integers'
88
+
89
+ # Weight quantization is signed, hence the quantization range is
90
+ # [-2**(multiplier_n_bits - 1), 2**(multiplier_n_bits - 1) - 1]
91
+ assert np.all((-1 * (2 ** (multiplier_n_bits - 1)) <= cluster_centers) &
92
+ (cluster_centers <= (2 ** (multiplier_n_bits - 1) - 1))), \
93
+ f'Expected cluster centers in the quantization range'
94
+
95
+ # num_bits must be less than multiplier_n_bits
96
+ assert num_bits <= multiplier_n_bits, f'Look-Up-Table bit configuration has {num_bits} bits. It must be ' \
97
+ f'less then {multiplier_n_bits}'
98
+ if num_bits == multiplier_n_bits:
99
+ warnings.warn("Num of bits equal to multiplier n bits, Please be aware LUT quantizier may be "
100
+ "inefficient in that case, consider using SymmetricInferableQuantizer instead")
101
+
102
+ self.num_bits = num_bits
103
+ self.cluster_centers = cluster_centers
104
+ self.multiplier_n_bits = multiplier_n_bits
105
+ self.eps = eps
106
+ self.per_channel = per_channel
107
+ self.channel_axis = channel_axis
108
+ self.input_rank = input_rank
109
+
110
+ # Tensorflow's fake_quant_with_min_max_vars_per_channel only works on last axis, so
111
+ # need to move the quantization axis to the last axis
112
+ if per_channel and channel_axis not in [-1, self.input_rank - 1]:
113
+ # If per-channel quantization is being used and the channel axis is not the last axis,
114
+ # create a permutation vector to move the channel axis to the last position
115
+ self.perm_vec = list(np.arange(self.input_rank))
116
+ self.perm_vec[channel_axis] = self.input_rank - 1
117
+ self.perm_vec[self.input_rank - 1] = channel_axis
118
+ else:
119
+ # If per-channel quantization is not being used or the channel axis is already the last axis,
120
+ # set the permutation vector to None
121
+ self.perm_vec = None
122
+
123
+ def __call__(self, inputs: tf.Tensor) -> tf.Tensor:
124
+ """
125
+ Quantize the given inputs using the quantizer parameters.
126
+
127
+ Args:
128
+ inputs: input tensor to quantize
129
+
130
+ Returns:
131
+ quantized tensor.
132
+ """
133
+ assert inputs.dtype == tf.float32, f'Input tensor was expected to be a float tensor but is of type ' \
134
+ f'{inputs.dtype}'
135
+
136
+ # If per-channel quantization is being used
137
+ if self.per_channel:
138
+ # If a permutation vector has been created to move the channel axis to the last position
139
+ if self.perm_vec:
140
+ # Transpose the input tensor to move the channel axis to the last position
141
+ inputs = tf.transpose(inputs, perm=self.perm_vec)
142
+
143
+ # Quantize the input tensor using per-channel quantization
144
+ q_tensor = lut_quantizer(inputs, cluster_centers=self.cluster_centers, signed=True,
145
+ threshold=self.threshold, multiplier_n_bits=self.multiplier_n_bits,
146
+ eps=self.eps)
147
+ if self.perm_vec:
148
+ # Transpose the quantized tensor back to its original shape
149
+ q_tensor = tf.transpose(q_tensor, perm=self.perm_vec)
150
+
151
+ # Return the quantized tensor
152
+ return q_tensor
153
+ else:
154
+ return lut_quantizer(inputs, cluster_centers=self.cluster_centers, signed=True,
155
+ threshold=self.threshold, multiplier_n_bits=self.multiplier_n_bits, eps=self.eps)
156
+
157
+ def get_config(self):
158
+ """
159
+ Return a dictionary with the configuration of the quantizer.
160
+
161
+ Returns:
162
+ Dictionary with the following keys: 'per_channel', 'num_bits', 'cluster_centers', 'threshold',
163
+ 'channel_axis', 'input_rank', 'multiplier_n_bits', 'eps'
164
+ """
165
+ return {'per_channel': self.per_channel,
166
+ 'num_bits': self.num_bits,
167
+ 'cluster_centers': self.cluster_centers,
168
+ 'threshold': self.threshold,
169
+ 'channel_axis': self.channel_axis,
170
+ 'input_rank': self.input_rank,
171
+ 'multiplier_n_bits': self.multiplier_n_bits,
172
+ 'eps': self.eps}
173
+
174
+ else:
175
+ class WeightsLUTSymmetricInferableQuantizer: # pragma: no cover
176
+ def __init__(self, *args, **kwargs):
177
+ raise Exception('Installing tensorflow and tensorflow_model_optimization is mandatory '
178
+ 'when using WeightsLUTSymmetricInferableQuantizer. '
179
+ 'Could not find Tensorflow package.')
@@ -0,0 +1,67 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List
16
+
17
+ import numpy as np
18
+
19
+ from model_compression_toolkit.core.common.constants import FOUND_TF
20
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
21
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, QuantizationTarget
22
+
23
+ if FOUND_TF:
24
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.weights_inferable_quantizers.weights_symmetric_inferable_quantizer \
25
+ import WeightsSymmetricInferableQuantizer
26
+
27
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
28
+ quantization_method=[QuantizationMethod.POWER_OF_TWO],
29
+ quantizer_type=None)
30
+ class WeightsPOTInferableQuantizer(WeightsSymmetricInferableQuantizer):
31
+ """
32
+ Class for quantizing weights using power-of-two quantizer
33
+ """
34
+
35
+ def __init__(self,
36
+ num_bits: int,
37
+ threshold: List[float],
38
+ per_channel: bool,
39
+ channel_axis: int = None,
40
+ input_rank: int = None):
41
+ """
42
+ Initialize the quantizer with the specified parameters.
43
+
44
+ Args:
45
+ num_bits: number of bits to use for quantization
46
+ threshold: threshold for quantizing activations
47
+ per_channel: whether to use per-channel quantization
48
+ channel_axis: axis along which to apply per-channel quantization
49
+ input_rank: number of dimensions of input tensor the quantizer quantizes
50
+ """
51
+ # Call the superclass constructor with the given parameters, along with the target of Weights quantization
52
+ super(WeightsPOTInferableQuantizer, self).__init__(num_bits=num_bits,
53
+ threshold=threshold,
54
+ per_channel=per_channel,
55
+ channel_axis=channel_axis,
56
+ input_rank=input_rank)
57
+
58
+ is_threshold_pot = np.all([int(np.log2(x)) == np.log2(x) for x in self.threshold.flatten()])
59
+ assert is_threshold_pot, f'Expected threshold to be power of 2 but is {self.threshold}'
60
+
61
+
62
+ else:
63
+ class WeightsPOTInferableQuantizer: # pragma: no cover
64
+ def __init__(self, *args, **kwargs):
65
+ raise Exception('Installing tensorflow and tensorflow_model_optimization is mandatory '
66
+ 'when using WeightsPOTInferableQuantizer. '
67
+ 'Could not find Tensorflow package.')
@@ -0,0 +1,87 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List
16
+
17
+ import numpy as np
18
+
19
+ from model_compression_toolkit.core.common.constants import FOUND_TF
20
+
21
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, QuantizationTarget
23
+
24
+ if FOUND_TF:
25
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.weights_inferable_quantizers.weights_uniform_inferable_quantizer \
26
+ import WeightsUniformInferableQuantizer
27
+
28
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
29
+ quantization_method=[QuantizationMethod.SYMMETRIC],
30
+ quantizer_type=None)
31
+ class WeightsSymmetricInferableQuantizer(WeightsUniformInferableQuantizer):
32
+ """
33
+ Class for quantizing weights using a symmetric quantizer
34
+ """
35
+ def __init__(self,
36
+ num_bits: int,
37
+ threshold: List[float],
38
+ per_channel: bool,
39
+ channel_axis: int = None,
40
+ input_rank: int = None
41
+ ):
42
+ """
43
+ Initialize the quantizer with the specified parameters.
44
+
45
+ Args:
46
+ num_bits: number of bits to use for quantization
47
+ threshold: threshold for quantizing weights
48
+ per_channel: whether to use per-channel quantization
49
+ channel_axis: axis along which to apply per-channel quantization
50
+ input_rank: number of dimensions of input tensor the quantizer quantizes
51
+ """
52
+ assert isinstance(threshold, list), f'Expected threshold to be of type list but is {type(threshold)}'
53
+ assert all([isinstance(x, (float, np.float32, np.float64)) for x in
54
+ threshold]), f'Expected threshold list to contain float or np.float values but found ' \
55
+ f'{[type(x) for x in threshold]}'
56
+
57
+ self.threshold = np.asarray(threshold)
58
+
59
+ _min_range = -self.threshold
60
+ _max_range = self.threshold - self.threshold / (2 ** (num_bits - 1))
61
+
62
+ super(WeightsSymmetricInferableQuantizer, self).__init__(num_bits=num_bits,
63
+ min_range=list(_min_range),
64
+ max_range=list(_max_range),
65
+ per_channel=per_channel,
66
+ channel_axis=channel_axis,
67
+ input_rank=input_rank)
68
+
69
+ def get_config(self):
70
+ """
71
+ Return a dictionary with the configuration of the quantizer.
72
+
73
+ Returns:
74
+ Dictionary with the following keys: 'num_bits', 'threshold', 'per_channel', 'channel_axis'
75
+ """
76
+ return {'num_bits': self.num_bits,
77
+ 'threshold': self.threshold,
78
+ 'per_channel': self.per_channel,
79
+ 'channel_axis': self.channel_axis,
80
+ 'input_rank': self.input_rank}
81
+
82
+ else:
83
+ class WeightsSymmetricInferableQuantizer: # pragma: no cover
84
+ def __init__(self, *args, **kwargs):
85
+ raise Exception('Installing tensorflow and tensorflow_model_optimization is mandatory '
86
+ 'when using WeightsPOTInferableQuantizer. '
87
+ 'Could not find Tensorflow package.')
@@ -0,0 +1,163 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List
16
+
17
+ import numpy as np
18
+
19
+ from model_compression_toolkit.core.common.constants import FOUND_TF
20
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
21
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, QuantizationTarget
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.quant_utils import \
23
+ adjust_range_to_include_zero
24
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.validation_functions import \
25
+ validate_uniform_min_max_ranges, validate_adjusted_min_max_ranges
26
+
27
+ if FOUND_TF:
28
+ import tensorflow as tf
29
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import BaseKerasInferableQuantizer
30
+
31
+ @mark_quantizer(quantization_target=QuantizationTarget.Weights,
32
+ quantization_method=[QuantizationMethod.UNIFORM],
33
+ quantizer_type=None)
34
+ class WeightsUniformInferableQuantizer(BaseKerasInferableQuantizer):
35
+ """
36
+ Class for quantizing weights using a uniform quantizer
37
+ """
38
+ def __init__(self,
39
+ num_bits: int,
40
+ min_range: List[float],
41
+ max_range: List[float],
42
+ per_channel: bool,
43
+ channel_axis: int = None,
44
+ input_rank: int = None
45
+ ):
46
+ """
47
+ Initialize the quantizer with the specified parameters.
48
+
49
+ Args:
50
+ num_bits: number of bits to use for quantization
51
+ min_range: min quantization range for quantizing weights
52
+ max_range: max quantization range for quantizing weights
53
+ per_channel: whether to use per-channel quantization
54
+ channel_axis: axis along which to apply per-channel quantization
55
+ input_rank: number of dimensions of input tensor the quantizer quantizes
56
+ """
57
+
58
+ super(WeightsUniformInferableQuantizer, self).__init__()
59
+
60
+ # Validate inputs properties
61
+ validate_uniform_min_max_ranges(min_range,
62
+ max_range)
63
+
64
+ # Convert min/max to numpy arrays
65
+ min_range, max_range = np.asarray(min_range), np.asarray(max_range)
66
+ _min_range, _max_range = adjust_range_to_include_zero(min_range, max_range, num_bits)
67
+ validate_adjusted_min_max_ranges(min_range=min_range,
68
+ max_range=max_range,
69
+ adj_min=_min_range,
70
+ adj_max=_max_range)
71
+
72
+ self.num_bits = num_bits
73
+ self.max_range = _max_range
74
+ self.min_range = _min_range
75
+
76
+ if per_channel:
77
+ assert input_rank is not None, f'Input rank is missing in per channel quantization'
78
+ assert channel_axis is not None, f'Channel axis is missing in per channel quantization'
79
+ assert len(self.min_range) >= 1, f'In per-channel quantization min ranges list should be of length >= 1 but is {len(self.min_range)}'
80
+ assert len(self.max_range) >= 1, f'In per-channel quantization max ranges list should be of length >= 1 but is {len(self.max_range)}'
81
+ else:
82
+ assert len(self.min_range) == 1, f'In per-tensor quantization min/max should be of length 1 but is {len(min_range)}'
83
+ assert len(self.min_range) == 1, f'In per-tensor quantization min_range should be of length 1 but is {len(self.min_range)}'
84
+ assert len(self.max_range) == 1, f'In per-tensor quantization max_range should be of length 1 but is {len(self.max_range)}'
85
+ self.min_range = self.min_range[0]
86
+ self.max_range = self.max_range[0]
87
+
88
+ self.per_channel = per_channel
89
+ self.channel_axis = channel_axis
90
+ self.input_rank = input_rank
91
+
92
+ # Tensorflow's fake_quant_with_min_max_vars_per_channel only works on last axis, so
93
+ # need to move the quantization axis to the last axis
94
+ if per_channel and channel_axis not in [-1, self.input_rank - 1]:
95
+ # If per-channel quantization is being used and the channel axis is not the last axis,
96
+ # create a permutation vector to move the channel axis to the last position
97
+ self.perm_vec = list(np.arange(self.input_rank))
98
+ self.perm_vec[channel_axis] = self.input_rank - 1
99
+ self.perm_vec[self.input_rank - 1] = channel_axis
100
+ else:
101
+ # If per-channel quantization is not being used or the channel axis is already the last axis,
102
+ # set the permutation vector to None
103
+ self.perm_vec = None
104
+
105
+ def __call__(self, inputs: tf.Tensor) -> tf.Tensor:
106
+ """
107
+ Quantize the given inputs using the quantizer parameters.
108
+
109
+ Args:
110
+ inputs: input tensor to quantize
111
+
112
+ Returns:
113
+ quantized tensor.
114
+ """
115
+ assert inputs.dtype==tf.float32, f'Input tensor was expected to be a float tensor but is of type {inputs.dtype}'
116
+
117
+ # If per-channel quantization is being used
118
+ if self.per_channel:
119
+ # If a permutation vector has been created to move the channel axis to the last position
120
+ if self.perm_vec:
121
+ # Transpose the input tensor to move the channel axis to the last position
122
+ inputs = tf.transpose(inputs, perm=self.perm_vec)
123
+
124
+ # Quantize the input tensor using per-channel quantization
125
+ q_tensor = tf.quantization.fake_quant_with_min_max_vars_per_channel(inputs,
126
+ min=self.min_range,
127
+ max=self.max_range,
128
+ num_bits=self.num_bits)
129
+ if self.perm_vec:
130
+ # Transpose the quantized tensor back to its original shape
131
+ q_tensor = tf.transpose(q_tensor, perm=self.perm_vec)
132
+
133
+ # Return the quantized tensor
134
+ return q_tensor
135
+ else:
136
+ # If per-channel quantization is not being used, quantize the input tensor using regular quantization
137
+ return tf.quantization.fake_quant_with_min_max_vars(inputs,
138
+ min=self.min_range,
139
+ max=self.max_range,
140
+ num_bits=self.num_bits)
141
+
142
+
143
+ def get_config(self):
144
+ """
145
+ Return a dictionary with the configuration of the quantizer.
146
+
147
+ Returns:
148
+ Dictionary with the following keys: 'num_bits', 'min_range', 'max_range', 'per_channel', 'channel_axis'
149
+ """
150
+ return {'per_channel': self.per_channel,
151
+ 'num_bits': self.num_bits,
152
+ 'max_range': self.max_range,
153
+ 'min_range': self.min_range,
154
+ 'channel_axis': self.channel_axis,
155
+ 'input_rank': self.input_rank}
156
+
157
+
158
+ else:
159
+ class WeightsUniformInferableQuantizer: # pragma: no cover
160
+ def __init__(self, *args, **kwargs):
161
+ raise Exception('Installing tensorflow and tensorflow_model_optimization is mandatory '
162
+ 'when using WeightsUniformInferableQuantizer. '
163
+ 'Could not find Tensorflow package.')
@@ -0,0 +1,66 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Any
16
+
17
+ import numpy as np
18
+
19
+ from model_compression_toolkit.core.common import Logger
20
+
21
+
22
+ def validate_uniform_min_max_ranges(min_range: Any, max_range: Any) -> None:
23
+ """
24
+ Validate min/max ranges in uniform quantizers are valid
25
+
26
+ Args:
27
+ min_range: min range list to check
28
+ max_range: max range list to check
29
+
30
+ """
31
+ assert isinstance(min_range, list), f'Expected min_range to be of type list but is {type(min_range)}'
32
+ assert isinstance(max_range, list), f'Expected max_range to be of type list but is {type(max_range)}'
33
+
34
+ assert all([isinstance(x, (float, np.float32, np.float64)) for x in
35
+ min_range]), f'Expected min_range list to contain float values but found {[type(x) for x in min_range]}'
36
+ assert all([isinstance(x, (float, np.float32, np.float64)) for x in
37
+ max_range]), f'Expected max_range list to contain float values but found {[type(x) for x in max_range]}'
38
+
39
+ assert len(min_range) == len(
40
+ max_range), f'Expected min/max values to have the same length but min shape: {len(min_range)} and max shape: ' \
41
+ f'{len(max_range)}'
42
+
43
+ # Convert min/max to numpy arrays
44
+ min_range, max_range = np.asarray(min_range), np.asarray(max_range)
45
+ assert np.all(max_range > min_range), f'Expected max_range to be bigger than min_range!'
46
+
47
+
48
+ def validate_adjusted_min_max_ranges(min_range: Any,
49
+ max_range: Any,
50
+ adj_min:Any,
51
+ adj_max:Any) -> None:
52
+ """
53
+ Validate adjusted min/max ranges in uniform quantization are valid
54
+
55
+ Args:
56
+ min_range: original min range
57
+ max_range: original max range
58
+ adj_min: adjusted min range
59
+ adj_max: adjusted max range
60
+
61
+ """
62
+
63
+ assert np.all(adj_min <= 0) and np.all(
64
+ adj_max >= 0), f'Expected zero to be in the range, got min_range={adj_min}, max_range={adj_max}'
65
+ if not np.isclose(np.linalg.norm(adj_min - min_range), 0, atol=1e-6) or not np.isclose(np.linalg.norm(adj_max - max_range), 0, atol=1e-6):
66
+ Logger.warning(f"Adjusting (min_range, max_range) from ({min_range},{max_range}) to ({adj_min},{adj_max})") # pragma: no cover
@@ -0,0 +1,14 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================