mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -60,7 +60,8 @@ def lut_kmeans_tensor(tensor_data: np.ndarray,
60
60
  the thresholds per channel and the multiplier num bits.
61
61
  """
62
62
  if n_bits > MULTIPLIER_N_BITS:
63
- Logger.critical(f'Look-Up-Table bit configuration has {n_bits} bits, but must be less or equal to {MULTIPLIER_N_BITS}')
63
+ Logger.critical(f'Look-Up-Table bit configuration has {n_bits} bits, but must be less or equal to '
64
+ f'{MULTIPLIER_N_BITS}') # pragma: no cover
64
65
  # TODO: need to set this externally
65
66
  if len(np.unique(tensor_data.flatten())) < 2 ** n_bits:
66
67
  n_clusters = len(np.unique(tensor_data.flatten()))
@@ -115,7 +116,8 @@ def lut_kmeans_histogram(bins: np.ndarray,
115
116
  """
116
117
 
117
118
  if n_bits >= MULTIPLIER_N_BITS:
118
- Logger.critical(f'Look-Up-Table bit configuration has {n_bits} bits. It must be less then {MULTIPLIER_N_BITS}')
119
+ Logger.critical(f'Look-Up-Table bit configuration has {n_bits} bits. It must be less then '
120
+ f'{MULTIPLIER_N_BITS}') # pragma: no cover
119
121
 
120
122
  bins_with_values = np.abs(bins)[1:][counts > 0]
121
123
  if len(np.unique(bins_with_values.flatten())) < 2 ** n_bits:
@@ -49,25 +49,22 @@ def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConf
49
49
  bins_counts)
50
50
  min_value, max_value = out_stats_container.get_min_max_values()
51
51
 
52
- if nodes_prior_info is not None:
53
- if nodes_prior_info.is_output_bounded():
54
- signed = min_value < 0
55
- else:
56
- signed = np.any(bins_values[:-1][bins_counts > 0] < 0)
57
-
58
- if nodes_prior_info.is_output_bounded():
59
- if activation_quant_cfg.activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
60
- activation_quant_cfg.activation_quantization_params_fn = \
61
- quantization_params_generation.power_of_two_no_clipping_selection_min_max
62
- elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.SYMMETRIC:
63
- activation_quant_cfg.activation_quantization_params_fn = \
64
- quantization_params_generation.symmetric_no_clipping_selection_min_max
65
- elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.UNIFORM:
66
- activation_quant_cfg.activation_quantization_params_fn = \
67
- quantization_params_generation.uniform_no_clipping_selection_min_max
52
+ if nodes_prior_info.is_output_bounded():
53
+ signed = min_value < 0
68
54
  else:
69
55
  signed = np.any(bins_values[:-1][bins_counts > 0] < 0)
70
56
 
57
+ if nodes_prior_info.is_output_bounded():
58
+ if activation_quant_cfg.activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
59
+ activation_quant_cfg.activation_quantization_params_fn = \
60
+ quantization_params_generation.power_of_two_no_clipping_selection_min_max
61
+ elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.SYMMETRIC:
62
+ activation_quant_cfg.activation_quantization_params_fn = \
63
+ quantization_params_generation.symmetric_no_clipping_selection_min_max
64
+ elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.UNIFORM:
65
+ activation_quant_cfg.activation_quantization_params_fn = \
66
+ quantization_params_generation.uniform_no_clipping_selection_min_max
67
+
71
68
  activation_params = activation_quant_cfg.activation_quantization_params_fn(bins_values,
72
69
  bins_counts,
73
70
  activation_quant_cfg.l_p_value,
@@ -78,4 +75,4 @@ def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConf
78
75
  quant_error_method=activation_quant_cfg.activation_error_method)
79
76
  activation_params.update({SIGNED: signed})
80
77
 
81
- return activation_params
78
+ return activation_params
@@ -18,7 +18,7 @@ from typing import Tuple, List
18
18
  import numpy as np
19
19
 
20
20
  from model_compression_toolkit.core.common.constants import MIN_THRESHOLD, EPS
21
-
21
+ from model_compression_toolkit.core import common
22
22
 
23
23
  def max_power_of_two(x: np.ndarray,
24
24
  min_threshold: float = MIN_THRESHOLD) -> np.ndarray:
@@ -235,7 +235,14 @@ def get_tensor_max(tensor_data: np.ndarray,
235
235
  Returns: maximal value (or values).
236
236
 
237
237
  """
238
- expansion_factor = 1.0 if is_uniform_quantization else np.power(2.0, n_bits - 1) / (np.power(2.0, n_bits - 1) - 1)
238
+ if n_bits < 1:
239
+ common.Logger.error("n_bits must be positive")
240
+ if is_uniform_quantization:
241
+ expansion_factor = 1.0
242
+ elif n_bits == 1:
243
+ expansion_factor = 0.0
244
+ else:
245
+ expansion_factor = np.power(2.0, n_bits - 1) / (np.power(2.0, n_bits - 1) - 1)
239
246
  if per_channel:
240
247
  output_shape = get_output_shape(tensor_data.shape, channel_axis)
241
248
  reshaped_tensor_data = reshape_tensor_for_per_channel_search(tensor_data, channel_axis)
@@ -15,6 +15,7 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
+ from model_compression_toolkit.core.common.logger import Logger
18
19
  from model_compression_toolkit.core.common.constants import RANGE_MIN, RANGE_MAX, THRESHOLD
19
20
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor, \
20
21
  quantize_tensor
@@ -51,9 +52,9 @@ def power_of_two_quantizer(tensor_data: np.ndarray,
51
52
  """
52
53
  threshold = quantization_params.get(THRESHOLD)
53
54
  if threshold is None:
54
- raise Exception(f"{THRESHOLD} parameter must be defined in 'quantization_params'")
55
+ Logger.error(f"{THRESHOLD} parameter must be defined in 'quantization_params'") # pragma: no cover
55
56
  if not threshold_is_power_of_two(threshold, per_channel):
56
- raise Exception(f"Expects {THRESHOLD} parameter to be a power of two, but got {threshold}")
57
+ Logger.error(f"Expects {THRESHOLD} parameter to be a power of two, but got {threshold}") # pragma: no cover
57
58
 
58
59
  return quantize_tensor(tensor_data,
59
60
  threshold,
@@ -84,7 +85,7 @@ def symmetric_quantizer(tensor_data: np.ndarray,
84
85
  """
85
86
  threshold = quantization_params.get(THRESHOLD)
86
87
  if threshold is None:
87
- raise Exception(f"{THRESHOLD} parameter must be defined in 'quantization_params'")
88
+ Logger.error(f"{THRESHOLD} parameter must be defined in 'quantization_params'") # pragma: no cover
88
89
 
89
90
  return quantize_tensor(tensor_data,
90
91
  threshold,
@@ -115,6 +116,6 @@ def uniform_quantizer(tensor_data: np.ndarray,
115
116
  range_min = quantization_params.get(RANGE_MIN)
116
117
  range_max = quantization_params.get(RANGE_MAX)
117
118
  if range_min is None or range_max is None:
118
- raise Exception("'quantization range' parameters must be defined in 'quantization_params'")
119
+ Logger.error("'quantization range' parameters must be defined in 'quantization_params'") # pragma: no cover
119
120
 
120
121
  return uniform_quantize_tensor(tensor_data, range_min, range_max, n_bits)
@@ -108,7 +108,7 @@ def create_node_activation_qc(qc: QuantizationConfig,
108
108
 
109
109
  activation_quantization_fn = fw_info.activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
110
110
  if activation_quantization_fn is None:
111
- Logger.critical('Unknown quantization method for activations')
111
+ Logger.critical('Unknown quantization method for activations') # pragma: no cover
112
112
 
113
113
  activation_quantization_params_fn = get_activation_quantization_params_fn(op_cfg.activation_quantization_method)
114
114
 
@@ -142,14 +142,14 @@ def create_node_qc_candidate(qc: QuantizationConfig,
142
142
  weights_quantization_fn = get_weights_quantization_fn(op_cfg.weights_quantization_method)
143
143
 
144
144
  if weights_quantization_fn is None:
145
- Logger.critical('Unknown quantization method for weights')
145
+ Logger.critical('Unknown quantization method for weights') # pragma: no cover
146
146
 
147
147
  weights_quantization_params_fn = get_weights_quantization_params_fn(op_cfg.weights_quantization_method)
148
148
 
149
149
  # get attributes for activation quantization
150
150
  activation_quantization_fn = fw_info.activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
151
151
  if activation_quantization_fn is None:
152
- Logger.critical('Unknown quantization method for activations')
152
+ Logger.critical('Unknown quantization method for activations') # pragma: no cover
153
153
 
154
154
  activation_quantization_params_fn = get_activation_quantization_params_fn(op_cfg.activation_quantization_method)
155
155
 
@@ -77,6 +77,13 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
77
77
  num_nodes_before_substitution = len(graph.nodes)
78
78
  num_edges_before_substitution = len(graph.edges)
79
79
 
80
+ # If the linear operator is part of a reused group (it is the "base" node, or a reused node),
81
+ # we should skip the substitution.
82
+ if source_node.reuse or source_node.reuse_group is not None:
83
+ for qc in source_node.candidates_quantization_cfg:
84
+ qc.weights_quantization_cfg.weights_second_moment_correction = False
85
+ return graph
86
+
80
87
  # We apply only on nodes with folded BatchNormalization.
81
88
  if source_node.prior_info.std_output is None or source_node.prior_info.mean_output is None:
82
89
  for qc in source_node.candidates_quantization_cfg:
@@ -24,6 +24,7 @@ from model_compression_toolkit.core.common.graph.graph_matchers import EdgeMatch
24
24
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
25
25
  from model_compression_toolkit.core.common.target_platform import QuantizationMethod
26
26
  from model_compression_toolkit.core.common.constants import THRESHOLD, RANGE_MIN, RANGE_MAX
27
+ from model_compression_toolkit.core.common.logger import Logger
27
28
 
28
29
 
29
30
  class BatchNormalizationRefusing(common.BaseSubstitution):
@@ -95,15 +96,22 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
95
96
 
96
97
  source_node = edge_nodes[0]
97
98
 
99
+ # We apply only on nodes with reconstructed BatchNormalization.
100
+ if not source_node.final_weights_quantization_cfg.weights_second_moment_correction:
101
+ return graph
102
+
98
103
  # If the linear operator is part of a reused group (it is the "base" node, or a reused node),
99
104
  # we should skip the substitution.
100
105
  if source_node.reuse or source_node.reuse_group is not None:
101
- return graph
106
+ Logger.exception("If the linear operator is part of a reused group we should skip the the BN folding "
107
+ "substitution and SMC feature") # pragma: no cover
102
108
 
103
109
  bn_node = edge_nodes[1]
104
110
 
105
111
  if len(graph.get_next_nodes(source_node)) > 1 or len(graph.get_prev_nodes(bn_node)) > 1:
106
- return graph
112
+ Logger.exception(
113
+ "If the linear operator has multiple outputs or the bn layer has multiple inputs we should "
114
+ "skip the the BN folding substitution and SMC feature") # pragma: no cover
107
115
 
108
116
  kernel = source_node.get_weights_by_keys(self.kernel_str)
109
117
  bias = source_node.get_weights_by_keys(self.bias_str)
@@ -113,9 +121,6 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
113
121
  moving_variance = bn_node.get_weights_by_keys(self.moving_variance_str)
114
122
  eps = bn_node.framework_attr[self.epsilon_str]
115
123
 
116
- if bias is None:
117
- bias = 0.0
118
-
119
124
  weights_scale = gamma / np.sqrt(moving_variance + eps)
120
125
  bias = beta + (bias - moving_mean) * weights_scale
121
126
 
@@ -177,7 +182,7 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
177
182
  corr_dict[THRESHOLD] = corr_threshold
178
183
  conv_bn.final_weights_quantization_cfg.set_weights_quantization_param(corr_dict)
179
184
 
180
- # In case of SYMMETRIC weight quantization method, we update the range_min, range_max by weights_scale
185
+ # In case of UNIFORM weight quantization method, we update the range_min, range_max by weights_scale
181
186
  elif conv_bn.final_weights_quantization_cfg.weights_quantization_method == QuantizationMethod.UNIFORM:
182
187
  corr_dict = copy.deepcopy(conv_bn.final_weights_quantization_cfg.weights_quantization_params)
183
188
  original_range_min = conv_bn.final_weights_quantization_cfg.weights_quantization_params[RANGE_MIN]
@@ -189,5 +194,5 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
189
194
  conv_bn.final_weights_quantization_cfg.set_weights_quantization_param(corr_dict)
190
195
 
191
196
  else:
192
- raise Exception("Second moment statistics correction feature disabled for models with weights "
193
- "quantization method of Power of 2")
197
+ Logger.exception("Second moment statistics correction feature disabled for models with weights "
198
+ "quantization method of Power of 2") # pragma: no cover
@@ -16,6 +16,7 @@ import copy
16
16
  import numpy as np
17
17
  from typing import List, Tuple, Any, Callable
18
18
 
19
+ from model_compression_toolkit.core.common.logger import Logger
19
20
  from model_compression_toolkit.core.common import FrameworkInfo, Graph, BaseNode
20
21
  from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
21
22
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
@@ -25,7 +26,8 @@ from model_compression_toolkit.core.common.quantization.set_node_quantization_co
25
26
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
26
27
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
27
28
  import get_activations_qparams
28
- from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import _mse_error_histogram
29
+ from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
30
+ _mse_error_histogram
29
31
  from model_compression_toolkit.core.common.quantization.quantization_params_generation import z_score_filter
30
32
 
31
33
  """
@@ -73,12 +75,12 @@ def op2d_bias_correction(op2d_node: BaseNode,
73
75
 
74
76
  # special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters
75
77
  if output_channel_index == input_channel_index:
76
- axis_not_output_channel.remove(3) # 3 is the depth multiplier index
78
+ axis_not_output_channel.remove(3) # 3 is the depth multiplier index
77
79
 
78
80
  bias_correction = shift_to_correct * np.sum(kernel, axis=tuple(axis_not_output_channel))
79
81
  op2d_node.set_weights_by_keys(bias_str, bias - bias_correction.flatten())
80
82
  else:
81
- raise NotImplementedError
83
+ raise NotImplementedError # pragma: no cover
82
84
 
83
85
 
84
86
  def insert_node_between_two_nodes(graph: Graph,
@@ -123,7 +125,7 @@ def insert_node_after_node(graph: Graph,
123
125
 
124
126
  last_nodes = graph.get_next_nodes(first_node)
125
127
  if len(last_nodes) != 1:
126
- raise Exception('Can only insert if there is only one input')
128
+ Logger.error('Can only insert if there is only one input') # pragma: no cover
127
129
  last_node = last_nodes[0]
128
130
  insert_node_between_two_nodes(graph, node_to_insert, first_node, last_node)
129
131
 
@@ -145,7 +147,7 @@ def insert_node_before_node(graph: Graph,
145
147
  """
146
148
  first_nodes = graph.get_prev_nodes(last_node)
147
149
  if len(first_nodes) != 1:
148
- raise Exception('Can only insert if there is only one input')
150
+ Logger.error('Can only insert if there is only one input') # pragma: no cover
149
151
  first_node = first_nodes[0]
150
152
  insert_node_between_two_nodes(graph, node_to_insert, first_node, last_node)
151
153
 
@@ -222,8 +224,8 @@ def shift_negative_function(graph: Graph,
222
224
  min_to_correct, max_value2compare = graph.get_out_stats_collector(non_linear_node).get_min_max_values()
223
225
 
224
226
  if not non_linear_node.is_all_activation_candidates_equal():
225
- raise Exception("Shift negative correction is not supported for more than one activation quantization "
226
- "configuration candidate")
227
+ Logger.error("Shift negative correction is not supported for more than one activation quantization "
228
+ "configuration candidate") # pragma: no cover
227
229
 
228
230
  # all candidates have same activation config, so taking the first candidate for calculations
229
231
  non_linear_node_cfg_candidate = non_linear_node.candidates_quantization_cfg[0].activation_quantization_cfg
@@ -241,7 +243,8 @@ def shift_negative_function(graph: Graph,
241
243
  # taking the minimal quantized point that is still positive.
242
244
  num_q_points = 2 ** non_linear_node_cfg_candidate.activation_n_bits
243
245
  lsb = activation_threshold / num_q_points
244
- q_points = np.linspace(0, activation_threshold - lsb, num_q_points).astype('float32') # Change to type float32 to support tensorflow dtypes
246
+ q_points = np.linspace(0, activation_threshold - lsb, num_q_points).astype(
247
+ 'float32') # Change to type float32 to support tensorflow dtypes
245
248
 
246
249
  delta = q_points + min_to_correct
247
250
  delta[delta < 0] = np.inf
@@ -253,14 +256,16 @@ def shift_negative_function(graph: Graph,
253
256
  hist_bins, hist_count)
254
257
 
255
258
  min_mse, _th, _shift = np.inf, None, None
256
- for _activation_threshold in [activation_threshold, 2*activation_threshold]:
259
+ for _activation_threshold in [activation_threshold, 2 * activation_threshold]:
257
260
  qparams = {THRESHOLD: _activation_threshold, SIGNED: False}
258
261
  _lsb = _activation_threshold / num_q_points
259
- _q_points = np.linspace(0, _activation_threshold - _lsb, num_q_points).astype('float32') # Change to type float32 to support tensorflow dtypes
262
+ _q_points = np.linspace(0, _activation_threshold - _lsb, num_q_points).astype(
263
+ 'float32') # Change to type float32 to support tensorflow dtypes
260
264
  for _shift_value in _q_points:
261
265
  _hist_bins = hist_bins.astype(np.float32) + _shift_value
262
- q_bins = non_linear_node_cfg_candidate.activation_quantization_fn(non_linear_node_cfg_candidate.activation_n_bits,
263
- qparams)(_hist_bins)
266
+ q_bins = non_linear_node_cfg_candidate.activation_quantization_fn(
267
+ non_linear_node_cfg_candidate.activation_n_bits,
268
+ qparams)(_hist_bins)
264
269
  mse = _mse_error_histogram(q_bins, None, _hist_bins, hist_count)
265
270
  if mse < min_mse:
266
271
  min_mse = mse
@@ -61,7 +61,7 @@ class BaseWeightsActivationSplit(BaseSubstitution):
61
61
  # Node is not composite, therefore, can't be split
62
62
  Logger.critical(f"The graph contains a node {node.name} with non composite candidates."
63
63
  f"In order to run mixed-precision search with BOPS target KPI, "
64
- f"all model layers should be composite.")
64
+ f"all model layers should be composite.") # pragma: no cover
65
65
 
66
66
  weights_node = VirtualSplitWeightsNode(node)
67
67
  activation_node = VirtualSplitActivationNode(node, self.activation_layer_type, self.fw_attr)
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from model_compression_toolkit.core.common.logger import Logger
17
+
16
18
  def get_current_tp_model():
17
19
  """
18
20
 
@@ -38,7 +40,7 @@ class CurrentTPModel:
38
40
 
39
41
  """
40
42
  if self.tp_model is None:
41
- raise Exception('Target platform model is not initialized.')
43
+ Logger.error('Target platform model is not initialized.') # pragma: no cover
42
44
  return self.tp_model
43
45
 
44
46
  def reset(self):
@@ -16,6 +16,8 @@
16
16
  import operator
17
17
  from typing import Any, Callable, Dict
18
18
 
19
+ from model_compression_toolkit.core.common.logger import Logger
20
+
19
21
 
20
22
  class Filter:
21
23
  """
@@ -31,7 +33,7 @@ class Filter:
31
33
  Returns:
32
34
  Whether the passed configuration matches the filter or not.
33
35
  """
34
- raise Exception('Filter did not implement match')
36
+ raise NotImplemented('Filter did not implement match') # pragma: no cover
35
37
 
36
38
 
37
39
  class AttributeFilter(Filter):
@@ -85,7 +87,7 @@ class AttributeFilter(Filter):
85
87
  """
86
88
 
87
89
  if not isinstance(other, AttributeFilter):
88
- raise Exception("Not an attribute filter. Can not run an OR operation.")
90
+ Logger.error("Not an attribute filter. Can not run an OR operation.") # pragma: no cover
89
91
  return OrAttributeFilter(self, other)
90
92
 
91
93
  def __and__(self, other: Any):
@@ -99,7 +101,7 @@ class AttributeFilter(Filter):
99
101
  AndAttributeFilter that filters with AND between the current AttributeFilter and the passed AttributeFilter.
100
102
  """
101
103
  if not isinstance(other, AttributeFilter):
102
- raise Exception("Not an attribute filter. Can not run an AND operation.")
104
+ Logger.error("Not an attribute filter. Can not run an AND operation.") # pragma: no cover
103
105
  return AndAttributeFilter(self, other)
104
106
 
105
107
  def match(self,
@@ -123,7 +125,7 @@ class AttributeFilter(Filter):
123
125
  Returns: A string representation for the filter.
124
126
 
125
127
  """
126
- raise Exception("Filter must implement op_as_str ")
128
+ raise NotImplemented("Filter must implement op_as_str ") # pragma: no cover
127
129
 
128
130
  def __repr__(self):
129
131
  return f'{self.attr} {self.op_as_str()} {self.value}'
@@ -267,3 +269,14 @@ class Eq(AttributeFilter):
267
269
  super().__init__(attr=attr, value=value, op=operator.eq)
268
270
 
269
271
  def op_as_str(self): return "="
272
+
273
+
274
+ class Contains(AttributeFilter):
275
+ """
276
+ Filter configurations such that it matches configurations that have an attribute with a value that contains the value that Contains holds.
277
+ """
278
+
279
+ def __init__(self, attr: str, value: Any):
280
+ super().__init__(attr=attr, value=value, op=operator.contains)
281
+
282
+ def op_as_str(self): return " in "
@@ -131,9 +131,7 @@ class OperationsToLayers:
131
131
  for layer in ops2layers.layers:
132
132
  qco_by_opset_name = _current_tpc.get().tp_model.get_config_options_by_operators_set(ops2layers.name)
133
133
  if layer in existing_layers:
134
- raise Exception(f'Found layer {layer.__name__} in more than one '
135
- f'OperatorsSet')
134
+ Logger.error(f'Found layer {layer.__name__} in more than one '
135
+ f'OperatorsSet') # pragma: no cover
136
136
  else:
137
137
  existing_layers.update({layer: qco_by_opset_name})
138
-
139
-
@@ -131,7 +131,8 @@ class TargetPlatformCapabilities(ImmutableClass):
131
131
  if isinstance(tpc_component, OperationsSetToLayers):
132
132
  self.op_sets_to_layers += tpc_component
133
133
  else:
134
- raise Exception(f'Trying to append an unfamiliar TargetPlatformCapabilitiesComponent of type: {type(tpc_component)}')
134
+ Logger.error(f'Trying to append an unfamiliar TargetPlatformCapabilitiesComponent of type: '
135
+ f'{type(tpc_component)}') # pragma: no cover
135
136
 
136
137
  def __enter__(self):
137
138
  """
@@ -175,7 +176,7 @@ class TargetPlatformCapabilities(ImmutableClass):
175
176
  QuantizationConfigOptions of the node.
176
177
  """
177
178
  if node is None:
178
- raise Exception(f'Can not retrieve QC options for None node')
179
+ Logger.error(f'Can not retrieve QC options for None node') # pragma: no cover
179
180
  for fl, qco in self.filterlayer2qco.items():
180
181
  if fl.match(node):
181
182
  return qco
@@ -205,7 +206,6 @@ class TargetPlatformCapabilities(ImmutableClass):
205
206
  layer2qco.update({l: qco})
206
207
  return layer2qco, filterlayer2qco
207
208
 
208
-
209
209
  def remove_fusing_names_from_not_used_list(self):
210
210
  """
211
211
  Remove OperatorSets names from the list of the unused sets (so a warning
@@ -235,5 +235,3 @@ class TargetPlatformCapabilities(ImmutableClass):
235
235
  """
236
236
  for op in self.__tp_model_opsets_not_used:
237
237
  Logger.warning(f'{op} is defined in TargetPlatformModel, but is not used in TargetPlatformCapabilities.')
238
-
239
-
@@ -20,38 +20,25 @@ from typing import List, Dict, Callable
20
20
  from networkx.algorithms.dag import topological_sort
21
21
 
22
22
  import tensorflow as tf
23
- from tensorflow.keras.layers import Layer
23
+ from tensorflow.keras.layers import Layer, InputLayer
24
24
  from model_compression_toolkit.core import common
25
25
  from model_compression_toolkit.core.common import Graph, BaseNode
26
26
  from model_compression_toolkit.core.keras.constants import LAYER_NAME
27
27
 
28
28
 
29
- def identity_wrapper(node: BaseNode, layer: Layer):
30
- """
31
- A function which takes a computational graph node and a keras layer and return an identity wrapping which return the layer itself
32
- Args:
33
- node: A node of mct graph.
34
- layer: A keras layer
35
-
36
- Returns: keras layer
37
-
38
- """
39
- return layer
40
-
41
-
42
29
  class OperationHandler:
43
30
  """
44
31
  Class to handle conversions from graph nodes to Keras operators and retrieving them.
45
32
  """
46
33
 
47
- def __init__(self, graph: Graph, wrapper: Callable = identity_wrapper):
34
+ def __init__(self, graph: Graph):
48
35
  # hold nodes after sorting them
49
36
  self.node_sort = list(topological_sort(graph))
50
37
 
51
38
  self.layer_to_node_dict = {}
52
39
 
53
40
  # hold dictionary from node to its equivalent Keras layer
54
- self.node_to_fw_op_dict = instance_builder(self.node_sort, wrapper)
41
+ self.node_to_fw_op_dict = instance_builder(self.node_sort)
55
42
 
56
43
  def get_node_op_function(self, n: BaseNode) -> Layer:
57
44
  """
@@ -86,10 +73,15 @@ def node_builder(n: common.BaseNode) -> Layer:
86
73
  Returns:
87
74
  Keras layer that was built from the node.
88
75
  """
89
-
90
76
  framework_attr = copy.copy(n.framework_attr)
77
+ if n.layer_class is InputLayer:
78
+ # replace input node with identity, so can wrap it with QuantizationWrapper
79
+ _layer_class = Layer # Identity
80
+ framework_attr = {}
81
+ else:
82
+ _layer_class = n.layer_class
91
83
  framework_attr[LAYER_NAME] = n.name # Overwrite framework name to identical graph node name
92
- node_instance = n.layer_class.from_config(framework_attr) # Build layer from node's configuration.
84
+ node_instance = _layer_class.from_config(framework_attr) # Build layer from node's configuration.
93
85
  with tf.name_scope(n.name):
94
86
  # Add layer name to default weight name to avoid name duplications
95
87
  node_instance.build(n.input_shape)
@@ -98,13 +90,12 @@ def node_builder(n: common.BaseNode) -> Layer:
98
90
  return node_instance
99
91
 
100
92
 
101
- def instance_builder(toposort: List[BaseNode], wrapper: Callable) -> Dict[BaseNode, Layer]:
93
+ def instance_builder(toposort: List[BaseNode]) -> Dict[BaseNode, Layer]:
102
94
  """
103
95
  Build a dictionary of nodes to their corresponding Keras
104
96
  layers, given a list of nodes.
105
97
 
106
98
  Args:
107
- wrapper: A function wrapper keras Layers.
108
99
  toposort: List of nodes sorted topological to build their layers.
109
100
 
110
101
  Returns:
@@ -114,7 +105,7 @@ def instance_builder(toposort: List[BaseNode], wrapper: Callable) -> Dict[BaseNo
114
105
  nodes_dict = dict()
115
106
  for n in toposort:
116
107
  if not n.reuse: # Hold a single node in dictionary for all reused nodes from the same layer.
117
- keras_node = wrapper(n, node_builder(n))
108
+ keras_node = node_builder(n)
118
109
  nodes_dict.update({n: keras_node})
119
110
 
120
111
  return nodes_dict