mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -1,66 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from typing import Any
16
-
17
- from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_configs import \
18
- NoOpQuantizeConfig
19
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_config import QuantizeConfig
20
-
21
- from model_compression_toolkit.core.common import BaseNode
22
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
23
- from model_compression_toolkit.exporter.model_wrapper.keras.builder.quantizer_to_node import \
24
- get_weights_quantizer_for_node, get_activations_quantizer_for_node
25
- from model_compression_toolkit.exporter.model_wrapper.keras.quantize_configs.activation_quantize_config import \
26
- ActivationQuantizeConfig
27
- from model_compression_toolkit.exporter.model_wrapper.keras.quantize_configs.weights_activation_quantize_config \
28
- import \
29
- WeightsActivationQuantizeConfig
30
- from model_compression_toolkit.exporter.model_wrapper.keras.quantize_configs.weights_quantize_config import \
31
- WeightsQuantizeConfig
32
-
33
- SUPPORTED_QUANTIZATION_CONFIG = [WeightsQuantizeConfig,
34
- ActivationQuantizeConfig,
35
- NoOpQuantizeConfig,
36
- WeightsActivationQuantizeConfig]
37
-
38
-
39
- def get_quantization_config(node: BaseNode) -> QuantizeConfig:
40
- """
41
- Create a QuantizeConfig to wrap a layer for its corresponding node.
42
-
43
- Args:
44
- node: Node to create a QuantizeConfig for.
45
-
46
- Returns:
47
- QuantizeConfig to use for wrapping the layer from the passed node.
48
- """
49
-
50
- if node.is_weights_quantization_enabled() and not node.is_activation_quantization_enabled():
51
- weight_attrs = DEFAULT_KERAS_INFO.get_kernel_op_attributes(node.type)
52
- return WeightsQuantizeConfig(weight_attrs=weight_attrs,
53
- w_quantizer=get_weights_quantizer_for_node(node,
54
- weight_attrs))
55
-
56
- elif not node.is_weights_quantization_enabled() and node.is_activation_quantization_enabled():
57
- return ActivationQuantizeConfig(activation_quantizer=get_activations_quantizer_for_node(node))
58
-
59
- elif not node.is_weights_quantization_enabled() and not node.is_activation_quantization_enabled():
60
- return NoOpQuantizeConfig()
61
-
62
- weight_attrs = DEFAULT_KERAS_INFO.get_kernel_op_attributes(node.type)
63
- return WeightsActivationQuantizeConfig(activation_quantizer=get_activations_quantizer_for_node(node),
64
- w_quantizer=get_weights_quantizer_for_node(node,
65
- weight_attrs),
66
- weight_attrs=DEFAULT_KERAS_INFO.get_kernel_op_attributes(node.type))
@@ -1,134 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import numpy as np
16
- from tensorflow_model_optimization.python.core.quantization.keras.quantizers import Quantizer
17
- from typing import List
18
-
19
- from model_compression_toolkit.core.common import BaseNode, Logger
20
- from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX
21
- from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import calculate_delta
22
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
23
- from model_compression_toolkit.exporter.model_wrapper.keras.quantizers.fq_quantizer import FakeQuantQuantizer
24
- from model_compression_toolkit.exporter.model_wrapper.keras.quantizers.weights_uniform_quantizer import \
25
- WeightsUniformQuantizer
26
-
27
- # Supporting other quantizer types in the future
28
- SUPPORTED_WEIGHT_QUANTIZER_TYPES = [QuantizationMethod.POWER_OF_TWO,
29
- QuantizationMethod.SYMMETRIC,
30
- QuantizationMethod.UNIFORM]
31
-
32
- SUPPORTED_ACTIVATION_QUANTIZER_TYPES = [QuantizationMethod.POWER_OF_TWO,
33
- QuantizationMethod.SYMMETRIC,
34
- QuantizationMethod.UNIFORM]
35
-
36
-
37
- def get_weights_quantizer_for_node(node: BaseNode, weights_attr: List[str]) -> Quantizer:
38
- """
39
- Get weights quantizer for a node.
40
-
41
- Args:
42
- node: Node to create a weight quantizer for.
43
- weights_attr: Attributes of the layer to quantize its weights.
44
-
45
- Returns:
46
- Quantizer for the node's weights.
47
-
48
- """
49
- if node.final_weights_quantization_cfg is None:
50
- Logger.critical(f'Can not set quantizer for a node with no final weights quantization configuration')
51
-
52
- node_w_qc = node.final_weights_quantization_cfg
53
- weights_quantization_method = node_w_qc.weights_quantization_method
54
-
55
- if weights_quantization_method not in SUPPORTED_WEIGHT_QUANTIZER_TYPES:
56
- Logger.error(f'Fully quantized models are now supported for {SUPPORTED_WEIGHT_QUANTIZER_TYPES} quantization methods, but node has {weights_quantization_method} quantization method')
57
-
58
- weight_thresholds = node_w_qc.weights_quantization_params.get(THRESHOLD)
59
-
60
- # Compute quantizer params based on node's quantization params
61
- if weights_quantization_method in [QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC]:
62
- if weights_quantization_method == QuantizationMethod.POWER_OF_TWO:
63
- is_threshold_pot = np.all([int(np.log2(x)) == np.log2(x) for x in weight_thresholds.flatten()])
64
- if not is_threshold_pot:
65
- Logger.error(f'Expected threshold to be power of 2 but is {weight_thresholds}')
66
-
67
- min_range = -weight_thresholds
68
- max_range = weight_thresholds - calculate_delta(weight_thresholds,
69
- n_bits=node_w_qc.weights_n_bits,
70
- signed=True)
71
-
72
- elif weights_quantization_method in [QuantizationMethod.UNIFORM]:
73
- min_range = node_w_qc.weights_quantization_params.get(RANGE_MIN)
74
- max_range = node_w_qc.weights_quantization_params.get(RANGE_MAX)
75
-
76
- else:
77
- Logger.error(f'For now fully quantized models support only {SUPPORTED_WEIGHT_QUANTIZER_TYPES} for weights quantization, but found {weights_quantization_method}')
78
-
79
- if len(weights_attr) > 1:
80
- Logger.error(f'Currently, we support only one quantized weight per layer, but received {len(weights_attr)} attributes to quantize')
81
-
82
- return WeightsUniformQuantizer(nbits=node_w_qc.weights_n_bits,
83
- min_range=min_range,
84
- max_range=max_range,
85
- weight=node.get_weights_by_keys(weights_attr[0]),
86
- quantization_method=weights_quantization_method)
87
-
88
-
89
- def get_activations_quantizer_for_node(node: BaseNode) -> Quantizer:
90
- """
91
- Get activation quantizer for a node.
92
-
93
- Args:
94
- node: Node to create an activation quantizer for.
95
-
96
- Returns:
97
- Quantizer for the node's activations.
98
-
99
- """
100
-
101
- if node.final_activation_quantization_cfg is None:
102
- Logger.critical(f'Can not set quantizer for a node with no final activation quantization configuration')
103
-
104
- node_act_qc = node.final_activation_quantization_cfg
105
- activation_quantization_method = node_act_qc.activation_quantization_method
106
-
107
- if activation_quantization_method not in SUPPORTED_ACTIVATION_QUANTIZER_TYPES:
108
- Logger.error(
109
- f'Fully quantized models are now supported for {SUPPORTED_ACTIVATION_QUANTIZER_TYPES} quantization methods, '
110
- f'but node has {activation_quantization_method} quantization method')
111
-
112
- activation_thresholds = node_act_qc.activation_quantization_params.get(THRESHOLD)
113
-
114
- if activation_quantization_method in [QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC]:
115
- if activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
116
- is_threshold_pot = np.all([int(np.log2(x)) == np.log2(x) for x in activation_thresholds.flatten()])
117
- if not is_threshold_pot:
118
- Logger.error(f'Expected threshold to be power of 2 but is {node_act_qc.activation_quantization_params.get(THRESHOLD)}')
119
-
120
- min_range = 0
121
- if node_act_qc.activation_quantization_params.get(SIGNED):
122
- min_range = -activation_thresholds
123
-
124
- max_range = activation_thresholds - calculate_delta(
125
- activation_thresholds,
126
- n_bits=node_act_qc.activation_n_bits,
127
- signed=node_act_qc.activation_quantization_params.get(SIGNED))
128
- else:
129
- Logger.error(f'For now fully quantized models support only {SUPPORTED_ACTIVATION_QUANTIZER_TYPES} for activation quantization, but found {activation_quantization_method}')
130
-
131
- return FakeQuantQuantizer(nbits=node_act_qc.activation_n_bits,
132
- min_range=min_range,
133
- max_range=max_range,
134
- quantization_method=activation_quantization_method)
@@ -1,81 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import Dict, Tuple, Any
17
-
18
- from keras.engine.base_layer import Layer
19
- from keras.layers import TFOpLambda
20
- from tensorflow import TensorShape
21
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_config import QuantizeConfig
22
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
23
-
24
-
25
- class ExtendedQuantizeWrapper(QuantizeWrapper):
26
-
27
- """Quantizes the weights and activations of the keras layer it wraps, according
28
- to the quantization config that is passed. This class was created to deal with TFOpLambda that can
29
- not use TF QuantizeWrapper since it does not implement compute_output_shape.
30
- Notice that reused layers do not have a compute_output_shape method, thus the added method
31
- is irrelevant for wrapping them."""
32
-
33
- def __init__(self,
34
- layer:Layer,
35
- quantize_config:QuantizeConfig,
36
- output_shape:Tuple[Any]=None,
37
- **kwargs:Dict[str, Any]):
38
- """
39
- Create a wrapper for a keras layer.
40
-
41
- Args:
42
- layer: The keras layer to be quantized.
43
- quantize_config: `QuantizeConfig` to quantize the layer.
44
- output_shape: The output shape of the layer.
45
- **kwargs: Keyword arguments to build the base class QuantizeWrapper.
46
- """
47
-
48
- # TFOpLambda does not implement the method compute_output_shape which is mandatory for cloning the model
49
- # and use TF transformations. For this reason, we add the output_shape in the layer configuration and
50
- # add an implementation for compute_output_shape.
51
- self._output_shape = output_shape
52
- if isinstance(layer, TFOpLambda):
53
- layer.compute_output_shape = self._compute_output_shape
54
-
55
- super(ExtendedQuantizeWrapper, self).__init__(layer=layer,
56
- quantize_config=quantize_config,
57
- **kwargs)
58
-
59
- def _compute_output_shape(self, input_shape: TensorShape) -> Tuple[Any]:
60
- """
61
- Internal method that returns the output shape of the layer.
62
-
63
- Args:
64
- input_shape: Input shape the layer is expecting to have.
65
-
66
- Returns:
67
- The layer's output shape.
68
- """
69
-
70
- return self._output_shape
71
-
72
- def get_config(self) -> Dict[str, Any]:
73
- """
74
-
75
- Returns: The layer configuration with the output shape, so it can be deserialized.
76
-
77
- """
78
- cfg = super(ExtendedQuantizeWrapper, self).get_config()
79
- cfg.update({'output_shape': self._output_shape})
80
- return cfg
81
-
@@ -1,81 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import List, Tuple, Any, Dict
17
-
18
- import tensorflow as tf
19
- from tensorflow import Tensor
20
- from tensorflow_model_optimization.python.core.quantization.keras.quantizers import Quantizer
21
- from packaging import version
22
-
23
- # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
24
- if version.parse(tf.__version__) < version.parse("2.6"):
25
- from tensorflow.python.keras.layers import Layer
26
- else:
27
- from keras.engine.base_layer import Layer
28
- from tensorflow.python.training.tracking.data_structures import ListWrapper
29
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_config import QuantizeConfig
30
-
31
-
32
- class ActivationQuantizeConfig(QuantizeConfig):
33
- """
34
- QuantizeConfig to quantize a layer's activations.
35
- """
36
-
37
- def __init__(self,
38
- activation_quantizer: Quantizer):
39
- """
40
-
41
- Args:
42
- activation_quantizer: Quantizer for quantization the layer's activations.
43
- """
44
-
45
- self.activation_quantizer = activation_quantizer
46
-
47
-
48
- def get_config(self) -> Dict[str, Any]:
49
- """
50
-
51
- Returns: Configuration of ActivationQuantizeConfig
52
-
53
- """
54
- return {
55
- 'activation_quantizer': self.activation_quantizer}
56
-
57
- def get_weights_and_quantizers(self, layer: Layer) -> List[Tuple[Tensor, Any]]:
58
- return []
59
-
60
- def get_activations_and_quantizers(self, layer: Layer) -> list:
61
- # For configurable activations we use get_output_quantizers,
62
- # Therefore, we do not need to implement this method.
63
- return []
64
-
65
- def set_quantize_weights(self, layer: Layer, quantize_weights: List[Tensor]):
66
- pass
67
-
68
- def set_quantize_activations(self, layer, quantize_activations: ListWrapper):
69
- pass
70
-
71
- def get_output_quantizers(self, layer: Layer) -> List[Quantizer]:
72
- """
73
- Quantize layer's outputs.
74
-
75
- Args:
76
- layer: Layer to quantize its activations.
77
-
78
- Returns: List of activation quantizers.
79
-
80
- """
81
- return [self.activation_quantizer]
@@ -1,128 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import List, Tuple, Any, Dict
17
-
18
- import tensorflow as tf
19
- from tensorflow import Tensor
20
- from tensorflow_model_optimization.python.core.quantization.keras.quantizers import Quantizer
21
-
22
- # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
23
- from model_compression_toolkit.exporter.model_wrapper.keras.quantize_configs.activation_quantize_config import \
24
- ActivationQuantizeConfig
25
- from model_compression_toolkit.exporter.model_wrapper.keras.quantize_configs.weights_quantize_config import \
26
- WeightsQuantizeConfig
27
- from packaging import version
28
-
29
- if version.parse(tf.__version__) < version.parse("2.6"):
30
- from tensorflow.python.keras.layers import Layer
31
- else:
32
- from keras.engine.base_layer import Layer
33
- from tensorflow.python.training.tracking.data_structures import ListWrapper
34
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_config import QuantizeConfig
35
-
36
-
37
- class WeightsActivationQuantizeConfig(QuantizeConfig):
38
- """
39
- QuantizeConfig to quantize a layer's activations and weights.
40
- """
41
-
42
- def __init__(self,
43
- activation_quantizer: Quantizer,
44
- w_quantizer: Quantizer,
45
- weight_attrs: List[str] = None):
46
- """
47
-
48
- Args:
49
- activation_quantizer: Quantizer for activations.
50
- w_quantizer: Quantizer for weights.
51
- weight_attrs: Weights attributes to quantize.
52
- """
53
- self.act_config = ActivationQuantizeConfig(activation_quantizer=activation_quantizer)
54
- self.weights_config = WeightsQuantizeConfig(w_quantizer=w_quantizer,
55
- weight_attrs=weight_attrs)
56
-
57
-
58
- def get_config(self) -> Dict[str,Any]:
59
- """
60
-
61
- Returns: Configuration of WeightsActivationQuantizeConfig
62
-
63
- """
64
- return {"activation_quantizer": self.act_config.activation_quantizer,
65
- "w_quantizer": self.weights_config.w_quantizer,
66
- "weight_attrs": self.weights_config.weight_attrs}
67
-
68
- def get_weights_and_quantizers(self, layer: Layer) -> List[Tuple[Tensor, Any]]:
69
- """
70
- Get the layer's weights to quantize and quantizers.
71
-
72
- Args:
73
- layer: Layer wrapped with this WeightsQuantizeConfig
74
-
75
- Returns:
76
- List of weights and quantizers to quantize these weights.
77
- """
78
- return self.weights_config.get_weights_and_quantizers(layer)
79
-
80
- def get_activations_and_quantizers(self, layer: Layer) -> list:
81
- """
82
- Get the layer's activations to quantize and quantizers.
83
-
84
- Args:
85
- layer: Layer wrapped with this WeightsActivationQuantizeConfig
86
-
87
- Returns:
88
- List of activation tensors and quantizers to quantize them.
89
- """
90
- return self.act_config.get_activations_and_quantizers(layer)
91
-
92
- def set_quantize_weights(self, layer: Layer, quantize_weights: List[Tensor]):
93
- """
94
- Set layer's weights with quantized weights.
95
-
96
- Args:
97
- layer: Layer wrapped with this WeightsQuantizeConfig
98
- quantize_weights: Quantized weights to set to the layer
99
-
100
- Returns:
101
- None
102
- """
103
- self.weights_config.set_quantize_weights(layer, quantize_weights)
104
-
105
- def set_quantize_activations(self, layer, quantize_activations: ListWrapper):
106
- """
107
- Set layer's activations with quantized activations.
108
-
109
- Args:
110
- layer: Layer wrapped with this WeightsActivationQuantizeConfig
111
- quantize_activations: Quantized activation to set to the layer
112
-
113
- Returns:
114
- None
115
- """
116
- self.act_config.set_quantize_activations(layer, quantize_activations)
117
-
118
- def get_output_quantizers(self, layer: Layer) -> List[Quantizer]:
119
- """
120
- Quantize layer's outputs.
121
-
122
- Args:
123
- layer: Layer to quantize its activations.
124
-
125
- Returns: List of activation quantizers.
126
-
127
- """
128
- return self.act_config.get_output_quantizers(layer)
@@ -1,107 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import List, Tuple, Any, Dict
17
-
18
- import tensorflow as tf
19
- from tensorflow import Tensor
20
- from tensorflow_model_optimization.python.core.quantization.keras.quantizers import Quantizer
21
- from packaging import version
22
-
23
- # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
24
- if version.parse(tf.__version__) < version.parse("2.6"):
25
- from tensorflow.python.keras.layers import Layer
26
- else:
27
- from keras.engine.base_layer import Layer
28
- from tensorflow.python.training.tracking.data_structures import ListWrapper
29
- from tensorflow_model_optimization.python.core.quantization.keras.quantize_config import QuantizeConfig
30
-
31
-
32
- class WeightsQuantizeConfig(QuantizeConfig):
33
- """
34
- QuantizeConfig to quantize a layer's weights.
35
- """
36
-
37
- def __init__(self,
38
- w_quantizer: Quantizer,
39
- weight_attrs: List[str] = None):
40
- """
41
-
42
- Args:
43
- w_quantizer: Quantizer for weights.
44
- weight_attrs: Weights attributes to quantize.
45
- """
46
-
47
- self.weight_attrs = weight_attrs
48
- self.w_quantizer = w_quantizer
49
-
50
- def get_config(self) -> Dict[str, Any]:
51
- """
52
-
53
- Returns: Configuration of WeightsQuantizeConfig
54
-
55
- """
56
- return {'w_quantizer': self.w_quantizer,
57
- 'weight_attrs': self.weight_attrs}
58
-
59
- def get_weights_and_quantizers(self, layer: Layer) -> List[Tuple[Tensor, Any]]:
60
- """
61
- Get the layer's weights to quantize and quantizers.
62
-
63
- Args:
64
- layer: Layer wrapped with this WeightsQuantizeConfig
65
-
66
- Returns:
67
- List of weights and quantizers to quantize these weights.
68
- """
69
- return [(getattr(layer, self.weight_attrs[i]),
70
- self.w_quantizer) for i in range(len(self.weight_attrs))]
71
-
72
- def get_activations_and_quantizers(self, layer: Layer) -> list:
73
- # For configurable activations we use get_output_quantizers,
74
- # Therefore, we do not need to implement this method.
75
- return []
76
-
77
- def set_quantize_weights(self, layer: Layer, quantize_weights: List[Tensor]):
78
- """
79
- Set layer's weights with quantized weights.
80
-
81
- Args:
82
- layer: Layer wrapped with this WeightsQuantizeConfig
83
- quantize_weights: Quantized weights to set to the layer
84
-
85
- Returns:
86
- None
87
- """
88
- if len(self.weight_attrs) != len(quantize_weights):
89
- raise ValueError(
90
- '`set_quantize_weights` called on layer {} with {} '
91
- 'weight parameters, but layer expects {} values.'.format(
92
- layer.name, len(quantize_weights), len(self.weight_attrs)))
93
-
94
- for weight_attr, weight in zip(self.weight_attrs, quantize_weights):
95
- current_weight = getattr(layer, weight_attr)
96
- if current_weight.shape != weight.shape:
97
- raise ValueError('Existing layer weight shape {} is incompatible with'
98
- 'provided weight shape {}'.format(
99
- current_weight.shape, weight.shape))
100
-
101
- setattr(layer, weight_attr, weight)
102
-
103
- def set_quantize_activations(self, layer, quantize_activations: ListWrapper):
104
- pass
105
-
106
- def get_output_quantizers(self, layer: Layer) -> List[Quantizer]:
107
- return []