mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -1,99 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================\
15
-
16
- from typing import Dict, Any
17
- import numpy as np
18
- import tensorflow as tf
19
- from keras.engine.base_layer import Layer
20
- from tensorflow import TensorShape
21
- from tensorflow_model_optimization.python.core.quantization.keras.quantizers import Quantizer
22
-
23
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
24
-
25
-
26
- class FakeQuantQuantizer(Quantizer):
27
- """
28
- Quantizer using TensorFlow fake quant layer to quantize activations.
29
- """
30
-
31
- def __init__(self,
32
- nbits: int,
33
- min_range: np.ndarray,
34
- max_range: np.ndarray,
35
- quantization_method: QuantizationMethod):
36
- """
37
-
38
- Args:
39
- nbits: Number of bits to quantize.
40
- min_range: Min quantization range.
41
- max_range: Max quantization range.
42
- quantization_method: Quantization method that is used (POT, Uniform, etc.)
43
-
44
- """
45
- self.nbits = nbits
46
- self.min_range = tf.Variable(min_range,
47
- trainable=False,
48
- dtype=tf.float32)
49
- self.max_range = tf.Variable(max_range,
50
- trainable=False,
51
- dtype=tf.float32)
52
- self.quantization_method = quantization_method
53
-
54
-
55
- def get_config(self) -> Dict[str, Any]:
56
- """
57
-
58
- Returns: Configuration of this FakeQuantQuantizer
59
-
60
- """
61
- return {"nbits": self.nbits,
62
- "min_range": self.min_range.numpy(),
63
- "max_range": self.max_range.numpy(),
64
- "quantization_method": self.quantization_method,
65
- }
66
-
67
- def build(self, tensor_shape: TensorShape, name: str, layer: Layer) -> dict:
68
- """
69
- Add variables under layer's scope.
70
-
71
- Args:
72
- tensor_shape: Shape of tensor which needs to be quantized.
73
- name: Name of tensor.
74
- layer: Layer to add variables to.
75
-
76
- Returns:
77
- Dictionary with new layer's variables.
78
- """
79
- return {}
80
-
81
- def __call__(self, inputs, training, weights, **kwargs):
82
- """
83
- Apply quantization to the input tensor.
84
-
85
- Args:
86
- inputs: Input tensor to be quantized.
87
- training: Whether the graph is currently training.
88
- weights: Dictionary of weights the quantizer can use to quantize the tensor. This contains the weights created in the `build` function.
89
- **kwargs: Additional variables which may be passed to the quantizer.
90
-
91
- Returns:
92
- Quantized tensor.
93
- """
94
- with tf.name_scope('FakeQuant'):
95
- return tf.quantization.fake_quant_with_min_max_vars(inputs,
96
- min=self.min_range,
97
- max=self.max_range,
98
- num_bits=self.nbits)
99
-
@@ -1,105 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import numpy as np
17
- import tensorflow as tf
18
- from keras.engine.base_layer import Layer
19
- from tensorflow import TensorShape
20
- from tensorflow_model_optimization.python.core.quantization.keras.quantizers import Quantizer
21
-
22
- from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor, \
23
- fix_range_to_include_zero
24
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
25
-
26
-
27
- class WeightsUniformQuantizer(Quantizer):
28
- """
29
- Quantizer for weights quantization.
30
- """
31
-
32
- def __init__(self,
33
- nbits: int,
34
- min_range: np.ndarray,
35
- max_range: np.ndarray,
36
- weight: np.ndarray,
37
- quantization_method: QuantizationMethod):
38
- """
39
-
40
- Args:
41
- nbits: Number of bits to quantize.
42
- min_range: Min quantization range.
43
- max_range: Max quantization range.
44
- weight: Tensor of weights to quantize.
45
- quantization_method: Quantization method that is used (POT, Uniform, etc.)
46
- """
47
-
48
- super().__init__()
49
-
50
- min_range, max_range = fix_range_to_include_zero(np.array(min_range),
51
- np.array(max_range),
52
- nbits)
53
- self.weight = uniform_quantize_tensor(weight,
54
- range_min=min_range,
55
- range_max=max_range,
56
- n_bits=nbits).astype("float32")
57
- self.nbits = nbits
58
- self.min_range = min_range
59
- self.max_range = max_range
60
- self.delta = (self.max_range - self.min_range) / (2 ** self.nbits - 1)
61
- self.quantization_method = quantization_method
62
-
63
- def __call__(self, inputs, training, weights, **kwargs):
64
- """
65
- Apply quantization to the input tensor.
66
-
67
- Args:
68
- inputs: Input tensor to be quantized.
69
- training: Whether the graph is currently training.
70
- weights: Dictionary of weights the quantizer can use to quantize the tensor. This contains the weights created in the `build` function.
71
- **kwargs: Additional variables which may be passed to the quantizer.
72
-
73
- Returns:
74
- Quantized tensor.
75
- """
76
- with tf.name_scope('WeightsUniformQuant'):
77
- return self.weight
78
-
79
- def get_config(self):
80
- """
81
-
82
- Returns: Configuration of this WeightsUniformQuantizer
83
-
84
- """
85
- cfg = {"nbits": self.nbits,
86
- "min_range": self.min_range,
87
- "max_range": self.max_range,
88
- "weight": self.weight,
89
- "quantization_method": self.quantization_method
90
- }
91
- return cfg
92
-
93
- def build(self, tensor_shape: TensorShape, name: str, layer: Layer) -> dict:
94
- """
95
- Add variables under layer's scope.
96
-
97
- Args:
98
- tensor_shape: Shape of tensor which needs to be quantized.
99
- name: Name of tensor.
100
- layer: Layer to add variables to.
101
-
102
- Returns:
103
- Dictionary with new layer's variables.
104
- """
105
- return {}
@@ -1,61 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from model_compression_toolkit.core.common import BaseNode
16
- from model_compression_toolkit.core.pytorch.back2framework.quantization_wrapper.wrapper_quantize_config import \
17
- WrapperQuantizeConfig
18
- from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizer import \
19
- get_weights_quantizer_for_node, \
20
- get_activations_quantizer_for_node
21
- from model_compression_toolkit.exporter.model_wrapper.pytorch.wrappers_quantize_configs.activation_quantize_config \
22
- import \
23
- ActivationQuantizeConfig
24
- from model_compression_toolkit.exporter.model_wrapper.pytorch.wrappers_quantize_configs \
25
- .no_quantization_quantize_config import \
26
- NoQuantizationQuantizeConfig
27
- from model_compression_toolkit.exporter.model_wrapper.pytorch.wrappers_quantize_configs \
28
- .weights_activation_quantize_config import \
29
- WeightsActivationQuantizeConfig
30
- from model_compression_toolkit.exporter.model_wrapper.pytorch.wrappers_quantize_configs.weights_quantize_config \
31
- import \
32
- WeightsQuantizeConfig
33
-
34
-
35
- def get_quantization_config(node: BaseNode) -> WrapperQuantizeConfig:
36
- """
37
- Create a WrapperQuantizeConfig to wrap a layer for its corresponding node.
38
-
39
- Args:
40
- node: Node to create a WrapperQuantizeConfig for.
41
-
42
- Returns:
43
- WrapperQuantizeConfig to use for wrapping the layer from the passed node.
44
-
45
- """
46
-
47
- if node.is_activation_quantization_enabled() and node.is_weights_quantization_enabled():
48
- weight_quantizers = get_weights_quantizer_for_node(node)
49
- activation_quantizers = get_activations_quantizer_for_node(node)
50
- return WeightsActivationQuantizeConfig(weight_quantizers=weight_quantizers,
51
- activation_quantizers=activation_quantizers)
52
-
53
- elif not node.is_weights_quantization_enabled() and node.is_activation_quantization_enabled():
54
- activation_quantizers = get_activations_quantizer_for_node(node)
55
- return ActivationQuantizeConfig(activation_quantizers=activation_quantizers)
56
-
57
- elif not node.is_weights_quantization_enabled() and not node.is_activation_quantization_enabled():
58
- return NoQuantizationQuantizeConfig()
59
-
60
- weight_quantizers = get_weights_quantizer_for_node(node)
61
- return WeightsQuantizeConfig(weight_quantizers=weight_quantizers)
@@ -1,59 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================\
15
- import numpy as np
16
- from typing import Dict, Any
17
-
18
- from model_compression_toolkit.core.common.constants import RANGE_MIN, RANGE_MAX
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
- from model_compression_toolkit.core.pytorch.quantizer.fake_quant_builder import uniform_quantization
21
-
22
-
23
- class FakeQuantQuantizer:
24
- """
25
- Quantizer using TensorFlow fake quant layer to quantize activations.
26
- """
27
-
28
- def __init__(self,
29
- nbits: int,
30
- min_range: np.ndarray,
31
- max_range: np.ndarray,
32
- quantization_method: QuantizationMethod):
33
- """
34
-
35
- Args:
36
- nbits: Number of bits to quantize.
37
- min_range: Min quantization range.
38
- max_range: Max quantization range.
39
- quantization_method: Quantization method that is used (POT, Uniform, etc.)
40
-
41
- """
42
- self.nbits = nbits
43
- self.min_range = min_range
44
- self.max_range = max_range
45
- self.quantization_method = quantization_method
46
-
47
- def __call__(self, inputs):
48
- """
49
- Apply quantization to the input tensor.
50
-
51
- Args:
52
- inputs: Input tensor to be quantized.
53
-
54
- Returns:
55
- Quantized tensor.
56
- """
57
- return uniform_quantization(self.nbits,
58
- {RANGE_MIN: self.min_range,
59
- RANGE_MAX: self.max_range})(inputs)
@@ -1,67 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import numpy as np
16
- import torch
17
-
18
- from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
-
21
-
22
- class UniformWeightsQuantizer:
23
- """
24
- Uniform quantizer for a PyTorch weights module.
25
- """
26
- def __init__(self,
27
- num_bits: int,
28
- min_range: np.ndarray,
29
- max_range: np.ndarray,
30
- quantization_method: QuantizationMethod,
31
- per_channel: bool,
32
- output_channels_axis: int
33
- ):
34
- """
35
-
36
- Args:
37
- num_bits: Number of bits for quantization.
38
- min_range: Min range for quantization.
39
- max_range: Max range for quantization.
40
- quantization_method: Method the quantized weight was quantized according to.
41
- per_channel: Whether the weight was quantized per-channel or not.
42
- output_channels_axis: Dimension of channel axis (needed if it was quantized per-channel).
43
- """
44
-
45
- super(UniformWeightsQuantizer, self).__init__()
46
- self.num_bits = num_bits
47
- self.min_range = min_range
48
- self.max_range = max_range
49
- self.quantization_method = quantization_method
50
- self.per_channel = per_channel
51
- self.output_channels_axis = output_channels_axis
52
-
53
- def __call__(self, float_weight: np.ndarray, *args, **kwargs) -> np.ndarray:
54
- """
55
- Quantize tensor according to quantization params in the quantizer.
56
-
57
- Args:
58
- float_weight: Weights to quantize.
59
-
60
- Returns:
61
- Quantized tensor.
62
-
63
- """
64
- return uniform_quantize_tensor(tensor_data=float_weight,
65
- range_min=self.min_range,
66
- range_max=self.max_range,
67
- n_bits=self.num_bits)
@@ -1,52 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from typing import List, Callable
16
-
17
- from model_compression_toolkit.core.pytorch.back2framework.quantization_wrapper.wrapper_quantize_config import \
18
- WrapperQuantizeConfig
19
-
20
-
21
- class ActivationQuantizeConfig(WrapperQuantizeConfig):
22
- """
23
- QuantizationConfig for Fully Quantized model to quantize layer's activations only.
24
-
25
- """
26
-
27
- def __init__(self, activation_quantizers: List[Callable]):
28
- """
29
-
30
- Args:
31
- activation_quantizers: List of quantizers for the layer's activation quantization.
32
- """
33
- super().__init__(is_weight_quantized=False,
34
- is_activation_quantized=True)
35
-
36
- self._activation_quantizers = activation_quantizers
37
-
38
- def get_weight_quantizers(self) -> list:
39
- """
40
-
41
- Returns: An empty list as this QC does not quantize the layer's weights.
42
-
43
- """
44
- return []
45
-
46
- def get_activation_quantizers(self):
47
- """
48
-
49
- Returns: List of quantizers to quantize the layer's activations.
50
-
51
- """
52
- return self._activation_quantizers
@@ -1,46 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from model_compression_toolkit.core.pytorch.back2framework.quantization_wrapper.wrapper_quantize_config import \
17
- WrapperQuantizeConfig
18
-
19
-
20
- class NoQuantizationQuantizeConfig(WrapperQuantizeConfig):
21
- """
22
- QuantizationConfig for Fully Quantized model to keep layer's activations and weights un-quantized.
23
- """
24
-
25
- def __init__(self):
26
- super().__init__(is_weight_quantized=False,
27
- is_activation_quantized=False)
28
-
29
- def get_weight_quantizers(self):
30
- """
31
-
32
- Returns: An empty list as this QC does not quantize the layer's weights.
33
-
34
- """
35
- return []
36
-
37
- def get_activation_quantizers(self):
38
- """
39
-
40
- Returns: An empty list as this QC does not quantize the layer's activations.
41
-
42
- """
43
- return []
44
-
45
-
46
-
@@ -1,54 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from typing import Callable, List
16
-
17
- from model_compression_toolkit.core.pytorch.back2framework.quantization_wrapper.wrapper_quantize_config import \
18
- WrapperQuantizeConfig
19
-
20
-
21
- class WeightsActivationQuantizeConfig(WrapperQuantizeConfig):
22
- """
23
- QuantizationConfig for Fully Quantized model to quantize layer's activations and weights.
24
- """
25
-
26
- def __init__(self,
27
- weight_quantizers: List[Callable],
28
- activation_quantizers: List[Callable]):
29
- """
30
- Args:
31
- weight_quantizers: List of quantizers to quantize the layer's weights.
32
- activation_quantizers: List of quantizers to quantize the layer's activations.
33
- """
34
- super().__init__(is_weight_quantized=True,
35
- is_activation_quantized=True)
36
-
37
- self._weight_quantizers = weight_quantizers
38
- self._activation_quantizers = activation_quantizers
39
-
40
- def get_weight_quantizers(self) -> List[Callable]:
41
- """
42
-
43
- Returns: List of quantizers to quantize the layer's weights.
44
-
45
- """
46
- return self._weight_quantizers
47
-
48
- def get_activation_quantizers(self) -> List[Callable]:
49
- """
50
-
51
- Returns: List of quantizers to quantize the layer's activations.
52
-
53
- """
54
- return self._activation_quantizers
@@ -1,52 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from typing import List, Callable
16
-
17
- from model_compression_toolkit.core.pytorch.back2framework.quantization_wrapper.wrapper_quantize_config import \
18
- WrapperQuantizeConfig
19
-
20
-
21
- class WeightsQuantizeConfig(WrapperQuantizeConfig):
22
- """
23
- QuantizationConfig for Fully Quantized model to quantize layer's weights only.
24
- """
25
-
26
- def __init__(self,
27
- weight_quantizers: List[Callable]):
28
- """
29
-
30
- Args:
31
- weight_quantizers: List of quantizers to quantize the layer's weights.
32
- """
33
- super().__init__(is_weight_quantized=True,
34
- is_activation_quantized=False)
35
-
36
- self._weight_quantizers = weight_quantizers
37
-
38
- def get_weight_quantizers(self) -> List[Callable]:
39
- """
40
-
41
- Returns: List of quantizers to quantize the layer's weights.
42
-
43
- """
44
- return self._weight_quantizers
45
-
46
- def get_activation_quantizer(self):
47
- """
48
-
49
- Returns: An empty list as this QC does not quantize the layer's activations.
50
-
51
- """
52
- return []