mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -0,0 +1,116 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List
16
+ from model_compression_toolkit.core.common import BaseNode, Logger
17
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
18
+ TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig, TrainableQuantizerCandidateConfig
19
+
20
+
21
+ def get_trainable_quantizer_weights_config(
22
+ n: BaseNode,
23
+ weights_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None
24
+ ) -> TrainableQuantizerWeightsConfig:
25
+ """
26
+ Returns the relevant configuration for weights trainable quantizer
27
+
28
+ Args:
29
+ n: BaseNode - the node to build a trainable quantizer from.
30
+ weights_quantization_candidates: A list of weights quantizer config candidates.
31
+
32
+ Returns:
33
+ TrainableQuantizerWeightsConfig: an object that contains the quantizer configuration
34
+ """
35
+ if n.final_weights_quantization_cfg is None:
36
+ Logger.error(f'Node must have final_weights_quantization_cfg in order to build quantizer configuration') # pragma: no cover
37
+
38
+ final_cfg = n.final_weights_quantization_cfg
39
+ return TrainableQuantizerWeightsConfig(final_cfg.weights_quantization_method,
40
+ final_cfg.weights_n_bits,
41
+ final_cfg.weights_quantization_params,
42
+ final_cfg.enable_weights_quantization,
43
+ final_cfg.weights_channels_axis,
44
+ final_cfg.weights_per_channel_threshold,
45
+ final_cfg.min_threshold,
46
+ weights_quantization_candidates)
47
+
48
+
49
+ def get_trainable_quantizer_activation_config(
50
+ n: BaseNode,
51
+ activation_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None
52
+ ) -> TrainableQuantizerActivationConfig:
53
+ """
54
+ Returns configuration for activation trainable quantizer
55
+
56
+ Args:
57
+ n: BaseNode - the node to build a trainable quantizer from.
58
+ activation_quantization_candidates: A list of activation quantizer candidates config.
59
+
60
+ Returns:
61
+ TrainableQuantizerActivationConfig - an object that contains the quantizer configuration
62
+ """
63
+ if n.final_activation_quantization_cfg is None:
64
+ Logger.error(f'Node must have final_activation_quantization_cfg in order to build quantizer configuration') # pragma: no cover
65
+
66
+ final_cfg = n.final_activation_quantization_cfg
67
+ return TrainableQuantizerActivationConfig(final_cfg.activation_quantization_method,
68
+ final_cfg.activation_n_bits,
69
+ final_cfg.activation_quantization_params,
70
+ final_cfg.enable_activation_quantization,
71
+ final_cfg.min_threshold,
72
+ activation_quantization_candidates)
73
+
74
+
75
+ def get_trainable_quantizer_quantization_candidates(n: BaseNode):
76
+ """
77
+ Returns quantization configuration candidates for activation and weights trainable quantizer.
78
+ Checks that the candidates are compatible with trainable quantizer
79
+
80
+ Args:
81
+ n: BaseNode - the node to build a trainable quantizer from
82
+
83
+ Returns:
84
+ weights_quantization_candidates - A list of configuration candidates for weights
85
+ activation_quantization_candidates - A list of configuration candidates for activation
86
+ """
87
+ # all candidates must have the same weights quantization method
88
+ weights_quantization_methods = set([cfg.weights_quantization_cfg.weights_quantization_method for cfg in n.candidates_quantization_cfg])
89
+ if len(weights_quantization_methods) > 1:
90
+ Logger.error(f'Unsupported candidates_quantization_cfg with different weights quantization methods: {weights_quantization_methods}') # pragma: no cover
91
+
92
+ # all candidates must have the same activation quantization method
93
+ activation_quantization_methods = set([cfg.activation_quantization_cfg.activation_quantization_method for cfg in n.candidates_quantization_cfg])
94
+ if len(activation_quantization_methods) > 1:
95
+ Logger.error(f'Unsupported candidates_quantization_cfg with different activation quantization methods: {activation_quantization_methods}') # pragma: no cover
96
+
97
+ # get unique lists of candidates
98
+ unique_weights_candidates = n.get_unique_weights_candidates()
99
+ unique_activation_candidates = n.get_unique_activation_candidates()
100
+
101
+ # verify all the combinations of weights_n_bits and activation_n_bits are allowed
102
+ if len(n.candidates_quantization_cfg) != len(unique_weights_candidates) * len(unique_activation_candidates):
103
+ Logger.error(f'Unsupported candidates_quantization_cfg for a trainable quantizer,'
104
+ f'it must contain all the combinations of (weights_n_bits X activations_n_bits)') # pragma: no cover
105
+
106
+ # generate list of weights quantizer candidates
107
+ weights_cfg_candidates = [TrainableQuantizerCandidateConfig(
108
+ cfg.weights_quantization_cfg.weights_n_bits,
109
+ cfg.weights_quantization_cfg.weights_quantization_params) for cfg in unique_weights_candidates]
110
+
111
+ # generate list of activation quantizer candidates
112
+ activation_cfg_candidates = [TrainableQuantizerCandidateConfig(
113
+ cfg.activation_quantization_cfg.activation_n_bits,
114
+ cfg.activation_quantization_cfg.activation_quantization_params) for cfg in unique_activation_candidates]
115
+
116
+ return weights_cfg_candidates, activation_cfg_candidates
@@ -0,0 +1,65 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Union
16
+
17
+ from model_compression_toolkit.gptq import RoundingType
18
+ from model_compression_toolkit import TrainingMethod
19
+ from model_compression_toolkit.core.common import Logger
20
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
21
+ from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants \
23
+ import QUANTIZATION_TARGET, QUANTIZATION_METHOD, QUANTIZER_TYPE
24
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_all_subclasses \
25
+ import get_all_subclasses
26
+
27
+
28
+ def get_trainable_quantizer_class(quant_target: QuantizationTarget,
29
+ quantizer_type: Union[TrainingMethod, RoundingType],
30
+ quant_method: QuantizationMethod,
31
+ quantizer_base_class: type) -> type:
32
+ """
33
+ Searches for a trainable quantizer class that matches the requested QuantizationTarget and QuantizationMethod and
34
+ a task dedicated quantizer type. Exactly one class should be found.
35
+
36
+ Args:
37
+ quant_target: QuantizationTarget value which indicates what is the target for quantization to
38
+ use the quantizer for.
39
+ quantizer_type: The type of the quantizer (quantization technique).
40
+ This can differ, depending on the purpose the quantizer is for.
41
+ quant_method: A list of QuantizationMethod values to indicate all type of quantization methods that the
42
+ quantizer supports.
43
+ quantizer_base_class: A type of quantizer that the requested quantizer should inherit from.
44
+
45
+ Returns: A class of a quantizer that inherits from BaseKerasQATTrainableQuantizer.
46
+
47
+ """
48
+ qat_quantizer_classes = get_all_subclasses(quantizer_base_class)
49
+ if len(qat_quantizer_classes) == 0:
50
+ Logger.error(f"No quantizers were found that inherit from {quantizer_base_class}.") # pragma: no cover
51
+
52
+ filtered_quantizers = list(filter(lambda q_class: getattr(q_class, QUANTIZATION_TARGET, None) is not None and
53
+ getattr(q_class, QUANTIZATION_TARGET) == quant_target and
54
+ getattr(q_class, QUANTIZATION_METHOD, None) is not None and
55
+ quant_method in getattr(q_class, QUANTIZATION_METHOD, []) and
56
+ getattr(q_class, QUANTIZER_TYPE, None) == quantizer_type,
57
+ qat_quantizer_classes))
58
+
59
+ if len(filtered_quantizers) != 1:
60
+ Logger.error(f"Found {len(filtered_quantizers)} quantizer for target {quant_target.value} " # pragma: no cover
61
+ f"that matches the requested quantization method {quant_method.name} and "
62
+ f"quantizer type {quantizer_type.value} but there should be exactly one."
63
+ f"The possible quantizers that were found are {filtered_quantizers}.")
64
+
65
+ return filtered_quantizers[0]
@@ -0,0 +1,36 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Tuple, List
16
+
17
+
18
+ def get_threshold_reshape_shape(tensor_shape: Tuple, quant_axis: int, quant_axis_dim: int) -> List[int]:
19
+ """
20
+ Gets a shape that contains 1 in all axis except the quantization axis, to adjust the threshold tensor for
21
+ per-channel quantization.
22
+
23
+ Args:
24
+ tensor_shape: The shape of th
25
+
26
+ e tensor to be quantized.
27
+ quant_axis: The axis along which the quantization happens (usually the tensor's channel axis).
28
+ quant_axis_dim: The dimension of the quantization axis.
29
+
30
+ Returns: A shape to reshape the threshold tensor according to.
31
+
32
+ """
33
+ n_axis = len(tensor_shape)
34
+ quantization_axis = n_axis + quant_axis if quant_axis < 0 else quant_axis
35
+
36
+ return [quant_axis_dim if i == quantization_axis else 1 for i in range(n_axis)]
@@ -0,0 +1,97 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from abc import ABC
16
+ from typing import Dict, List
17
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+
19
+
20
+ class TrainableQuantizerCandidateConfig:
21
+
22
+ def __init__(self,
23
+ n_bits: int,
24
+ quantization_params: Dict,
25
+ ):
26
+ """
27
+ Class for representing candidates of quantization configurations for trainable quantizer.
28
+ It can be used for weights and activation quantization configuration.
29
+
30
+ Args:
31
+ n_bits (int): Number of bits to use for quantization.
32
+ quantization_params (Dict): Dictionary that contains quantization params.
33
+ """
34
+
35
+ self.n_bits = n_bits
36
+ self.quantization_params = quantization_params
37
+
38
+
39
+ class TrainableQuantizerActivationConfig:
40
+
41
+ def __init__(self,
42
+ activation_quantization_method: QuantizationMethod,
43
+ activation_n_bits: int,
44
+ activation_quantization_params: Dict,
45
+ enable_activation_quantization: bool,
46
+ min_threshold: float,
47
+ activation_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None,
48
+ ):
49
+ """
50
+ Attributes for configuring activations trainable quantizer.
51
+
52
+ Args:
53
+ activation_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for activation quantization.
54
+ activation_n_bits (int): Number of bits to quantize the activations.
55
+ activation_quantization_params (Dict): Dictionary that contains activation quantization params.
56
+ enable_activation_quantization (bool): Whether to quantize the layer's activations or not.
57
+ min_threshold (float): Minimum threshold to use during thresholds selection.
58
+ """
59
+ self.activation_quantization_method = activation_quantization_method
60
+ self.activation_n_bits = activation_n_bits
61
+ self.activation_quantization_params = activation_quantization_params
62
+ self.enable_activation_quantization = enable_activation_quantization
63
+ self.min_threshold = min_threshold
64
+ self.activation_bits_candidates = activation_quantization_candidates
65
+
66
+
67
+ class TrainableQuantizerWeightsConfig:
68
+ def __init__(self,
69
+ weights_quantization_method: QuantizationMethod,
70
+ weights_n_bits: int,
71
+ weights_quantization_params: Dict,
72
+ enable_weights_quantization: bool,
73
+ weights_channels_axis: int,
74
+ weights_per_channel_threshold: bool,
75
+ min_threshold: float,
76
+ weights_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None,
77
+ ):
78
+ """
79
+ Attributes for configuring weights trainable quantizer.
80
+
81
+ Args:
82
+ weights_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for weights quantization.
83
+ weights_n_bits (int): Number of bits to quantize the coefficients.
84
+ weights_quantization_params (Dict): Dictionary that contains weights quantization params.
85
+ enable_weights_quantization (bool): Whether to quantize the layer's weights or not.
86
+ weights_channels_axis (int): Axis to quantize a node's kernel when quantizing per-channel.
87
+ weights_per_channel_threshold (bool): Whether to quantize the weights per-channel or not (per-tensor).
88
+ min_threshold (float): Minimum threshold to use during thresholds selection.
89
+ """
90
+ self.weights_quantization_method = weights_quantization_method
91
+ self.weights_n_bits = weights_n_bits
92
+ self.weights_quantization_params = weights_quantization_params
93
+ self.enable_weights_quantization = enable_weights_quantization
94
+ self.weights_channels_axis = weights_channels_axis
95
+ self.weights_per_channel_threshold = weights_per_channel_threshold
96
+ self.min_threshold = min_threshold
97
+ self.weights_bits_candidates = weights_quantization_candidates
@@ -0,0 +1,14 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,90 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Dict, Any, Union, List
16
+
17
+ from model_compression_toolkit.core.common import Logger
18
+ from model_compression_toolkit.core.common.constants import FOUND_TF
19
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
20
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer, VAR, GROUP
21
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
22
+ TrainableQuantizerActivationConfig
23
+
24
+ if FOUND_TF:
25
+ QUANTIZATION_CONFIG = 'quantization_config'
26
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.keras.config_serialization import config_serialization, \
27
+ config_deserialization
28
+ import tensorflow as tf
29
+
30
+ class BaseKerasTrainableQuantizer(BaseTrainableQuantizer):
31
+ def __init__(self,
32
+ quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
33
+ """
34
+ This class is a base quantizer which validates provided quantization config and defines an abstract function which any quantizer needs to implement.
35
+ This class adds to the base quantizer a get_config and from_config functions to enable loading and saving the keras model.
36
+ Args:
37
+ quantization_config: quantizer config class contains all the information about a quantizer configuration.
38
+ """
39
+ super().__init__(quantization_config)
40
+
41
+ def get_config(self) -> Dict[str, Any]:
42
+ """
43
+
44
+ Returns: Configuration of BaseKerasQuantizer.
45
+
46
+ """
47
+ return {QUANTIZATION_CONFIG: config_serialization(self.quantization_config)}
48
+
49
+ @classmethod
50
+ def from_config(cls, config: dict):
51
+ """
52
+
53
+ Args:
54
+ config(dict): dictonory of BaseKerasQuantizer Configuration
55
+
56
+ Returns: A BaseKerasQuantizer
57
+
58
+ """
59
+ config = config.copy()
60
+ quantization_config = config_deserialization(config[QUANTIZATION_CONFIG])
61
+ # Note that a quantizer only receive quantization config and the rest of define hardcoded inside the speficie quantizer.
62
+ return cls(quantization_config=quantization_config)
63
+
64
+ def get_trainable_variables(self, group: VariableGroup) -> List[tf.Tensor]:
65
+ """
66
+ Get trainable parameters with specific group from quantizer
67
+
68
+ Args:
69
+ group: Enum of variable group
70
+
71
+ Returns:
72
+ List of trainable variables
73
+ """
74
+ quantizer_trainable = []
75
+ for name, parameter_dict in self.quantizer_parameters.items():
76
+ quantizer_parameter, parameter_group = parameter_dict[VAR], parameter_dict[GROUP]
77
+ if quantizer_parameter.trainable and parameter_group == group:
78
+ quantizer_trainable.append(quantizer_parameter)
79
+ return quantizer_trainable
80
+
81
+
82
+ else:
83
+ class BaseKerasTrainableQuantizer(BaseTrainableQuantizer):
84
+ def __init__(self,
85
+ quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
86
+
87
+ super().__init__(quantization_config)
88
+ Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
89
+ 'when using BaseKerasQuantizer. '
90
+ 'Could not find Tensorflow package.') # pragma: no cover
@@ -0,0 +1,80 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import copy
16
+
17
+ from typing import Any, Union
18
+ from enum import Enum
19
+
20
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
21
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
22
+ TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig
23
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common import constants as C
24
+
25
+
26
+ def transform_enum(v: Any):
27
+ """
28
+ If an enum is received it value is return otherwise the input is returned.
29
+ Args:
30
+ v: Any type
31
+
32
+ Returns: Any
33
+
34
+ """
35
+ if isinstance(v, Enum):
36
+ return v.value
37
+ return v
38
+
39
+
40
+ def config_serialization(quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
41
+ """
42
+ This function change trainable quantizer config to a dictionary
43
+ Args:
44
+ quantization_config: A TrainableQuantizerWeightsConfig or TrainableQuantizerActivationConfig for serialization
45
+
46
+ Returns: A config dictionary of quantizer config
47
+
48
+ """
49
+ config_data = {k: transform_enum(v) for k, v in quantization_config.__dict__.items()}
50
+ config_data[C.IS_WEIGHTS] = isinstance(quantization_config, TrainableQuantizerWeightsConfig)
51
+ config_data[C.IS_ACTIVATIONS] = isinstance(quantization_config, TrainableQuantizerActivationConfig)
52
+ return config_data
53
+
54
+
55
+ def config_deserialization(in_config: dict) -> Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]:
56
+ """
57
+ This function change config dictionary to trainable quantizer config.
58
+ Args:
59
+ in_config: A config dictionary of trainable quantizer config.
60
+
61
+ Returns: Trainable quantizer configuration object - TrainableQuantizerWeightsConfig or TrainableQuantizerActivationConfig
62
+
63
+ """
64
+ in_config = copy.deepcopy(in_config)
65
+ if in_config[C.IS_WEIGHTS]:
66
+ return TrainableQuantizerWeightsConfig(weights_quantization_method=QuantizationMethod(in_config[C.WEIGHTS_QUANTIZATION_METHOD]),
67
+ weights_n_bits=in_config[C.WEIGHTS_N_BITS],
68
+ weights_quantization_params=in_config[C.WEIGHTS_QUANTIZATION_PARAMS],
69
+ enable_weights_quantization=in_config[C.ENABLE_WEIGHTS_QUANTIZATION],
70
+ weights_channels_axis=in_config[C.WEIGHTS_CHANNELS_AXIS],
71
+ weights_per_channel_threshold=in_config[C.WEIGHTS_PER_CHANNEL_THRESHOLD],
72
+ min_threshold=in_config[C.MIN_THRESHOLD])
73
+ elif in_config[C.IS_ACTIVATIONS]:
74
+ return TrainableQuantizerActivationConfig(activation_quantization_method=QuantizationMethod(in_config[C.ACTIVATION_QUANTIZATION_METHOD]),
75
+ activation_n_bits=in_config[C.ACTIVATION_N_BITS],
76
+ activation_quantization_params=in_config[C.ACTIVATION_QUANTIZATION_PARAMS],
77
+ enable_activation_quantization=in_config[C.ENABLE_ACTIVATION_QUANTIZATION],
78
+ min_threshold=in_config[C.MIN_THRESHOLD])
79
+ else:
80
+ raise NotImplemented # pragma: no cover
@@ -0,0 +1,48 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numpy as np
17
+ import tensorflow as tf
18
+
19
+
20
+ def int_quantization_with_threshold(data: tf.Tensor,
21
+ n_bits: int,
22
+ signed: bool,
23
+ threshold: np.ndarray,
24
+ eps: float) -> tf.Tensor:
25
+ """
26
+ Divides data by threshold and quantize it to integers in the quantization range (depends on signed value).
27
+
28
+ Args:
29
+ data: tensor data.
30
+ n_bits: number of bits that determines the quantization range.
31
+ signed: Whether the quantization is signed or not.
32
+ threshold: threshold for quantization.
33
+ eps: Small value for numerical stability in division.
34
+
35
+ Returns:
36
+ Uniform Quantized tensor.
37
+
38
+ """
39
+
40
+ if signed:
41
+ clip_max = 2 ** (n_bits - 1) - 1
42
+ clip_min = -2 ** (n_bits - 1)
43
+ else:
44
+ clip_max = 2 ** n_bits - 1
45
+ clip_min = 0
46
+
47
+ return tf.clip_by_value((data / (threshold + eps)) * (2 ** (n_bits - int(signed))),
48
+ clip_value_max=clip_max, clip_value_min=clip_min)
@@ -0,0 +1,14 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,66 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Union, List
16
+
17
+ from model_compression_toolkit.core.common.logger import Logger
18
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
20
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer, VAR, GROUP
21
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
22
+ TrainableQuantizerActivationConfig
23
+
24
+
25
+ if FOUND_TORCH:
26
+
27
+ import torch
28
+
29
+ class BasePytorchTrainableQuantizer(BaseTrainableQuantizer):
30
+ def __init__(self,
31
+ quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
32
+ """
33
+ This class is a base Pytorch quantizer which validates the provided quantization config and defines an
34
+ abstract function which any quantizer needs to implement.
35
+
36
+ Args:
37
+ quantization_config: quantizer config class contains all the information about the quantizer configuration.
38
+ """
39
+ super().__init__(quantization_config)
40
+
41
+
42
+ def get_trainable_variables(self, group: VariableGroup) -> List[torch.Tensor]:
43
+ """
44
+ Get trainable parameters with specific group from quantizer
45
+
46
+ Args:
47
+ group: Enum of variable group
48
+
49
+ Returns:
50
+ List of trainable variables
51
+ """
52
+ quantizer_trainable = []
53
+ for name, parameter_dict in self.quantizer_parameters.items():
54
+ quantizer_parameter, parameter_group = parameter_dict[VAR], parameter_dict[GROUP]
55
+ if quantizer_parameter.requires_grad and parameter_group == group:
56
+ quantizer_trainable.append(quantizer_parameter)
57
+ return quantizer_trainable
58
+
59
+ else:
60
+ class BasePytorchTrainableQuantizer(BaseTrainableQuantizer):
61
+ def __init__(self,
62
+ quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
63
+ super().__init__(quantization_config)
64
+ Logger.critical('Installing Pytorch is mandatory '
65
+ 'when using BasePytorchTrainableQuantizer. '
66
+ 'Could not find torch package.') # pragma: no cover