mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -12,32 +12,206 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
15
+ import copy
16
16
  from typing import Callable
17
+ from functools import partial
18
+
19
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH, PYTORCH
17
20
 
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
21
+ from model_compression_toolkit import CoreConfig
22
+ from model_compression_toolkit.core import common
19
23
  from model_compression_toolkit.core.common import Logger
20
- from model_compression_toolkit.core.common.constants import PYTORCH
21
- from model_compression_toolkit.core.common.target_platform import TargetPlatformCapabilities
22
- from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
23
24
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
24
- from model_compression_toolkit import CoreConfig
25
+ from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
26
+ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
27
+ MixedPrecisionQuantizationConfigV2
28
+ from model_compression_toolkit.core.common.target_platform.targetplatform2framework import TargetPlatformCapabilities
29
+ from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
30
+ from model_compression_toolkit.ptq.runner import ptq_runner
31
+
25
32
 
26
33
  if FOUND_TORCH:
34
+ import torch.nn as nn
35
+ from torch.nn import Module
27
36
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
28
37
  from model_compression_toolkit.core.pytorch.constants import DEFAULT_TP_MODEL
29
- from torch.nn import Module
30
-
38
+ from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
39
+ from model_compression_toolkit.qat.common.qat_config import _is_qat_applicable
40
+ from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
41
+ from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
42
+ from model_compression_toolkit import quantizers_infrastructure as qi
31
43
  from model_compression_toolkit import get_target_platform_capabilities
44
+ from model_compression_toolkit.qat.common.qat_config import QATConfig
45
+ from model_compression_toolkit.qat.pytorch.quantizer.quantization_builder import quantization_builder
32
46
  DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
33
47
 
34
- def pytorch_quantization_aware_training_init(in_module: Module,
48
+
49
+ def qat_wrapper(n: common.BaseNode, module: nn.Module, qat_config: QATConfig):
50
+ """
51
+ A function which takes a computational graph node and a pytorch module and perform the quantization wrapping
52
+ Args:
53
+ n: A node of mct graph.
54
+ module: A Pytorch module
55
+ qat_config (QATConfig): QAT configuration
56
+ Returns: Wrapped layer
57
+
58
+ """
59
+ if _is_qat_applicable(n, DEFAULT_PYTORCH_INFO):
60
+ weights_quantizers, activation_quantizers = quantization_builder(n, qat_config, DEFAULT_PYTORCH_INFO)
61
+ return qi.PytorchQuantizationWrapper(module, weights_quantizers, activation_quantizers)
62
+ else:
63
+ return module
64
+
65
+
66
+ def pytorch_quantization_aware_training_init(in_model: Module,
35
67
  representative_data_gen: Callable,
36
68
  target_kpi: KPI = None,
37
69
  core_config: CoreConfig = CoreConfig(),
70
+ qat_config: QATConfig = QATConfig(),
38
71
  fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
39
72
  target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
40
- Logger.error("Quantization Aware Training isn't supported yet.")
73
+ """
74
+ Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized
75
+ with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
76
+ quantized using a symmetric quantization thresholds (power of two).
77
+ The model is first optimized using several transformations (e.g. BatchNormalization folding to
78
+ preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
79
+ being collected for each layer's output (and input, depends on the quantization configuration).
80
+ For each possible bit width (per layer) a threshold is then being calculated using the collected
81
+ statistics. Then, if given a mixed precision config in the core_config, using an ILP solver we find
82
+ a mixed-precision configuration, and set a bit-width for each layer. The model is built with fake_quant
83
+ nodes for quantizing activation. Weights are kept as float and are quantized online while training by the
84
+ quantization wrapper's weight quantizer.
85
+ In order to limit the maximal model's size, a target KPI need to be passed after weights_memory
86
+ is set (in bytes).
87
+
88
+ Args:
89
+ in_model (Model): Pytorch model to quantize.
90
+ representative_data_gen (Callable): Dataset used for initial calibration.
91
+ target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
92
+ core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
93
+ qat_config (QATConfig): QAT configuration
94
+ fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Pytorch info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/pytorch/default_framework_info.py>`_
95
+ target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Pytorch model according to.
96
+
97
+ Returns:
98
+
99
+ A quantized model.
100
+ User information that may be needed to handle the quantized model.
101
+
102
+ Examples:
103
+
104
+ Import MCT:
105
+
106
+ >>> import model_compression_toolkit as mct
107
+
108
+ Import a Pytorch model:
109
+
110
+ >>> from torchvision.models import mobilenet_v2
111
+ >>> model = mobilenet_v2(pretrained=True)
112
+
113
+ Create a random dataset generator, for required number of calibration iterations (num_calibration_batches):
114
+ In this example a random dataset of 10 batches each containing 4 images is used.
115
+
116
+ >>> import numpy as np
117
+ >>> num_calibration_batches = 10
118
+ >>> def repr_datagen():
119
+ >>> for _ in range(num_calibration_batches):
120
+ >>> yield [np.random.random((4, 3, 224, 224))]
121
+
122
+ Create a MCT core config, containing the quantization configuration:
123
+
124
+ >>> config = mct.CoreConfig()
125
+
126
+ Pass the model, the representative dataset generator, the configuration and the target KPI to get a
127
+ quantized model. Now the model contains quantizer wrappers for fine tunning the weights:
128
+
129
+ >>> quantized_model, quantization_info = pytorch_quantization_aware_training_init(model, repr_datagen, core_config=config)
130
+
131
+ For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
132
+
133
+ """
134
+
135
+ if core_config.mixed_precision_enable:
136
+ if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
137
+ common.Logger.error("Given quantization config to mixed-precision facade is not of type "
138
+ "MixedPrecisionQuantizationConfigV2. Please use pytorch_post_training_quantization API,"
139
+ "or pass a valid mixed precision configuration.")
140
+
141
+ common.Logger.info("Using experimental mixed-precision quantization. "
142
+ "If you encounter an issue please file a bug.")
143
+
144
+ tb_w = _init_tensorboard_writer(fw_info)
145
+
146
+ fw_impl = PytorchImplementation()
147
+
148
+ tg, bit_widths_config = core_runner(in_model=in_model,
149
+ representative_data_gen=representative_data_gen,
150
+ core_config=core_config,
151
+ fw_info=DEFAULT_PYTORCH_INFO,
152
+ fw_impl=fw_impl,
153
+ tpc=target_platform_capabilities,
154
+ target_kpi=target_kpi,
155
+ tb_w=tb_w)
156
+
157
+ tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
158
+
159
+ _qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
160
+
161
+ qat_model, user_info = PyTorchModelBuilder(graph=tg, fw_info=fw_info, wrapper=_qat_wrapper).build_model()
162
+
163
+ user_info.mixed_precision_cfg = bit_widths_config
164
+
165
+ return qat_model, user_info
166
+
167
+ def pytorch_quantization_aware_training_finalize(in_model: Module):
168
+ """
169
+ Convert a model fine-tuned by the user to a network with QuantizeWrappers containing
170
+ InferableQuantizers, that quantizes both the layers weights and outputs
171
+
172
+ Args:
173
+ in_model (Model): Pytorch model to remove QuantizeWrappers.
174
+
175
+ Returns:
176
+ A quantized model with QuantizeWrappers and InferableQuantizers.
177
+
178
+ Examples:
179
+
180
+ Import MCT:
181
+
182
+ >>> import model_compression_toolkit as mct
183
+
184
+ Import a Pytorch model:
185
+
186
+ >>> from torchvision.models import mobilenet_v2
187
+ >>> model = mobilenet_v2(pretrained=True)
188
+
189
+ Create a random dataset generator:
190
+
191
+ >>> import numpy as np
192
+ >>> def repr_datagen(): yield [np.random.random((1, 224, 224, 3))]
193
+
194
+ Create a MCT core config, containing the quantization configuration:
195
+
196
+ >>> config = mct.CoreConfig()
197
+
198
+ Pass the model, the representative dataset generator, the configuration and the target KPI to get a
199
+ quantized model:
200
+
201
+ >>> quantized_model, quantization_info = pytorch_quantization_aware_training_init(model, repr_datagen, core_config=config)
202
+
203
+ Use the quantized model for fine-tuning. Finally, remove the quantizer wrappers and keep a quantize model ready for inference.
204
+
205
+ >>> quantized_model = mct.pytorch_quantization_aware_training_finalize(quantized_model)
206
+
207
+ """
208
+ exported_model = copy.deepcopy(in_model)
209
+ for _, layer in exported_model.named_children():
210
+ if isinstance(layer, PytorchQuantizationWrapper):
211
+ layer.convert_to_inferable_quantizers()
212
+
213
+ return exported_model
214
+
41
215
 
42
216
  else:
43
217
  # If torch is not installed,
@@ -45,4 +219,9 @@ else:
45
219
  def pytorch_quantization_aware_training_init(*args, **kwargs):
46
220
  Logger.critical('Installing Pytorch is mandatory '
47
221
  'when using pytorch_quantization_aware_training_init. '
48
- 'Could not find the torch package.')
222
+ 'Could not find the torch package.') # pragma: no cover
223
+
224
+ def pytorch_quantization_aware_training_finalize(*args, **kwargs):
225
+ Logger.critical('Installing Pytorch is mandatory '
226
+ 'when using pytorch_quantization_aware_training_finalize. '
227
+ 'Could not find the torch package.') # pragma: no cover
@@ -0,0 +1,17 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import model_compression_toolkit.qat.pytorch.quantizer.ste_rounding.symmetric_ste
17
+ import model_compression_toolkit.qat.pytorch.quantizer.ste_rounding.uniform_ste
@@ -0,0 +1,49 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Union
16
+
17
+ from model_compression_toolkit.core.common.logger import Logger
18
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
+
20
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
21
+ TrainableQuantizerActivationConfig
22
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.pytorch.base_pytorch_quantizer import \
23
+ BasePytorchTrainableQuantizer
24
+
25
+ if FOUND_TORCH:
26
+
27
+ class BasePytorchQATTrainableQuantizer(BasePytorchTrainableQuantizer):
28
+ """
29
+ A base class for trainable Keras quantizer for QAT.
30
+ """
31
+
32
+ def __init__(self,
33
+ quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
34
+ """
35
+ Initializes BasePytorchQATTrainableQuantizer object.
36
+
37
+ Args:
38
+ quantization_config: quantizer config class contains all the information about a quantizer configuration.
39
+ """
40
+ super().__init__(quantization_config)
41
+
42
+ else:
43
+ class BasePytorchQATTrainableQuantizer(BasePytorchTrainableQuantizer):
44
+ def __init__(self,
45
+ quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
46
+ super().__init__(quantization_config)
47
+ Logger.critical('Installing Pytorch is mandatory '
48
+ 'when using BasePytorchQATTrainableQuantizer. '
49
+ 'Could not find torch package.') # pragma: no cover
@@ -0,0 +1,74 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List, Dict, Tuple
16
+
17
+ from model_compression_toolkit.core import common
18
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
19
+ from model_compression_toolkit.qat.common.qat_config import QATConfig
20
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizer_config import \
21
+ get_trainable_quantizer_quantization_candidates, get_trainable_quantizer_weights_config, \
22
+ get_trainable_quantizer_activation_config
23
+ from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
24
+ from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
25
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizers import \
26
+ get_trainable_quantizer_class
27
+
28
+
29
+ def quantization_builder(n: common.BaseNode,
30
+ qat_config: QATConfig,
31
+ fw_info: FrameworkInfo,
32
+ ) -> Tuple[Dict[str, BasePytorchQATTrainableQuantizer],
33
+ List[BasePytorchQATTrainableQuantizer]]:
34
+ """
35
+ Build quantizers for a node according to its quantization configuration.
36
+
37
+ Args:
38
+ n: Node to build its QuantizeConfig.
39
+ qat_config (QATConfig): QAT configuration
40
+ fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
41
+
42
+ Returns:
43
+ weights_quantizers: A dictionary between a weight's name to its quantizer.
44
+ activation_quantizers: A list of activations quantization, one for each layer output.).
45
+ """
46
+ if len(n.candidates_quantization_cfg) > 1:
47
+ wq_cand, aq_cand = get_trainable_quantizer_quantization_candidates(n)
48
+ else:
49
+ wq_cand, aq_cand = None, None
50
+
51
+ weight_quantizers = {}
52
+ if n.is_weights_quantization_enabled():
53
+ quant_method = n.final_weights_quantization_cfg.weights_quantization_method
54
+ quantizer_class = get_trainable_quantizer_class(QuantizationTarget.Weights,
55
+ qat_config.weight_training_method,
56
+ quant_method,
57
+ BasePytorchQATTrainableQuantizer)
58
+ attributes = fw_info.get_kernel_op_attributes(n.type)
59
+ for attr in attributes:
60
+ weight_quantizers.update({attr: quantizer_class(get_trainable_quantizer_weights_config(n, wq_cand),
61
+ **qat_config.weight_quantizer_params_override)})
62
+
63
+ activation_quantizers = []
64
+ if n.is_activation_quantization_enabled():
65
+ quant_method = n.final_activation_quantization_cfg.activation_quantization_method
66
+ quantizer_class = get_trainable_quantizer_class(QuantizationTarget.Activation,
67
+ qat_config.activation_training_method,
68
+ quant_method,
69
+ BasePytorchQATTrainableQuantizer)
70
+
71
+ activation_quantizers = [quantizer_class(get_trainable_quantizer_activation_config(n, aq_cand),
72
+ **qat_config.activation_quantizer_params_override)]
73
+
74
+ return weight_quantizers, activation_quantizers
@@ -0,0 +1,136 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Tuple
16
+ import torch
17
+
18
+
19
+ def ste_round(x: torch.Tensor) -> torch.Tensor:
20
+ """
21
+ Calculate the rounded values of a tensor
22
+ Args:
23
+ x: input variable
24
+ Returns:
25
+ rounded value
26
+ """
27
+ return (torch.round(x) - x).detach() + x
28
+
29
+
30
+ def ste_clip(x: torch.Tensor, min_val=-1.0, max_val=1.0) -> torch.Tensor:
31
+ """
32
+ Clip a variable between fixed values such that min_val<=output<=max_val
33
+ Args:
34
+ x: input variable
35
+ min_val: minimum value for clipping
36
+ max_val: maximum value for clipping
37
+ Returns:
38
+ clipped variable
39
+ """
40
+ return (torch.clip(x, min=min_val, max=max_val) - x).detach() + x
41
+
42
+
43
+ def fix_range_to_include_zero(range_min: torch.Tensor,
44
+ range_max: torch.Tensor,
45
+ n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
46
+ """
47
+ Adjusting the quantization range to include representation of 0.0 in the quantization grid.
48
+ If quantization per-channel, then range_min and range_max should be tensors in the specific shape that allows
49
+ quantization along the channel_axis.
50
+ Args:
51
+ range_min: min bound of the quantization range (before adjustment).
52
+ range_max: max bound of the quantization range (before adjustment).
53
+ n_bits: Number of bits to quantize the tensor.
54
+ Returns: adjusted quantization range
55
+ """
56
+ min_positive = range_min > 0
57
+ max_negative = range_max < 0
58
+ mid_range = torch.logical_and(torch.logical_not(min_positive), torch.logical_not(max_negative))
59
+ min_positive = min_positive.float()
60
+ max_negative = max_negative.float()
61
+ mid_range = mid_range.float()
62
+
63
+ scale = (range_max - range_min) / (2 ** n_bits - 1)
64
+ min_range_adj = scale * torch.round(range_min / scale)
65
+ max_range_adj = range_max - range_min + min_range_adj
66
+
67
+ min_range_adj = min_range_adj * mid_range + max_negative * range_min
68
+ max_range_adj = max_range_adj * mid_range + min_positive * range_max
69
+ return min_range_adj, max_range_adj
70
+
71
+
72
+ def symmetric_quantizer(tensor_data: torch.Tensor,
73
+ threshold: torch.Tensor,
74
+ n_bits: int,
75
+ sign: bool = False) -> torch.Tensor:
76
+ """
77
+ Quantize a tensor according to the number of bits and threshold.
78
+ Symmetric quantization.
79
+ Args:
80
+ tensor_data: Tensor values to quantize.
81
+ threshold: threshold for quantization.
82
+ n_bits: Number of bits to quantize the tensor.
83
+ sign: sign of tensor_data
84
+ Returns:
85
+ Quantized data.
86
+ """
87
+
88
+ # Compute the step size of quantized values.
89
+ n_pos = 2 ** (n_bits - int(sign))
90
+ delta_tensor = threshold / n_pos
91
+
92
+ # Compute min/max int value
93
+ min_val = -int(sign) * n_pos
94
+ max_val = n_pos - 1
95
+
96
+ # Apply rounding
97
+ input_tensor_int = ste_round(tensor_data / delta_tensor)
98
+
99
+ # Clip data in range
100
+ clipped_tensor = ste_clip(input_tensor_int, min_val=min_val, max_val=max_val)
101
+
102
+ # Quantize the data between -threshold/threshold
103
+ q = delta_tensor * clipped_tensor
104
+ return q
105
+
106
+
107
+ def uniform_quantizer(tensor_data: torch.Tensor,
108
+ range_min: torch.Tensor,
109
+ range_max: torch.Tensor,
110
+ n_bits: int) -> torch.Tensor:
111
+ """
112
+ Quantize a tensor according to given range (min, max) and number of bits.
113
+ Uniform quantization.
114
+ Args:
115
+ tensor_data: Tensor values to quantize.
116
+ range_min: minimum bound of the range for quantization (or array of min values per channel).
117
+ range_max: maximum bound of the range for quantization (or array of max values per channel).
118
+ n_bits: Number of bits to quantize the tensor.
119
+ Returns:
120
+ Quantized data.
121
+ """
122
+ # adjusts the quantization range so the quantization grid includes zero.
123
+ a, b = fix_range_to_include_zero(range_min, range_max, n_bits)
124
+
125
+ # Compute the step size of quantized values.
126
+ delta_tensor = (b - a) / (2 ** n_bits - 1)
127
+
128
+ # Apply rounding
129
+ input_tensor_int = ste_round((tensor_data - a) / delta_tensor)
130
+
131
+ # Clip data in range
132
+ clipped_tensor = ste_clip(input_tensor_int, min_val=0, max_val=2 ** n_bits - 1)
133
+
134
+ # Quantize the data between min/max of quantization range.
135
+ q = delta_tensor * clipped_tensor + a
136
+ return q