mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -0,0 +1,45 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Callable
16
+
17
+ from model_compression_toolkit.gptq import RoundingType, GradientPTQConfigV2, GradientPTQConfig
18
+ from model_compression_toolkit.gptq.keras.quantizer.soft_rounding.soft_quantizer_reg import \
19
+ SoftQuantizerRegularization
20
+
21
+
22
+ def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen: Callable) -> Callable:
23
+ """
24
+ Returns a function that computes the regularization term for GPTQ training based on the given
25
+ rounding type in the GPTQ configuration.
26
+
27
+ Args:
28
+ gptq_config: A GPTQ configuration.
29
+ representative_data_gen: Dataset used for the GPTQ training.
30
+
31
+ Returns: A function for computing the regularization. If there is no regularization function defined for the given
32
+ rounding type, then it returns a function that just returns 0.
33
+
34
+ """
35
+ if gptq_config.rounding_type == RoundingType.SoftQuantizer:
36
+ # dry run on the representative dataset to count number of batches
37
+ num_batches = 0
38
+ for _ in representative_data_gen():
39
+ num_batches += 1
40
+
41
+ n_epochs = GradientPTQConfigV2.from_v1(n_ptq_iter=num_batches, config_v1=gptq_config).n_epochs if \
42
+ not type(gptq_config) == GradientPTQConfigV2 else gptq_config.n_epochs
43
+ return SoftQuantizerRegularization(total_gradient_steps=num_batches * n_epochs)
44
+ else:
45
+ return lambda m, e_reg: 0
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -0,0 +1,112 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List
16
+
17
+ import tensorflow as tf
18
+ from keras import Model
19
+
20
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
21
+ from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
22
+ from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
23
+
24
+
25
+ class LinearTempDecay:
26
+ """
27
+ Annealing process for the soft quantizer regularization temperature term.
28
+ """
29
+
30
+ def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 20, end_b: int = 2):
31
+ """
32
+ Initializes a LinearTempDecay object.
33
+
34
+ Args:
35
+ t_max: maximal time step.
36
+ rel_start_decay: Decay step size at the beginning of the process.
37
+ start_b: Starting value of the regularization term.
38
+ end_b: Target value of the regularization term.
39
+ """
40
+
41
+ self.t_max = t_max
42
+ self.start_decay = rel_start_decay * t_max
43
+ self.start_b = start_b
44
+ self.end_b = end_b
45
+
46
+ def __call__(self, t: int) -> float:
47
+ """
48
+ Cosine annealing scheduler for soft quantizer regularization temperature term.
49
+
50
+ Args:
51
+ t: The current time step.
52
+
53
+ Returns: Scheduled temperature.
54
+ """
55
+
56
+ is_before_start_decay = tf.cast(t < self.start_decay, tf.float32)
57
+
58
+ rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
59
+
60
+ return self.start_b * is_before_start_decay + \
61
+ (1 - is_before_start_decay) * \
62
+ (self.end_b + (self.start_b - self.end_b) * tf.math.maximum(0.0, (1 - rel_t)))
63
+
64
+
65
+ class SoftQuantizerRegularization:
66
+ """
67
+ A class to handle the computation of soft quantizer regularization for GPTQ training.
68
+ """
69
+
70
+ def __init__(self, total_gradient_steps: int):
71
+ """
72
+ Initializes the regularization computation object with a LinearDecay object.
73
+
74
+ Args:
75
+ total_gradient_steps: The number of gradient steps during optimization.
76
+ """
77
+ # Initializing the temperature decay according to the number of expected gradient steps
78
+ self.linear_decay = LinearTempDecay(total_gradient_steps)
79
+
80
+ self.count_iter = 0
81
+
82
+
83
+ def __call__(self, model: Model, entropy_reg: float):
84
+ """
85
+ Returns the soft quantizer regularization value for SoftRounding.
86
+
87
+ Args:
88
+ model: A model to be quantized with SoftRounding.
89
+ entropy_reg: Entropy value to scale the quantizer regularization.
90
+
91
+ Returns: Regularization value.
92
+ """
93
+
94
+ soft_reg_aux: List[tf.Tensor] = []
95
+ for layer in model.layers:
96
+ if isinstance(layer, KerasQuantizationWrapper):
97
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
98
+ fw_info=DEFAULT_KERAS_INFO)
99
+
100
+ st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
101
+ b = self.linear_decay(self.count_iter)
102
+
103
+ soft_reg_aux.append(tf.reduce_sum(1 - tf.pow(tf.math.abs(st - .5) * 2, b)))
104
+
105
+ reg = 0
106
+
107
+ for sq in soft_reg_aux:
108
+ reg += sq
109
+
110
+ self.count_iter += 1
111
+
112
+ return entropy_reg * reg
@@ -0,0 +1,256 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import tensorflow as tf
17
+ import numpy as np
18
+
19
+ from model_compression_toolkit.gptq import RoundingType
20
+ from model_compression_toolkit import quantizers_infrastructure as qi
21
+ from model_compression_toolkit.core.common import max_power_of_two
22
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
23
+ from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
24
+ SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
25
+ from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
26
+ from typing import Dict, Any
27
+ from model_compression_toolkit.core.common.constants import THRESHOLD, MIN_THRESHOLD
28
+ from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
29
+ from model_compression_toolkit.gptq.keras.quantizer.quant_utils import power_of_two_max, clip, calculate_delta
30
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
31
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
32
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
33
+ get_threshold_reshape_shape
34
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
35
+
36
+
37
+ def soft_rounding_symmetric_quantizer(input_tensor: tf.Tensor,
38
+ auxvar_tensor: tf.Variable,
39
+ threshold_tensor: tf.Tensor,
40
+ num_bits: int,
41
+ signed: bool,
42
+ power_of_two: bool) -> tf.Tensor:
43
+ """
44
+ Quantize a tensor symmetrically for GPTQ quantizers.
45
+
46
+ Args:
47
+ input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
48
+ auxvar_tensor: Tensor that manifests the bit shift of the quantized weights due to gptq training.
49
+ threshold_tensor: Tensor with values to compute the threshold.
50
+ num_bits: Num of bits to use.
51
+ signed: Signedness of the quantization range.
52
+ power_of_two: Whether the threshold should be constrained or not.
53
+
54
+ Returns:
55
+ A quantized tensor.
56
+ """
57
+
58
+ if power_of_two:
59
+ threshold_tensor = power_of_two_max(threshold_tensor)
60
+ delta = calculate_delta(threshold_tensor, num_bits, signed)
61
+ input_tensor = tf.stop_gradient(input_tensor)
62
+ input_tensor_int = tf.floor(input_tensor / delta)
63
+ tensor_q = input_tensor_int + auxvar_tensor
64
+ min_int = -int(signed) * (2 ** (num_bits - int(signed)))
65
+ max_int = (2 ** (num_bits - int(signed))) - 1
66
+ return delta * clip(tensor_q, max_val=max_int, min_val=min_int)
67
+
68
+
69
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
70
+ quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
71
+ quantizer_type=RoundingType.SoftQuantizer)
72
+ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
73
+ """
74
+ Trainable symmetric quantizer to optimize the rounding of the quantized values using a soft quantization method.
75
+ """
76
+
77
+ def __init__(self,
78
+ quantization_config: TrainableQuantizerWeightsConfig,
79
+ quantization_parameter_learning: bool = False):
80
+ """
81
+ Initialize a SymmetricSoftRoundingGPTQ object with parameters to use
82
+ for the quantization.
83
+
84
+ Args:
85
+ quantization_config: Trainable weights quantizer config.
86
+ quantization_parameter_learning: Whether to train the quantization threshold.
87
+ """
88
+ super().__init__(quantization_config)
89
+ self.num_bits = quantization_config.weights_n_bits
90
+ self.per_channel = quantization_config.weights_per_channel_threshold
91
+
92
+ threshold_values = quantization_config.weights_quantization_params[THRESHOLD]
93
+ self.threshold_shape = np.asarray(threshold_values).shape
94
+ self.threshold_values = np.reshape(np.asarray(threshold_values), [-1]) if self.per_channel else np.asarray(
95
+ threshold_values)
96
+
97
+ self.quantization_axis = quantization_config.weights_channels_axis
98
+ self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
99
+ self.quantization_parameter_learning = quantization_parameter_learning
100
+ self.num_channels = len(self.threshold_values) if self.per_channel else 1
101
+
102
+ # gamma and zeta are stretch parameters for computing the rectified sigmoind function.
103
+ # See: https://arxiv.org/pdf/2004.10568.pdf
104
+ self.gamma = SOFT_ROUNDING_GAMMA
105
+ self.zeta = SOFT_ROUNDING_ZETA
106
+
107
+ self.quantizer_parameters = {}
108
+
109
+ def initialize_quantization(self,
110
+ tensor_shape: Any,
111
+ name: str,
112
+ layer: Any):
113
+ """
114
+ Add quantizer parameters to the quantizer parameters dictionary
115
+
116
+ Args:
117
+ tensor_shape: tensor shape of the quantized tensor.
118
+ name: Tensor name.
119
+ layer: Layer to quantize.
120
+ """
121
+
122
+ if self.per_channel:
123
+ reshape_shape = get_threshold_reshape_shape(tensor_shape,
124
+ quant_axis=self.quantization_axis,
125
+ quant_axis_dim=self.num_channels)
126
+ else:
127
+ reshape_shape = [self.num_channels]
128
+
129
+ ptq_threshold_tensor = layer.add_weight(
130
+ f"{name}_{PTQ_THRESHOLD}",
131
+ shape=reshape_shape,
132
+ initializer=tf.keras.initializers.Constant(1.0),
133
+ trainable=False)
134
+ ptq_threshold_tensor.assign(self.threshold_values.reshape(reshape_shape))
135
+
136
+ w = getattr(layer.layer, name)
137
+ auxvar_tensor = layer.add_weight(
138
+ f"{name}_{AUXVAR}",
139
+ shape=list(w.shape),
140
+ initializer=tf.keras.initializers.Constant(0.0),
141
+ trainable=True)
142
+
143
+ delta = qutils.calculate_delta(ptq_threshold_tensor, self.num_bits, signed=True)
144
+ w_floor = tf.floor(w / delta)
145
+ rest = (w / delta) - w_floor # rest of rounding [0, 1)
146
+ # Note that (rest - self.gamma) can't be zero since rest is positive and gamma is negative, so the division
147
+ # is safe
148
+ alpha = -qutils.safe_log((self.zeta - self.gamma) / (rest - self.gamma) - 1, 1e-16) # => sigmoid(alpha) = rest
149
+
150
+ auxvar_tensor.assign(alpha)
151
+
152
+ # Add quantization variables
153
+ self.add_quantizer_variable(AUXVAR, auxvar_tensor, VariableGroup.WEIGHTS)
154
+ self.add_quantizer_variable(PTQ_THRESHOLD, ptq_threshold_tensor, VariableGroup.QPARAMS)
155
+
156
+ if self.quantization_parameter_learning and not self.power_of_two:
157
+ scale = layer.add_weight(
158
+ f"{name}_{SCALE_PTQ}",
159
+ shape=self.num_channels,
160
+ initializer=tf.keras.initializers.Constant(1.0),
161
+ trainable=True)
162
+ self.add_quantizer_variable(SCALE_PTQ, scale, VariableGroup.QPARAMS)
163
+
164
+ def get_soft_targets(self) -> tf.Tensor:
165
+ """
166
+ Computes the rectified sigmoid function for the quantization target parameters.
167
+
168
+ Returns:
169
+ A tensor with the soft rounding targets values.
170
+
171
+ """
172
+ return qutils.clip(
173
+ tf.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma, 1, 0)
174
+
175
+ def __call__(self,
176
+ inputs: tf.Tensor,
177
+ training: bool):
178
+ """
179
+ Quantize a tensor.
180
+
181
+ Args:
182
+ inputs: Input tensor to quantize.
183
+ training: Whether the graph is in training mode.
184
+
185
+ Returns:
186
+ The quantized tensor.
187
+ """
188
+
189
+ ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
190
+
191
+ #####################################################
192
+ # Soft Rounding
193
+ #####################################################
194
+ aux_var = self.get_soft_targets()
195
+ if not training:
196
+ aux_var = tf.cast(tf.math.greater_equal(aux_var, 0.5), tf.float32)
197
+
198
+ if self.per_channel:
199
+ reshape_shape = get_threshold_reshape_shape(inputs.shape,
200
+ quant_axis=self.quantization_axis,
201
+ quant_axis_dim=-1)
202
+
203
+ ##########################################################
204
+ # Calculate soft rounding targets and optimized threshold
205
+ ##########################################################
206
+ ptq_threshold_tensor_hat = tf.reshape(ptq_threshold_tensor, reshape_shape)
207
+
208
+ #####################################################
209
+ # Quantized Input
210
+ #####################################################
211
+ q_tensor = soft_rounding_symmetric_quantizer(input_tensor=inputs,
212
+ auxvar_tensor=aux_var,
213
+ threshold_tensor=ptq_threshold_tensor_hat,
214
+ num_bits=self.num_bits,
215
+ signed=True,
216
+ power_of_two=self.power_of_two)
217
+
218
+ if self.quantization_parameter_learning and not self.power_of_two:
219
+ scale = tf.reshape(self.get_quantizer_variable(SCALE_PTQ), reshape_shape)
220
+ q_tensor *= scale
221
+
222
+ else:
223
+ q_tensor = soft_rounding_symmetric_quantizer(input_tensor=inputs,
224
+ auxvar_tensor=aux_var,
225
+ threshold_tensor=ptq_threshold_tensor.value(),
226
+ num_bits=self.num_bits,
227
+ signed=True,
228
+ power_of_two=self.power_of_two)
229
+
230
+ if self.quantization_parameter_learning and not self.power_of_two:
231
+ scale = self.get_quantizer_variable(SCALE_PTQ)
232
+ q_tensor *= scale
233
+
234
+ return q_tensor
235
+
236
+ def get_quant_config(self) -> Dict[str, np.ndarray]:
237
+ """
238
+ Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
239
+
240
+ Returns:
241
+ A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
242
+ Keys must match NodeQuantizationConfig attributes
243
+ """
244
+
245
+ if self.power_of_two:
246
+ old_threshold = self.get_quantizer_variable(PTQ_THRESHOLD)
247
+ old_threshold = max_power_of_two(old_threshold, MIN_THRESHOLD)
248
+
249
+ else:
250
+ old_threshold = self.get_quantizer_variable(PTQ_THRESHOLD)
251
+ if self.quantization_parameter_learning:
252
+ scale = tf.reshape(self.get_quantizer_variable(SCALE_PTQ), self.threshold_shape)
253
+ old_threshold = old_threshold * scale
254
+ old_threshold = old_threshold.numpy()
255
+ old_threshold = old_threshold.reshape(self.threshold_shape)
256
+ return {THRESHOLD: old_threshold}