mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -0,0 +1,204 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Union
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
22
+ from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
23
+ from model_compression_toolkit import quantizers_infrastructure as qi, TrainingMethod
24
+ from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
25
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
26
+ from model_compression_toolkit.core.common import constants as C
27
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
28
+ from model_compression_toolkit.qat.pytorch.quantizer.quantizer_utils import ste_round, ste_clip, symmetric_quantizer
29
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
30
+ WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer, ActivationPOTInferableQuantizer, \
31
+ ActivationSymmetricInferableQuantizer
32
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
33
+ TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
34
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
35
+
36
+
37
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
38
+ quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
39
+ quantizer_type=TrainingMethod.STE)
40
+ class STEWeightQATQuantizer(BasePytorchQATTrainableQuantizer):
41
+ """
42
+ Trainable constrained quantizer to quantize a layer weights.
43
+ """
44
+
45
+ def __init__(self, quantization_config: TrainableQuantizerWeightsConfig):
46
+ """
47
+ Initialize a TrainableWeightQuantizer object with parameters to use
48
+ for the quantization.
49
+
50
+ Args:
51
+ quantization_config: trainable quantizer config class
52
+ """
53
+ super().__init__(quantization_config)
54
+ self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
55
+ self.threshold_values = quantization_config.weights_quantization_params[C.THRESHOLD]
56
+ self.threshold_shape = np.asarray(self.threshold_values).shape
57
+ self.np_threshold_values = self.threshold_values
58
+
59
+ if self.power_of_two:
60
+ self.np_threshold_values = np.power(2.0,
61
+ np.ceil(np.log2(np.maximum(self.np_threshold_values, C.MIN_THRESHOLD))))
62
+ self.num_bits = self.quantization_config.weights_n_bits
63
+ n_pos_bits = self.num_bits - int(C.WEIGHTS_SIGNED)
64
+ delta = self.np_threshold_values / np.power(2.0, n_pos_bits)
65
+ self.delta_tensor = to_torch_tensor(delta)
66
+ self.min_int = -int(C.WEIGHTS_SIGNED) * (2 ** n_pos_bits)
67
+ self.max_int = (2 ** n_pos_bits) - 1
68
+ self.min = delta * self.min_int
69
+ self.max = delta * self.max_int
70
+
71
+
72
+ def initialize_quantization(self,
73
+ tensor_shape: torch.Size,
74
+ name: str,
75
+ layer: qi.PytorchQuantizationWrapper):
76
+ """
77
+ Add quantizer parameters to the quantizer parameters dictionary
78
+
79
+ Args:
80
+ tensor_shape: tensor shape of the quantized tensor.
81
+ name: Tensor name.
82
+ layer: Layer to quantize.
83
+ """
84
+
85
+ # Add threshold variables to layer.
86
+ layer.register_parameter(name + "_" + THRESHOLD_TENSOR, nn.Parameter(to_torch_tensor(self.np_threshold_values),
87
+ requires_grad=False))
88
+
89
+ # save the quantizer added parameters for later calculations
90
+ self.add_quantizer_variable(THRESHOLD_TENSOR, layer.get_parameter(name + "_" + THRESHOLD_TENSOR), VariableGroup.QPARAMS)
91
+
92
+
93
+ def __call__(self,
94
+ inputs: nn.Parameter,
95
+ training: bool) -> nn.Parameter:
96
+ """
97
+ Quantize a tensor
98
+ Args:
99
+ inputs: Input tensor to quantize.
100
+ training: whether in training mode or not
101
+ Returns:
102
+ quantized tensor
103
+ """
104
+ w0 = ste_round(inputs / self.delta_tensor)
105
+ w1 = ste_clip(w0, min_val=self.min_int, max_val=self.max_int)
106
+ w_q = self.delta_tensor * w1
107
+ return w_q
108
+
109
+ def convert2inferable(self) -> Union[WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer]:
110
+ """
111
+ Convert quantizer to inferable quantizer.
112
+
113
+ Returns:
114
+ A pytorch inferable quanizer object.
115
+ """
116
+ np_threshold = self.get_quantizer_variable(THRESHOLD_TENSOR).cpu().detach().numpy().flatten()
117
+ if self.power_of_two:
118
+ pot_threshold = 2 ** np.ceil(np.log2(np_threshold))
119
+ return WeightsPOTInferableQuantizer(num_bits=self.num_bits,
120
+ threshold=pot_threshold,
121
+ per_channel=self.quantization_config.weights_per_channel_threshold,
122
+ channel_axis=self.quantization_config.weights_channels_axis)
123
+ else:
124
+ return WeightsSymmetricInferableQuantizer(num_bits=self.num_bits,
125
+ threshold=np_threshold,
126
+ per_channel=self.quantization_config.weights_per_channel_threshold,
127
+ channel_axis=self.quantization_config.weights_channels_axis)
128
+
129
+
130
+
131
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Activation,
132
+ quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
133
+ quantizer_type=TrainingMethod.STE)
134
+ class STEActivationQATQuantizer(BasePytorchQATTrainableQuantizer):
135
+ """
136
+ Trainable constrained quantizer to quantize a layer activations.
137
+ """
138
+
139
+ def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
140
+ """
141
+ Initialize a STEActivationQATQuantizer object with parameters to use
142
+ for symmetric or power of two quantization.
143
+
144
+ Args:
145
+ quantization_config: trainable quantizer config class
146
+ """
147
+ super().__init__(quantization_config)
148
+ self.power_of_two = quantization_config.activation_quantization_method == QuantizationMethod.POWER_OF_TWO
149
+ self.sign = quantization_config.activation_quantization_params['is_signed']
150
+ np_threshold_values = quantization_config.activation_quantization_params[C.THRESHOLD]
151
+ self.threshold_tensor = torch.Tensor([np_threshold_values])
152
+ self.num_bits = quantization_config.activation_n_bits
153
+
154
+ def initialize_quantization(self,
155
+ tensor_shape: torch.Size,
156
+ name: str,
157
+ layer: qi.PytorchQuantizationWrapper):
158
+ """
159
+ Add quantizer parameters to the quantizer parameters dictionary
160
+
161
+ Args:
162
+ tensor_shape: tensor shape of the quantized tensor.
163
+ name: Tensor name.
164
+ layer: Layer to quantize.
165
+ """
166
+ layer.register_parameter(name, nn.Parameter(to_torch_tensor(self.threshold_tensor), requires_grad=True))
167
+
168
+ # save the quantizer added parameters for later calculations
169
+ self.add_quantizer_variable(THRESHOLD_TENSOR, layer.get_parameter(name), VariableGroup.QPARAMS)
170
+
171
+ def __call__(self,
172
+ inputs: torch.Tensor,
173
+ training: bool = True) -> torch.Tensor:
174
+ """
175
+ Quantize a tensor.
176
+ Args:
177
+ inputs: Input tensor to quantize.
178
+ training: Whether the graph is in training mode.
179
+
180
+ Returns:
181
+ The quantized tensor.
182
+ """
183
+
184
+ _t = self.get_quantizer_variable(THRESHOLD_TENSOR)
185
+ q_tensor = symmetric_quantizer(inputs, _t, self.num_bits, sign=self.sign)
186
+ return q_tensor
187
+
188
+ def convert2inferable(self) -> Union[ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer]:
189
+ """
190
+ Convert quantizer to inferable quantizer.
191
+
192
+ Returns:
193
+ A pytorch inferable quanizer object.
194
+ """
195
+ np_threshold = self.get_quantizer_variable(THRESHOLD_TENSOR).cpu().detach().numpy()
196
+ if self.power_of_two:
197
+ pot_threshold = np.power(2.0, np.ceil(np.log2(np_threshold)))
198
+ return ActivationPOTInferableQuantizer(num_bits=self.num_bits,
199
+ threshold=pot_threshold,
200
+ signed=self.sign)
201
+ else:
202
+ return ActivationSymmetricInferableQuantizer(num_bits=self.num_bits,
203
+ threshold=np_threshold,
204
+ signed=self.sign)
@@ -0,0 +1,190 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch import Tensor
19
+
20
+ from model_compression_toolkit.core.common.constants import RANGE_MAX, RANGE_MIN
21
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
22
+ from model_compression_toolkit.qat.common.constants import FQ_MIN, FQ_MAX
23
+ from model_compression_toolkit.core.common import constants as C
24
+ from model_compression_toolkit import quantizers_infrastructure as qi, TrainingMethod
25
+ from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
26
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
27
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
28
+ from model_compression_toolkit.qat.pytorch.quantizer.quantizer_utils import uniform_quantizer
29
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
30
+ WeightsUniformInferableQuantizer, ActivationUniformInferableQuantizer
31
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
32
+ TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
33
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
34
+
35
+
36
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
37
+ quantization_method=[QuantizationMethod.UNIFORM],
38
+ quantizer_type=TrainingMethod.STE)
39
+ class STEUniformWeightQATQuantizer(BasePytorchQATTrainableQuantizer):
40
+ """
41
+ Trainable constrained quantizer to quantize a layer inputs.
42
+ """
43
+
44
+ def __init__(self, quantization_config: TrainableQuantizerWeightsConfig):
45
+ """
46
+ Initialize a TrainableWeightQuantizer object with parameters to use
47
+ for the quantization.
48
+
49
+ Args:
50
+ quantization_config: trainable quantizer config class
51
+ """
52
+ super().__init__(quantization_config)
53
+ self.num_bits = self.quantization_config.weights_n_bits
54
+ self.min_int = 0
55
+ self.max_int = 2 ** self.num_bits - 1
56
+ self.max_values = quantization_config.weights_quantization_params[RANGE_MAX]
57
+ self.min_values = quantization_config.weights_quantization_params[RANGE_MIN]
58
+ self.min_max_shape = np.asarray(self.max_values).shape
59
+ self.max = np.reshape(self.max_values,
60
+ [-1]) if self.quantization_config.weights_per_channel_threshold else float(
61
+ self.max_values)
62
+ self.min = np.reshape(self.min_values,
63
+ [-1]) if self.quantization_config.weights_per_channel_threshold else float(
64
+ self.min_values)
65
+
66
+
67
+ def initialize_quantization(self,
68
+ tensor_shape: torch.Size,
69
+ name: str,
70
+ layer: qi.PytorchQuantizationWrapper):
71
+ """
72
+ Add quantizer parameters to the quantizer parameters dictionary
73
+
74
+ Args:
75
+ tensor_shape: tensor shape of the quantized tensor.
76
+ name: Tensor name.
77
+ layer: Layer to quantize.
78
+ """
79
+
80
+ # Add min and max variables to layer.
81
+ layer.register_parameter(name+"_"+FQ_MIN, nn.Parameter(to_torch_tensor(self.min_values), requires_grad=False))
82
+ layer.register_parameter(name+"_"+FQ_MAX, nn.Parameter(to_torch_tensor(self.max_values), requires_grad=False))
83
+
84
+ # Save the quantizer parameters for later calculations
85
+ self.add_quantizer_variable(FQ_MIN, layer.get_parameter(name+"_"+FQ_MIN), VariableGroup.QPARAMS)
86
+ self.add_quantizer_variable(FQ_MAX, layer.get_parameter(name+"_"+FQ_MAX), VariableGroup.QPARAMS)
87
+
88
+
89
+ def __call__(self,
90
+ inputs: nn.Parameter,
91
+ training: bool) -> Tensor:
92
+ """
93
+ Quantize a tensor
94
+ Args:
95
+ inputs: Input tensor to quantize.
96
+ training: whether in training mode or not
97
+ Returns:
98
+ quantized tensor
99
+ """
100
+ return uniform_quantizer(inputs, self.get_quantizer_variable(FQ_MIN), self.get_quantizer_variable(FQ_MAX), self.num_bits)
101
+
102
+ def convert2inferable(self) -> WeightsUniformInferableQuantizer:
103
+ """
104
+ Convert quantizer to inferable quantizer.
105
+
106
+ Returns:
107
+ A pytorch inferable quanizer object.
108
+ """
109
+ _min = self.get_quantizer_variable(FQ_MIN).cpu().detach().numpy()
110
+ _max = self.get_quantizer_variable(FQ_MAX).cpu().detach().numpy()
111
+
112
+ return WeightsUniformInferableQuantizer(num_bits=self.num_bits,
113
+ min_range=_min, max_range=_max,
114
+ per_channel=self.quantization_config.weights_per_channel_threshold,
115
+ channel_axis=self.quantization_config.weights_channels_axis)
116
+
117
+
118
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Activation,
119
+ quantization_method=[QuantizationMethod.UNIFORM],
120
+ quantizer_type=TrainingMethod.STE)
121
+ class STEUniformActivationQATQuantizer(BasePytorchQATTrainableQuantizer):
122
+ """
123
+ Trainable constrained quantizer to quantize a layer activations.
124
+ """
125
+
126
+ def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
127
+ """
128
+ Initialize a STEUniformActivationQATQuantizer object with parameters to use
129
+ for uniform quantization.
130
+
131
+ Args:
132
+ quantization_config: trainable quantizer config class
133
+ """
134
+ super().__init__(quantization_config)
135
+
136
+ np_min_range = quantization_config.activation_quantization_params[C.RANGE_MIN]
137
+ np_max_range = quantization_config.activation_quantization_params[C.RANGE_MAX]
138
+ self.min_range_tensor = torch.Tensor([np_min_range])
139
+ self.max_range_tensor = torch.Tensor([np_max_range])
140
+ self.num_bits = quantization_config.activation_n_bits
141
+
142
+ def initialize_quantization(self,
143
+ tensor_shape: torch.Size,
144
+ name: str,
145
+ layer: qi.PytorchQuantizationWrapper):
146
+ """
147
+ Add quantizer parameters to the quantizer parameters dictionary
148
+
149
+ Args:
150
+ tensor_shape: tensor shape of the quantized tensor.
151
+ name: Tensor name.
152
+ layer: Layer to quantize.
153
+ """
154
+ layer.register_parameter(name+"_"+FQ_MIN, nn.Parameter(to_torch_tensor(self.min_range_tensor), requires_grad=True))
155
+ layer.register_parameter(name+"_"+FQ_MAX, nn.Parameter(to_torch_tensor(self.max_range_tensor), requires_grad=True))
156
+
157
+ # Save the quantizer parameters for later calculations
158
+ self.add_quantizer_variable(FQ_MIN, layer.get_parameter(name+"_"+FQ_MIN), VariableGroup.QPARAMS)
159
+ self.add_quantizer_variable(FQ_MAX, layer.get_parameter(name+"_"+FQ_MAX), VariableGroup.QPARAMS)
160
+
161
+ def __call__(self,
162
+ inputs: torch.Tensor,
163
+ training: bool = True) -> torch.Tensor:
164
+ """
165
+ Quantize a tensor.
166
+ Args:
167
+ inputs: Input tensor to quantize.
168
+ training: Whether the graph is in training mode.
169
+
170
+ Returns:
171
+ The quantized tensor.
172
+ """
173
+
174
+ _min = self.get_quantizer_variable(FQ_MIN)
175
+ _max = self.get_quantizer_variable(FQ_MAX)
176
+ q_tensor = uniform_quantizer(inputs, _min, _max, self.num_bits)
177
+ return q_tensor
178
+
179
+ def convert2inferable(self) -> ActivationUniformInferableQuantizer:
180
+ """
181
+ Convert quantizer to inferable quantizer.
182
+
183
+ Returns:
184
+ A pytorch inferable quanizer object.
185
+ """
186
+ _min = self.get_quantizer_variable(FQ_MIN).cpu().detach().numpy()
187
+ _max = self.get_quantizer_variable(FQ_MAX).cpu().detach().numpy()
188
+
189
+ return ActivationUniformInferableQuantizer(num_bits=self.num_bits,
190
+ min_range=_min, max_range=_max)
@@ -0,0 +1,23 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget, BaseInferableQuantizer
17
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
18
+ TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
19
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.keras.base_keras_quantizer import BaseKerasTrainableQuantizer
20
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.load_model import keras_load_quantized_model
21
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantize_wrapper import KerasQuantizationWrapper
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantize_wrapper import PytorchQuantizationWrapper
23
+
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -0,0 +1,87 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from enum import Enum
16
+ from typing import Any, Dict, List
17
+
18
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+
20
+
21
+ class QuantizationTarget(Enum):
22
+ Activation = "Activation"
23
+ Weights = "Weights"
24
+
25
+
26
+ def mark_quantizer(quantization_target: QuantizationTarget = None,
27
+ quantization_method: List[QuantizationMethod] = None,
28
+ quantizer_type: Any = None):
29
+ """
30
+ A function to be used as decoration for all inferable quantizers (which inherit from BaseInferableQuantizer).
31
+ By decorating a class with this decoration, we can define required static properties of the quantizer.
32
+
33
+ Args:
34
+ quantization_target: QuantizationTarget value which indicates what is the target for quantization to
35
+ use the quantizer for.
36
+ quantization_method: A list of QuantizationMethod values to indicate all type of quantization methods that the
37
+ quantizer supports.
38
+ quantizer_type: The type of the quantizer (quantization technique).
39
+ This can differ, depending on the purpose the quantizer is for.
40
+
41
+ Returns: A function that decorates a class object.
42
+
43
+ """
44
+ def mark(quantizer_class_object: BaseInferableQuantizer):
45
+ """
46
+ Initializes the parameters for the decorator.
47
+
48
+ Args:
49
+ quantizer_class_object: The class to be decorated.
50
+
51
+ Returns: A decorated class.
52
+
53
+ """
54
+ quantizer_class_object.quantization_target = quantization_target
55
+ quantizer_class_object.quantization_method = quantization_method
56
+ quantizer_class_object.quantizer_type = quantizer_type
57
+
58
+ return quantizer_class_object
59
+
60
+ return mark
61
+
62
+
63
+ class BaseInferableQuantizer:
64
+
65
+ def __init__(self):
66
+ """
67
+ This class is a base quantizer which defines an abstract
68
+ function which any quantizer needs to implement.
69
+ """
70
+ pass
71
+
72
+ def initialize_quantization(self,
73
+ tensor_shape: Any,
74
+ name: str,
75
+ layer: Any) -> Dict[Any, Any]:
76
+ """
77
+ Return a dictionary of quantizer parameters and their names.
78
+
79
+ Args:
80
+ tensor_shape: tensor shape of the quantized tensor.
81
+ name: Tensor name.
82
+ layer: Layer to quantize.
83
+
84
+ Returns:
85
+ Dictionary of parameters names to the variables.
86
+ """
87
+ return {}
@@ -0,0 +1,41 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ IS_WEIGHTS = "is_weights"
17
+ IS_ACTIVATIONS = "is_activations"
18
+ ACTIVATION_QUANTIZERS = "activation_quantizers"
19
+ WEIGHTS_QUANTIZERS = "weights_quantizer"
20
+ WEIGHTS_QUANTIZATION_METHOD = 'weights_quantization_method'
21
+ WEIGHTS_N_BITS = 'weights_n_bits'
22
+ WEIGHTS_QUANTIZATION_PARAMS = 'weights_quantization_params'
23
+ ENABLE_WEIGHTS_QUANTIZATION = 'enable_weights_quantization'
24
+ WEIGHTS_CHANNELS_AXIS = 'weights_channels_axis'
25
+ WEIGHTS_PER_CHANNEL_THRESHOLD = 'weights_per_channel_threshold'
26
+ MIN_THRESHOLD = 'min_threshold'
27
+ ACTIVATION_QUANTIZATION_METHOD = 'activation_quantization_method'
28
+ ACTIVATION_N_BITS = 'activation_n_bits'
29
+ ACTIVATION_QUANTIZATION_PARAMS = 'activation_quantization_params'
30
+ ENABLE_ACTIVATION_QUANTIZATION = 'enable_activation_quantization'
31
+ LAYER = "layer"
32
+ STEPS = "optimizer_step"
33
+ TRAINING = "training"
34
+
35
+
36
+ QUANTIZATION_TARGET = 'quantization_target'
37
+ QUANTIZATION_METHOD = 'quantization_method'
38
+ QUANTIZER_TYPE = 'quantizer_type'
39
+
40
+ EPS = 1e-8
41
+ MULTIPLIER_N_BITS = 8
@@ -0,0 +1,31 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Set
16
+
17
+
18
+ def get_all_subclasses(cls: type) -> Set[type]:
19
+ """
20
+ This function returns a list of all subclasses of the given class,
21
+ including all subclasses of those subclasses, and so on.
22
+ Recursively get all subclasses of the subclass and add them to the list of all subclasses.
23
+
24
+ Args:
25
+ cls: A class object.
26
+
27
+ Returns: All classes that inherit from cls.
28
+
29
+ """
30
+
31
+ return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in get_all_subclasses(c)])
@@ -0,0 +1,53 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from model_compression_toolkit.core.common import Logger
17
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
19
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import QUANTIZATION_TARGET, \
20
+ QUANTIZATION_METHOD
21
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_all_subclasses import get_all_subclasses
22
+
23
+
24
+ def get_inferable_quantizer_class(quant_target: QuantizationTarget,
25
+ quant_method: QuantizationMethod,
26
+ quantizer_base_class: type) -> type:
27
+ """
28
+ Searches for an inferable quantizer class that matches the requested QuantizationTarget and QuantizationMethod.
29
+ Exactly one class should be found.
30
+
31
+ Args:
32
+ quant_target: QuantizationTarget value (Weights or Activation) which indicates what is the target for
33
+ quantization to use the quantizer for.
34
+ quant_method: A list of QuantizationMethod values to indicate all type of quantization methods that the
35
+ quantizer supports.
36
+ quantizer_base_class: A type of quantizer that the requested quantizer should inherit from.
37
+
38
+ Returns: A class of a quantizer that inherits from BaseKerasInferableQuantizer.
39
+
40
+ """
41
+ qat_quantizer_classes = get_all_subclasses(quantizer_base_class)
42
+ filtered_quantizers = list(filter(lambda q_class: getattr(q_class, QUANTIZATION_TARGET) == quant_target and
43
+ getattr(q_class, QUANTIZATION_METHOD) is not None and
44
+ quant_method in getattr(q_class, QUANTIZATION_METHOD),
45
+ qat_quantizer_classes))
46
+
47
+ if len(filtered_quantizers) != 1:
48
+ Logger.error(f"Found {len(filtered_quantizers)} quantizer for target {quant_target.value} "
49
+ f"that matches the requested quantization method {quant_method.name} "
50
+ f"but there should be exactly one."
51
+ f"The possible quantizers that were found are {filtered_quantizers}.")
52
+
53
+ return filtered_quantizers[0]