mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -33,7 +33,7 @@ if FOUND_TORCH:
33
33
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
34
34
  from model_compression_toolkit.core.pytorch.constants import DEFAULT_TP_MODEL
35
35
  from torch.nn import Module
36
- from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_fully_quantized_pytorch_model
36
+ from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
37
37
  from model_compression_toolkit import get_target_platform_capabilities
38
38
 
39
39
  DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
@@ -62,7 +62,7 @@ if FOUND_TORCH:
62
62
  representative_data_gen (Callable): Dataset used for calibration.
63
63
  target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
64
64
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
65
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to. `Default PyTorch TPC <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/tpc_models/pytorch_tp_models/pytorch_default.py>`_
65
+ target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
66
66
  new_experimental_exporter (bool): Whether exporting the quantized model using new exporter or not (in progress. Avoiding it for now is recommended).
67
67
 
68
68
  Returns:
@@ -95,8 +95,9 @@ if FOUND_TORCH:
95
95
  if core_config.mixed_precision_enable:
96
96
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
97
97
  common.Logger.error("Given quantization config to mixed-precision facade is not of type "
98
- "MixedPrecisionQuantizationConfigV2. Please use pytorch_post_training_quantization API,"
99
- "or pass a valid mixed precision configuration.")
98
+ "MixedPrecisionQuantizationConfigV2. Please use "
99
+ "pytorch_post_training_quantization API, or pass a valid mixed precision "
100
+ "configuration.") # pragma: no cover
100
101
 
101
102
  common.Logger.info("Using experimental mixed-precision quantization. "
102
103
  "If you encounter an issue please file a bug.")
@@ -127,7 +128,7 @@ if FOUND_TORCH:
127
128
  Logger.warning('Using new experimental exported models. '
128
129
  'Please do not use unless you are familiar with what you are doing')
129
130
 
130
- return get_fully_quantized_pytorch_model(tg)
131
+ return get_exportable_pytorch_model(tg)
131
132
 
132
133
  quantized_model, user_info = export_model(tg,
133
134
  DEFAULT_PYTORCH_INFO,
@@ -143,4 +144,4 @@ else:
143
144
  def pytorch_post_training_quantization_experimental(*args, **kwargs):
144
145
  Logger.critical('Installing Pytorch is mandatory '
145
146
  'when using pytorch_post_training_quantization_experimental. '
146
- 'Could not find the torch package.')
147
+ 'Could not find the torch package.') # pragma: no cover
@@ -0,0 +1,68 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Dict
17
+ from enum import Enum
18
+ from model_compression_toolkit.core import common
19
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
20
+
21
+ def _is_qat_applicable(node: common.BaseNode,
22
+ fw_info: FrameworkInfo) -> bool:
23
+ """
24
+ A function for deciding if a layer should be fine-tuned during QAT
25
+ Args:
26
+ node (BaseNode): Node for quantization decision
27
+ fw_info (FrameworkInfo): Pytorch quantization information
28
+
29
+ Returns:
30
+ A boolean whether the layer is to be wrapped with a QuantizeWrapper
31
+ """
32
+
33
+ if node.is_weights_quantization_enabled() and not fw_info.is_kernel_op(node.type):
34
+ common.Logger.error("QAT Error: Quantizing a node without a kernel isn't supported")
35
+ return node.is_weights_quantization_enabled() or node.is_activation_quantization_enabled()
36
+
37
+
38
+ class TrainingMethod(Enum):
39
+ """
40
+ An enum for selecting a QAT training method
41
+
42
+ STE - Standard straight-through estimator. Includes PowerOfTwo, symmetric & uniform quantizers
43
+ """
44
+ STE = "STE",
45
+
46
+
47
+ class QATConfig:
48
+ """
49
+ QAT configuration class.
50
+ """
51
+
52
+ def __init__(self, weight_training_method: TrainingMethod = TrainingMethod.STE,
53
+ activation_training_method: TrainingMethod = TrainingMethod.STE,
54
+ weight_quantizer_params_override: Dict = None,
55
+ activation_quantizer_params_override: Dict = None,
56
+ ):
57
+ """
58
+
59
+ Args:
60
+ weight_training_method (TrainingMethod): Training method for weight quantizers
61
+ activation_training_method (TrainingMethod): Training method for activation quantizers:
62
+ weight_quantizer_params_override: A dictionary of parameters to override in weight quantization quantizer instantiation. Defaults to None (no parameters)
63
+ activation_quantizer_params_override: A dictionary of parameters to override in activation quantization quantizer instantiation. Defaults to None (no parameters)
64
+ """
65
+ self.weight_training_method = weight_training_method
66
+ self.activation_training_method = activation_training_method
67
+ self.weight_quantizer_params_override = {} if weight_quantizer_params_override is None else weight_quantizer_params_override
68
+ self.activation_quantizer_params_override = {} if activation_quantizer_params_override is None else activation_quantizer_params_override
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from typing import Callable
17
+ from functools import partial
17
18
 
18
19
  from model_compression_toolkit import CoreConfig
19
20
  from model_compression_toolkit.core import common
@@ -29,25 +30,56 @@ from model_compression_toolkit.ptq.runner import ptq_runner
29
30
 
30
31
  if FOUND_TF:
31
32
  import tensorflow as tf
33
+ from tensorflow.keras.layers import Layer
34
+ from tensorflow.keras.models import Model
32
35
 
33
36
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
34
37
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
35
38
  from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
36
- from tensorflow.keras.models import Model
37
39
  from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL
38
40
 
39
- from model_compression_toolkit.qat.keras.qat_model_builder import QATKerasModelBuilder
41
+ from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
42
+
43
+ from model_compression_toolkit import get_target_platform_capabilities
44
+ from model_compression_toolkit import quantizers_infrastructure as qi
40
45
 
41
46
  from model_compression_toolkit import get_target_platform_capabilities
42
- from model_compression_toolkit import qunatizers_infrastructure as qi
47
+ from model_compression_toolkit.core import common
48
+ from model_compression_toolkit.core.common import BaseNode
49
+ from model_compression_toolkit.core.common.constants import TENSORFLOW
50
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
51
+ from model_compression_toolkit.qat.common.qat_config import _is_qat_applicable
52
+ from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL
53
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
54
+ from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder
55
+ from model_compression_toolkit.qat.common.qat_config import QATConfig
56
+ from model_compression_toolkit import quantizers_infrastructure as qi
43
57
 
44
58
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
45
59
 
46
60
 
61
+ def qat_wrapper(n: common.BaseNode, layer: Layer, qat_config):
62
+ """
63
+ A function which takes a computational graph node and a keras layer and perform the quantization wrapping
64
+ Args:
65
+ n: A node of mct graph.
66
+ layer: A keras layer
67
+
68
+ Returns: Wrapped layer
69
+
70
+ """
71
+ if _is_qat_applicable(n, DEFAULT_KERAS_INFO):
72
+ weights_quantizers, activation_quantizers = quantization_builder(n, qat_config, DEFAULT_KERAS_INFO)
73
+ return qi.KerasQuantizationWrapper(layer, weights_quantizers, activation_quantizers)
74
+ else:
75
+ return layer
76
+
77
+
47
78
  def keras_quantization_aware_training_init(in_model: Model,
48
79
  representative_data_gen: Callable,
49
80
  target_kpi: KPI = None,
50
81
  core_config: CoreConfig = CoreConfig(),
82
+ qat_config: QATConfig = QATConfig(),
51
83
  fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
52
84
  target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
53
85
  """
@@ -70,6 +102,7 @@ if FOUND_TF:
70
102
  representative_data_gen (Callable): Dataset used for initial calibration.
71
103
  target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
72
104
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
105
+ qat_config (QATConfig): QAT configuration
73
106
  fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
74
107
  target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
75
108
 
@@ -90,14 +123,14 @@ if FOUND_TF:
90
123
  >>> from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
91
124
  >>> model = MobileNetV2()
92
125
 
93
- Create a random dataset generator, for required number of calibration iterations (num_calibration_batches):
94
- In this example a random dataset of 10 batches each containing 4 images is used.
126
+ Create a random dataset generator, for required number of calibration iterations (num_calibration_batches):
127
+ In this example a random dataset of 10 batches each containing 4 images is used.
95
128
 
96
- >>> import numpy as np
97
- >>> num_calibration_batches = 10
98
- >>> def repr_datagen():
99
- >>> for _ in range(num_calibration_batches):
100
- >>> yield [np.random.random((4, 224, 224, 3))]
129
+ >>> import numpy as np
130
+ >>> num_calibration_batches = 10
131
+ >>> def repr_datagen():
132
+ >>> for _ in range(num_calibration_batches):
133
+ >>> yield [np.random.random((4, 224, 224, 3))]
101
134
 
102
135
  Create a MCT core config, containing the quantization configuration:
103
136
 
@@ -154,24 +187,23 @@ if FOUND_TF:
154
187
 
155
188
  tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
156
189
 
157
- qat_model, user_info = QATKerasModelBuilder(graph=tg, fw_info=fw_info).build_model()
190
+ _qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
191
+ qat_model, user_info = KerasModelBuilder(graph=tg, fw_info=fw_info, wrapper=_qat_wrapper).build_model()
158
192
 
159
193
  user_info.mixed_precision_cfg = bit_widths_config
160
194
  #TODO: remove the last output after updating documentation.
161
195
  return qat_model, user_info, {}
162
196
 
163
197
 
164
- def keras_quantization_aware_training_finalize(in_model: Model):
198
+ def keras_quantization_aware_training_finalize(in_model: Model) -> Model:
165
199
  """
166
- Convert a model fine-tuned by the user to a network without QuantizeWrappers. The exported
167
- model contains float (fake-quantized) parameters and fake-quantiztion layers for quantizing
168
- the activations
200
+ Convert a model fine-tuned by the user (Trainable quantizers) to a model with Inferable quantizers.
169
201
 
170
202
  Args:
171
- in_model (Model): Keras model to remove QuantizeWrappers.
203
+ in_model (Model): Keras model to replace TrainableQuantizer with InferableQuantizer
172
204
 
173
205
  Returns:
174
- A quantized model without QuantizeWrappers.
206
+ A quantized model with Inferable quantizers
175
207
 
176
208
  Examples:
177
209
 
@@ -216,37 +248,12 @@ if FOUND_TF:
216
248
  >>> quantized_model = mct.keras_quantization_aware_training_finalize(quantized_model)
217
249
 
218
250
  """
219
-
220
251
  def _export(layer):
221
252
  if isinstance(layer, qi.KerasQuantizationWrapper):
222
- if layer.dispatcher.is_weights_quantization:
223
- new_layer = layer.layer.__class__.from_config(layer.layer.get_config())
224
- with tf.name_scope(new_layer.name):
225
- new_layer.build(layer.input_shape)
226
- weights_list = []
227
- for w in new_layer.weights:
228
- val = None
229
- for qw in layer.weights:
230
- if w.name in qw.name:
231
- attribute_name = w.name.split('/')[-1].split(':')[0]
232
- if attribute_name in layer.dispatcher.weight_quantizers.keys():
233
- quantizer = layer.dispatcher.weight_quantizers.get(attribute_name)
234
- val = quantizer(qw, False)
235
- else:
236
- val = qw
237
- val = val.numpy()
238
- if val is None:
239
- Logger.error(f'Could not match weight name: {w.name}')
240
- weights_list.append(val)
241
- new_layer.set_weights(weights_list)
242
- new_layer.trainable = False
243
- return new_layer
244
- else:
245
- Logger.error(f'Undefined quantize_config')
246
- else:
247
- return layer
248
-
249
- # clone each layer in the model and apply _export to layers wrapped with a QuantizeWrapper.
253
+ layer.convert_to_inferable_quantizers()
254
+ return layer
255
+
256
+ # clone each layer in the model and apply _export to layers with TrainableQuantizeWrappers
250
257
  exported_model = tf.keras.models.clone_model(in_model, input_tensors=None, clone_function=_export)
251
258
 
252
259
  return exported_model
@@ -257,10 +264,10 @@ else:
257
264
  def keras_quantization_aware_training_init(*args, **kwargs):
258
265
  Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
259
266
  'when using keras_quantization_aware_training_init. '
260
- 'Could not find Tensorflow package.')
267
+ 'Could not find Tensorflow package.') # pragma: no cover
261
268
 
262
269
 
263
270
  def keras_quantization_aware_training_finalize(*args, **kwargs):
264
271
  Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
265
272
  'when using keras_quantization_aware_training_finalize. '
266
- 'Could not find Tensorflow package.')
273
+ 'Could not find Tensorflow package.') # pragma: no cover
@@ -12,3 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
16
+ import model_compression_toolkit.qat.keras.quantizer.ste_rounding.symmetric_ste
17
+ import model_compression_toolkit.qat.keras.quantizer.ste_rounding.uniform_ste
@@ -0,0 +1,49 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Union
16
+
17
+ from model_compression_toolkit.core.common import Logger
18
+ from model_compression_toolkit.core.common.constants import FOUND_TF
19
+
20
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
21
+ TrainableQuantizerActivationConfig, BaseKerasTrainableQuantizer
22
+
23
+ if FOUND_TF:
24
+
25
+ class BaseKerasQATTrainableQuantizer(BaseKerasTrainableQuantizer):
26
+ """
27
+ A base class for trainable Keras quantizer for QAT.
28
+ """
29
+
30
+ def __init__(self,
31
+ quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
32
+ """
33
+ Initializes BaseKerasQATTrainableQuantizer object.
34
+
35
+ Args:
36
+ quantization_config: quantizer config class contains all the information about a quantizer configuration.
37
+ """
38
+
39
+ super().__init__(quantization_config)
40
+
41
+ else:
42
+ class BaseKerasQATTrainableQuantizer(BaseKerasTrainableQuantizer):
43
+ def __init__(self,
44
+ quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
45
+
46
+ super().__init__(quantization_config)
47
+ Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
48
+ 'when using BaseKerasQATTrainableQuantizer. '
49
+ 'Could not find Tensorflow package.') # pragma: no cover
@@ -0,0 +1,48 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import tensorflow as tf
17
+ from typing import Tuple
18
+
19
+
20
+ def adjust_range_to_include_zero(range_min: tf.Tensor,
21
+ range_max: tf.Tensor,
22
+ n_bits: int) -> Tuple[tf.Tensor, tf.Tensor]:
23
+ """
24
+ Adjusting the quantization range to include representation of 0.0 in the quantization grid.
25
+ For per_channel quantization range_min\range_max should be tensors in the specific shape that allows
26
+ quantization along the channel_axis.
27
+
28
+ Args:
29
+ range_min: min bound of the quantization range (before adjustment).
30
+ range_max: max bound of the quantization range (before adjustment).
31
+ n_bits: Number of bits to quantize the tensor.
32
+
33
+ Returns: adjusted quantization range
34
+ """
35
+ scale = (range_max - range_min) / (2 ** n_bits - 1)
36
+ min_range_adj = scale * tf.round(range_min / scale)
37
+ max_range_adj = range_max - range_min + min_range_adj
38
+
39
+ min_positive = range_min > 0
40
+ max_negative = range_max < 0
41
+ mid_range = tf.logical_and(tf.logical_not(min_positive), tf.logical_not(max_negative))
42
+ min_positive = tf.cast(min_positive, tf.float32)
43
+ max_negative = tf.cast(max_negative, tf.float32)
44
+ mid_range = tf.cast(mid_range, tf.float32)
45
+ min_range_adj = min_range_adj * mid_range + max_negative * range_min
46
+ max_range_adj = max_range_adj * mid_range + min_positive * range_max
47
+
48
+ return min_range_adj, max_range_adj
@@ -0,0 +1,77 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Tuple, Dict, List
16
+
17
+ from model_compression_toolkit.core import common
18
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
19
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizer_config import \
20
+ get_trainable_quantizer_weights_config, get_trainable_quantizer_activation_config, \
21
+ get_trainable_quantizer_quantization_candidates
22
+ from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
23
+ from model_compression_toolkit.qat.common.qat_config import QATConfig
24
+ from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
25
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizers import \
26
+ get_trainable_quantizer_class
27
+
28
+
29
+ def quantization_builder(n: common.BaseNode,
30
+ qat_config: QATConfig,
31
+ fw_info: FrameworkInfo,
32
+ ) -> Tuple[Dict[str, BaseKerasQATTrainableQuantizer], List[BaseKerasQATTrainableQuantizer]]:
33
+ """
34
+ Build quantizers for a node according to its quantization configuration.
35
+
36
+ Args:
37
+ n: Node to build its QuantizeConfig.
38
+ qat_config (QATConfig): QAT configuration
39
+ fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
40
+
41
+ Returns:
42
+ weights_quantizers: A dictionary between a weight's name to its quantizer.
43
+ activation_quantizers: A list of activations quantization, one for each layer output.
44
+ """
45
+ if len(n.candidates_quantization_cfg) > 1:
46
+ wq_cand, aq_cand = get_trainable_quantizer_quantization_candidates(n)
47
+ else:
48
+ wq_cand, aq_cand = None, None
49
+
50
+ weight_quantizers = {}
51
+ if n.is_weights_quantization_enabled():
52
+ quant_method = n.final_weights_quantization_cfg.weights_quantization_method
53
+
54
+ quantizer_class = get_trainable_quantizer_class(QuantizationTarget.Weights,
55
+ qat_config.weight_training_method,
56
+ quant_method,
57
+ BaseKerasQATTrainableQuantizer)
58
+ attributes = fw_info.get_kernel_op_attributes(n.type)
59
+ for attr in attributes:
60
+ weight_quantizers.update({attr: quantizer_class(get_trainable_quantizer_weights_config(n, wq_cand),
61
+ **qat_config.weight_quantizer_params_override)})
62
+
63
+ activation_quantizers = []
64
+ if n.is_activation_quantization_enabled():
65
+ quant_method = n.final_activation_quantization_cfg.activation_quantization_method
66
+ # single output -> normalize to list of output_shapes
67
+ output_shapes = n.output_shape if isinstance(n.output_shape[0], (list, tuple)) else [n.output_shape]
68
+
69
+ quantizer_class = get_trainable_quantizer_class(QuantizationTarget.Activation,
70
+ qat_config.activation_training_method,
71
+ quant_method,
72
+ BaseKerasQATTrainableQuantizer)
73
+
74
+ activation_quantizers = [quantizer_class(get_trainable_quantizer_activation_config(n, aq_cand),
75
+ **qat_config.activation_quantizer_params_override)] * len(output_shapes)
76
+
77
+ return weight_quantizers, activation_quantizers