mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.38.4)
2
+ Generator: bdist_wheel (0.40.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -14,8 +14,6 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from model_compression_toolkit.core.common.quantization.debug_config import DebugConfig
17
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GumbelConfig, \
18
- GradientPTQConfigV2
19
17
  from model_compression_toolkit.core.common.quantization import quantization_config
20
18
  from model_compression_toolkit.core.common.mixed_precision import mixed_precision_quantization_config
21
19
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
@@ -26,6 +24,7 @@ from model_compression_toolkit.core.tpc_models.get_target_platform_capabilities
26
24
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
27
25
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
28
26
  MixedPrecisionQuantizationConfig, MixedPrecisionQuantizationConfigV2
27
+ from model_compression_toolkit.qat.common.qat_config import QATConfig, TrainingMethod
29
28
  from model_compression_toolkit.core.common.logger import set_log_folder
30
29
  from model_compression_toolkit.core.common.data_loader import FolderImageLoader
31
30
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo, ChannelAxis
@@ -35,21 +34,21 @@ from model_compression_toolkit.core.common import network_editors as network_edi
35
34
  from model_compression_toolkit.core.keras.quantization_facade import keras_post_training_quantization, \
36
35
  keras_post_training_quantization_mixed_precision
37
36
  from model_compression_toolkit.ptq.keras.quantization_facade import keras_post_training_quantization_experimental
38
- from model_compression_toolkit.gptq.keras.quantization_facade import \
39
- keras_gradient_post_training_quantization_experimental
40
- from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
41
- from model_compression_toolkit.qat.keras.quantization_facade import keras_quantization_aware_training_init, \
42
- keras_quantization_aware_training_finalize
43
- from model_compression_toolkit.core.pytorch.quantization_facade import pytorch_post_training_quantization, \
44
- pytorch_post_training_quantization_mixed_precision
37
+ from model_compression_toolkit.qat.keras.quantization_facade import keras_quantization_aware_training_init, keras_quantization_aware_training_finalize
38
+ from model_compression_toolkit.qat.pytorch.quantization_facade import pytorch_quantization_aware_training_init, pytorch_quantization_aware_training_finalize
39
+ from model_compression_toolkit.core.pytorch.quantization_facade import pytorch_post_training_quantization, pytorch_post_training_quantization_mixed_precision
45
40
  from model_compression_toolkit.ptq.pytorch.quantization_facade import pytorch_post_training_quantization_experimental
46
- from model_compression_toolkit.gptq.pytorch.quantization_facade import \
47
- pytorch_gradient_post_training_quantization_experimental
48
- from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config
49
41
 
50
42
  from model_compression_toolkit.core.keras.kpi_data_facade import keras_kpi_data, keras_kpi_data_experimental
51
43
  from model_compression_toolkit.core.pytorch.kpi_data_facade import pytorch_kpi_data, pytorch_kpi_data_experimental
52
44
 
53
- from model_compression_toolkit.qunatizers_infrastructure.keras.load_model import keras_load_quantized_model
45
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.load_model import keras_load_quantized_model
54
46
 
55
- __version__ = "1.7.1"
47
+
48
+ from model_compression_toolkit import exporter
49
+
50
+ from model_compression_toolkit import gptq
51
+ from model_compression_toolkit.gptq import GradientPTQConfig
52
+
53
+
54
+ __version__ = "1.8.0"
@@ -51,4 +51,4 @@ class BaseModelBuilder(ABC):
51
51
  Returns: A framework's model built from its graph.
52
52
 
53
53
  """
54
- raise NotImplemented(f'{self.__class__.__name__} have to implement build_model method.')
54
+ raise NotImplemented(f'{self.__class__.__name__} have to implement build_model method.') # pragma: no cover
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import numpy as np
17
+ from model_compression_toolkit.core.common.logger import Logger
17
18
 
18
19
 
19
20
  class BaseCollector(object):
@@ -33,7 +34,8 @@ class BaseCollector(object):
33
34
 
34
35
  """
35
36
 
36
- raise Exception(f'{self.__class__.__name__} needs to implement scale operation for its state.')
37
+ raise NotImplemented(
38
+ f'{self.__class__.__name__} needs to implement scale operation for its state.') # pragma: no cover
37
39
 
38
40
  def shift(self, shift_value: np.ndarray):
39
41
  """
@@ -43,7 +45,8 @@ class BaseCollector(object):
43
45
 
44
46
  """
45
47
 
46
- raise Exception(f'{self.__class__.__name__} needs to implement shift operation for its state.')
48
+ raise NotImplemented(
49
+ f'{self.__class__.__name__} needs to implement shift operation for its state.') # pragma: no cover
47
50
 
48
51
  def update_legal_status(self, is_illegal: bool):
49
52
  """
@@ -63,5 +66,5 @@ class BaseCollector(object):
63
66
  """
64
67
 
65
68
  if not self.is_legal:
66
- raise Exception(f'{self.__class__.__name__} was manipulated per-channel,'
67
- 'but collected per-tensor. Data is invalid.')
69
+ Logger.exception(f'{self.__class__.__name__} was manipulated per-channel,'
70
+ 'but collected per-tensor. Data is invalid.') # pragma: no cover
@@ -37,7 +37,7 @@ class BaseStatsCollector(object):
37
37
  Returns whether this tensor requires statistics collection or not.
38
38
  Should be implemented in extending classes.
39
39
  """
40
- raise Exception(f'require_collection is not implemented in {self.__class__.__name__}')
40
+ raise NotImplemented(f'require_collection is not implemented in {self.__class__.__name__}') # pragma: no cover
41
41
 
42
42
  def update_statistics(self,
43
43
  x: Any):
@@ -47,7 +47,7 @@ class BaseStatsCollector(object):
47
47
  Args:
48
48
  x: Tensor.
49
49
  """
50
- raise Exception(f'update_statistics is not implemented in {self.__class__.__name__}')
50
+ raise NotImplemented(f'update_statistics is not implemented in {self.__class__.__name__}') # pragma: no cover
51
51
 
52
52
 
53
53
  class StatsCollector(BaseStatsCollector):
@@ -21,10 +21,11 @@ FOUND_TF = importlib.util.find_spec(TENSORFLOW) is not None and importlib.util.f
21
21
  "tensorflow_model_optimization") is not None
22
22
  FOUND_TORCH = importlib.util.find_spec("torch") is not None
23
23
  FOUND_ONNX = importlib.util.find_spec("onnx") is not None
24
+ FOUND_ONNXRUNTIME = importlib.util.find_spec("onnxruntime") is not None
24
25
 
25
26
  WEIGHTS_SIGNED = True
26
27
  # Minimal threshold to use for quantization ranges:
27
- MIN_THRESHOLD = (2 ** -28)
28
+ MIN_THRESHOLD = (2 ** -16)
28
29
  EPS = 1e-8
29
30
  MULTIPLIER_N_BITS = 8
30
31
 
@@ -114,12 +115,16 @@ ACTIVATION_QUANT_PARAMS_FN = 'activation_quantization_params_fn'
114
115
  WEIGHTS_QUANT_PARAMS_FN = 'weights_quantization_params_fn'
115
116
  WEIGHTS_CHANNELS_AXIS = 'weights_channels_axis'
116
117
 
117
- # GPTQ Parameters
118
- GUMBEL_MAX_ITER = 10000
119
-
120
118
  # Memory graph constants
121
119
  DUMMY_NODE = 'dummy_node'
122
120
  DUMMY_TENSOR = 'dummy_tensor'
123
121
 
124
122
  # TP Model constants
125
123
  OPS_SET_LIST = 'ops_set_list'
124
+
125
+ # TF Input node base name
126
+ INPUT_BASE_NAME = 'base_input'
127
+
128
+ # Jacobian-weights constants
129
+ MIN_JACOBIANS_ITER = 10
130
+ JACOBIANS_COMP_TOLERANCE = 1e-3
@@ -44,7 +44,7 @@ class FrameworkImplementation(ABC):
44
44
  Returns: Module of the framework constants.
45
45
 
46
46
  """
47
- raise Exception(f'{self.__class__.__name__} did not supply a constants module.')
47
+ raise NotImplemented(f'{self.__class__.__name__} did not supply a constants module.') # pragma: no cover
48
48
 
49
49
  @abstractmethod
50
50
  def to_numpy(self, tensor: Any) -> np.ndarray:
@@ -57,7 +57,7 @@ class FrameworkImplementation(ABC):
57
57
  Numpy array converted from the input tensor.
58
58
  """
59
59
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
60
- f'framework\'s to_numpy method.') \
60
+ f'framework\'s to_numpy method.') # pragma: no cover
61
61
 
62
62
  @abstractmethod
63
63
  def to_tensor(self, tensor: np.ndarray) -> Any:
@@ -70,7 +70,7 @@ class FrameworkImplementation(ABC):
70
70
  Framework's tensor converted from the input Numpy array.
71
71
  """
72
72
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
73
- f'framework\'s to_tensor method.')
73
+ f'framework\'s to_tensor method.') # pragma: no cover
74
74
 
75
75
  @abstractmethod
76
76
  def model_reader(self,
@@ -86,7 +86,7 @@ class FrameworkImplementation(ABC):
86
86
  Graph representing the input model.
87
87
  """
88
88
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
89
- f'framework\'s model_reader method.')
89
+ f'framework\'s model_reader method.') # pragma: no cover
90
90
 
91
91
  @abstractmethod
92
92
  def model_builder(self,
@@ -111,7 +111,7 @@ class FrameworkImplementation(ABC):
111
111
  A tuple of the model that was built and an UserInformation object.
112
112
  """
113
113
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
114
- f'framework\'s model_builder method.')
114
+ f'framework\'s model_builder method.') # pragma: no cover
115
115
 
116
116
  @abstractmethod
117
117
  def run_model_inference(self,
@@ -128,7 +128,7 @@ class FrameworkImplementation(ABC):
128
128
  The frameworks model's output.
129
129
  """
130
130
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
131
- f'framework\'s run_model_inference method.')
131
+ f'framework\'s run_model_inference method.') # pragma: no cover
132
132
 
133
133
  @abstractmethod
134
134
  def shift_negative_correction(self,
@@ -147,7 +147,7 @@ class FrameworkImplementation(ABC):
147
147
  Graph after SNC.
148
148
  """
149
149
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
150
- f'framework\'s apply_shift_negative_correction method.')
150
+ f'framework\'s apply_shift_negative_correction method.') # pragma: no cover
151
151
 
152
152
  @abstractmethod
153
153
  def attach_sc_to_node(self, node: BaseNode, fw_info: FrameworkInfo) -> BaseStatsCollector:
@@ -163,7 +163,7 @@ class FrameworkImplementation(ABC):
163
163
  Statistics collector for the node.
164
164
  """
165
165
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
166
- f'framework\'s attach_sc_to_node method.')
166
+ f'framework\'s attach_sc_to_node method.') # pragma: no cover
167
167
 
168
168
  @abstractmethod
169
169
  def get_substitutions_channel_equalization(self,
@@ -180,7 +180,7 @@ class FrameworkImplementation(ABC):
180
180
  A list of the framework substitutions used after we collect statistics.
181
181
  """
182
182
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
183
- f'framework\'s get_substitutions_channel_equalization method.')
183
+ f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover
184
184
 
185
185
  @abstractmethod
186
186
  def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List[common.BaseSubstitution]:
@@ -190,7 +190,7 @@ class FrameworkImplementation(ABC):
190
190
 
191
191
  """
192
192
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
193
- f'framework\'s get_substitutions_prepare_graph method.')
193
+ f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover
194
194
 
195
195
  @abstractmethod
196
196
  def get_substitutions_pre_statistics_collection(self, quant_config: QuantizationConfig) -> \
@@ -204,7 +204,7 @@ class FrameworkImplementation(ABC):
204
204
 
205
205
  """
206
206
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
207
- f'framework\'s get_substitutions_pre_statistics_collection method.')
207
+ f'framework\'s get_substitutions_pre_statistics_collection method.') # pragma: no cover
208
208
 
209
209
  @abstractmethod
210
210
  def get_linear_collapsing_substitution(self) -> common.BaseSubstitution:
@@ -212,7 +212,7 @@ class FrameworkImplementation(ABC):
212
212
  Returns: linear collapsing substitution
213
213
  """
214
214
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
215
- f'framework\'s get_linear_collapsing_substitution method.')
215
+ f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover
216
216
 
217
217
  @abstractmethod
218
218
  def get_substitutions_statistics_correction(self, quant_config: QuantizationConfig) -> \
@@ -227,7 +227,7 @@ class FrameworkImplementation(ABC):
227
227
  A list of the framework substitutions used for statistics correction.
228
228
  """
229
229
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
230
- f'framework\'s get_substitutions_statistics_correction method.')
230
+ f'framework\'s get_substitutions_statistics_correction method.') # pragma: no cover
231
231
 
232
232
  @abstractmethod
233
233
  def get_residual_collapsing_substitution(self) -> List[common.BaseSubstitution]:
@@ -235,7 +235,7 @@ class FrameworkImplementation(ABC):
235
235
  Returns: A list of the framework substitutions used for residual collapsing
236
236
  """
237
237
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
238
- f'framework\'s get_residual_collapsing_substitution method.')
238
+ f'framework\'s get_residual_collapsing_substitution method.') # pragma: no cover
239
239
 
240
240
  @abstractmethod
241
241
  def get_substitutions_pre_build(self) -> List[common.BaseSubstitution]:
@@ -245,7 +245,7 @@ class FrameworkImplementation(ABC):
245
245
 
246
246
  """
247
247
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
248
- f'framework\'s get_substitutions_pre_build method.')
248
+ f'framework\'s get_substitutions_pre_build method.') # pragma: no cover
249
249
 
250
250
  @abstractmethod
251
251
  def get_substitutions_post_statistics_collection(self, quant_config: QuantizationConfig) -> List[
@@ -260,7 +260,7 @@ class FrameworkImplementation(ABC):
260
260
  A list of the framework substitutions used after we collect statistics.
261
261
  """
262
262
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
263
- f'framework\'s get_substitutions_post_statistics_collection method.')
263
+ f'framework\'s get_substitutions_post_statistics_collection method.') # pragma: no cover
264
264
 
265
265
  @abstractmethod
266
266
  def get_substitutions_virtual_weights_activation_coupling(self) -> List[common.BaseSubstitution]:
@@ -269,7 +269,8 @@ class FrameworkImplementation(ABC):
269
269
  """
270
270
 
271
271
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
272
- f'framework\'s get_substitutions_virtual_weights_activation_coupling method.')
272
+ f'framework\'s get_substitutions_virtual_weights_activation_coupling '
273
+ f'method.') # pragma: no cover
273
274
 
274
275
  @abstractmethod
275
276
  def get_substitutions_after_second_moment_correction(self, quant_config: QuantizationConfig) \
@@ -284,7 +285,8 @@ class FrameworkImplementation(ABC):
284
285
  A list of the framework substitutions used after we apply second moment statistics.
285
286
  """
286
287
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
287
- f'framework\'s get_substitutions_after_second_moment_correction method.')
288
+ f'framework\'s get_substitutions_after_second_moment_correction '
289
+ f'method.') # pragma: no cover
288
290
 
289
291
  @abstractmethod
290
292
  def get_gptq_trainer_obj(self):
@@ -292,7 +294,7 @@ class FrameworkImplementation(ABC):
292
294
  Returns: GPTQTrainer object
293
295
  """
294
296
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
295
- f'framework\'s get_gptq_trainer method.')
297
+ f'framework\'s get_gptq_trainer method.') # pragma: no cover
296
298
 
297
299
  @abstractmethod
298
300
  def get_sensitivity_evaluator(self,
@@ -317,7 +319,7 @@ class FrameworkImplementation(ABC):
317
319
  """
318
320
 
319
321
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
320
- f'framework\'s get_sensitivity_evaluator method.')
322
+ f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover
321
323
 
322
324
  def get_node_prior_info(self, node: BaseNode,
323
325
  fw_info: FrameworkInfo,
@@ -335,7 +337,7 @@ class FrameworkImplementation(ABC):
335
337
  """
336
338
 
337
339
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
338
- f'framework\'s get_node_prior_info method.')
340
+ f'framework\'s get_node_prior_info method.') # pragma: no cover
339
341
 
340
342
  def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
341
343
  """
@@ -346,7 +348,7 @@ class FrameworkImplementation(ABC):
346
348
  """
347
349
 
348
350
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
349
- f'framework\'s count_node_for_mixed_precision_interest_points method.')
351
+ f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover
350
352
 
351
353
  def get_node_distance_fn(self, layer_class: type,
352
354
  framework_attrs: Dict[str, Any],
@@ -365,7 +367,7 @@ class FrameworkImplementation(ABC):
365
367
  """
366
368
 
367
369
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
368
- f'framework\'s get_node_distance_fn method.')
370
+ f'framework\'s get_node_distance_fn method.') # pragma: no cover
369
371
 
370
372
  @abstractmethod
371
373
  def get_model_layers_names(self,
@@ -381,7 +383,7 @@ class FrameworkImplementation(ABC):
381
383
  """
382
384
 
383
385
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
384
- f'framework\'s get_model_layers_names method.')
386
+ f'framework\'s get_model_layers_names method.') # pragma: no cover
385
387
 
386
388
  @abstractmethod
387
389
  def get_model_layer_by_name(self,
@@ -399,7 +401,7 @@ class FrameworkImplementation(ABC):
399
401
  """
400
402
 
401
403
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
402
- f'framework\'s get_model_layer_by_name method.')
404
+ f'framework\'s get_model_layer_by_name method.') # pragma: no cover
403
405
 
404
406
  @abstractmethod
405
407
  def model_grad(self,
@@ -433,7 +435,7 @@ class FrameworkImplementation(ABC):
433
435
  """
434
436
 
435
437
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
436
- f'framework\'s model_grad method.')
438
+ f'framework\'s model_grad method.') # pragma: no cover
437
439
 
438
440
  @abstractmethod
439
441
  def is_node_compatible_for_metric_outputs(self,
@@ -450,7 +452,7 @@ class FrameworkImplementation(ABC):
450
452
  """
451
453
 
452
454
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
453
- f'framework\'s is_node_compatible_for_metric_outputs method.')
455
+ f'framework\'s is_node_compatible_for_metric_outputs method.') # pragma: no cover
454
456
 
455
457
  @abstractmethod
456
458
  def get_node_mac_operations(self,
@@ -467,7 +469,7 @@ class FrameworkImplementation(ABC):
467
469
  """
468
470
 
469
471
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
470
- f'framework\'s get_node_mac_operations method.')
472
+ f'framework\'s get_node_mac_operations method.') # pragma: no cover
471
473
 
472
474
  @abstractmethod
473
475
  def apply_second_moment_correction(self,
@@ -488,7 +490,7 @@ class FrameworkImplementation(ABC):
488
490
  A Graph after second moment correction.
489
491
  """
490
492
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
491
- f'framework\'s apply_second_moment_correction method.')
493
+ f'framework\'s apply_second_moment_correction method.') # pragma: no cover
492
494
 
493
495
  @abstractmethod
494
496
  def sensitivity_eval_inference(self,
@@ -505,4 +507,4 @@ class FrameworkImplementation(ABC):
505
507
  The output of the model inference on the given input.
506
508
  """
507
509
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
508
- f'framework\'s sensitivity_eval_inference method.')
510
+ f'framework\'s sensitivity_eval_inference method.') # pragma: no cover
@@ -75,7 +75,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
75
75
  self.fused_nodes = []
76
76
 
77
77
  def set_fw_info(self,
78
- fw_info: FrameworkInfo):
78
+ fw_info: FrameworkInfo):
79
79
  """
80
80
  Set the graph's framework info.
81
81
  Args:
@@ -93,7 +93,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
93
93
  """
94
94
  self.tpc = tpc
95
95
 
96
-
97
96
  def get_topo_sorted_nodes(self):
98
97
  """
99
98
  Returns: a list of toposorted nodes.
@@ -216,7 +215,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
216
215
 
217
216
  sc = self.node_to_in_stats_collector.get(n)
218
217
  if sc is None:
219
- raise Exception()
218
+ Logger.error(f'Input statistics collector of node {n.name} is None') # pragma: no cover
220
219
  return sc
221
220
 
222
221
  def scale_stats_collector(self,
@@ -350,7 +349,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
350
349
  input_nodes_output_index = [0] * len(input_nodes)
351
350
 
352
351
  if len(input_nodes_output_index) != len(input_nodes):
353
- raise Exception('Graph.add_node_with_in_edges: input_nodes & input_nodes_output_index must be the same length')
352
+ Logger.error('Graph.add_node_with_in_edges: input_nodes & input_nodes_output_index must be the same '
353
+ 'length') # pragma: no cover
354
354
 
355
355
  self.add_node(new_node)
356
356
  for sink_index, (in_node, source_index) in enumerate(zip(input_nodes, input_nodes_output_index)):
@@ -420,12 +420,14 @@ class Graph(nx.MultiDiGraph, GraphSearches):
420
420
  output_nodes = [ot.node for ot in self.get_outputs()] # get output nodes from namedtuples
421
421
  if node_to_remove in output_nodes: # If node is in the graph's outputs, the outputs should be updated
422
422
  if new_graph_outputs is None:
423
- Logger.critical(f'{node_to_remove.name} is in graph outputs, but new outputs were not given.')
423
+ Logger.critical(
424
+ f'{node_to_remove.name} is in graph outputs, but new outputs were not given.') # pragma: no cover
424
425
  self.set_outputs(new_graph_outputs)
425
426
 
426
427
  if node_to_remove in self.get_inputs(): # If node is in the graph's inputs, the inputs should be updated
427
428
  if new_graph_inputs is None:
428
- Logger.critical(f'{node_to_remove.name} is in graph inputs, but new inputs were not given.')
429
+ Logger.critical(
430
+ f'{node_to_remove.name} is in graph inputs, but new inputs were not given.') # pragma: no cover
429
431
  self.set_inputs(new_graph_inputs)
430
432
 
431
433
  # Make sure there are no connected edges left to the node before removing it.
@@ -17,7 +17,6 @@
17
17
  import logging
18
18
  import os
19
19
  from datetime import datetime
20
- from os import path
21
20
  from pathlib import Path
22
21
 
23
22
  LOGGER_NAME = 'Constrained Model Optimization'
@@ -43,7 +42,7 @@ class Logger:
43
42
 
44
43
  """
45
44
 
46
- if not path.exists(log_path):
45
+ if not os.path.exists(log_path):
47
46
  Path(log_path).mkdir(parents=True, exist_ok=True)
48
47
 
49
48
  @staticmethod
@@ -93,6 +92,15 @@ class Logger:
93
92
 
94
93
  print(f'log file is in {log_name}')
95
94
 
95
+ @staticmethod
96
+ def shutdown():
97
+ """
98
+ An orderly command to shutdown by flushing and closing all logging handlers.
99
+
100
+ """
101
+ Logger.LOG_PATH = None
102
+ logging.shutdown()
103
+
96
104
  ########################################
97
105
  # Delegating methods to wrapped logger
98
106
  ########################################
@@ -41,19 +41,19 @@ class BaseMatcher(object):
41
41
  """
42
42
  Return a matcher to check the logic AND of two BaseMatchers on an object.
43
43
  """
44
- raise NotImplemented
44
+ raise NotImplemented # pragma: no cover
45
45
 
46
46
  def __or__(self, other: Any):
47
47
  """
48
48
  Return a matcher to check the logic OR of BaseMatchers on an object.
49
49
  """
50
- raise NotImplemented
50
+ raise NotImplemented # pragma: no cover
51
51
 
52
52
  def logic_not(self):
53
53
  """
54
54
  Return a matcher to check the logic NOT of the BaseMatcher on an object.
55
55
  """
56
- raise NotImplemented
56
+ raise NotImplemented # pragma: no cover
57
57
 
58
58
  def logic_and(self, other: Any):
59
59
  """
@@ -127,7 +127,8 @@ class MixedPrecisionQuantizationConfig(QuantizationConfig):
127
127
  elif hasattr(_dummy_mp_config_experimental, k):
128
128
  mp_dict.update({k: v})
129
129
  else:
130
- raise Exception(f'Attribute "{k}" mismatch: exists in MixedPrecisionQuantizationConfig but not in MixedPrecisionQuantizationConfigV2') # pragma: no cover
130
+ Logger.error(f'Attribute "{k}" mismatch: exists in MixedPrecisionQuantizationConfig but not in '
131
+ f'MixedPrecisionQuantizationConfigV2') # pragma: no cover
131
132
 
132
133
  return QuantizationConfig(**qc_dict), MixedPrecisionQuantizationConfigV2(**mp_dict)
133
134
 
@@ -75,7 +75,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
75
75
 
76
76
  # target_kpi have to be passed. If it was not passed, the facade is not supposed to get here by now.
77
77
  if target_kpi is None:
78
- Logger.critical('Target KPI have to be passed for search_methods bit-width configuration')
78
+ Logger.critical('Target KPI have to be passed for search_methods bit-width configuration') # pragma: no cover
79
79
 
80
80
  # Set graph for MP search
81
81
  graph = copy.deepcopy(graph_to_search_cfg) # Copy graph before searching
@@ -114,7 +114,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
114
114
  if search_method in search_methods: # Get a specific search function
115
115
  search_method_fn = search_methods.get(search_method)
116
116
  else:
117
- raise NotImplemented
117
+ raise NotImplemented # pragma: no cover
118
118
 
119
119
  # Search for the desired mixed-precision configuration
120
120
  result_bit_cfg = search_method_fn(search_manager,
@@ -350,8 +350,8 @@ class ConfigReconstructionHelper:
350
350
 
351
351
  if changed_virtual_nodes_idx is not None:
352
352
  if original_base_config is None:
353
- Logger.critical("Must provide a base original config in order to run config reconstruction for partial" # pragma: no cover
354
- "set of nodes.")
353
+ Logger.critical("Must provide a base original config in order to run config reconstruction for partial"
354
+ "set of nodes.") # pragma: no cover
355
355
 
356
356
  updated_virtual_nodes = \
357
357
  [(idx, self.virtual_graph.get_configurable_sorted_nodes()[idx]) for idx in changed_virtual_nodes_idx]
@@ -22,6 +22,8 @@ from model_compression_toolkit.core.common import Logger
22
22
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI, KPITarget
23
23
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import MixedPrecisionSearchManager
24
24
 
25
+ # Limit ILP solver runtime in seconds
26
+ SOLVER_TIME_LIMIT = 60
25
27
 
26
28
  def mp_integer_programming_search(search_manager: MixedPrecisionSearchManager,
27
29
  target_kpi: KPI = None) -> List[int]:
@@ -64,7 +66,10 @@ def mp_integer_programming_search(search_manager: MixedPrecisionSearchManager,
64
66
  target_kpi,
65
67
  search_manager)
66
68
 
67
- lp_problem.solve() # Try to solve the problem.
69
+ # Use default PULP solver. Limit runtime in seconds
70
+ solver = PULP_CBC_CMD(timeLimit=SOLVER_TIME_LIMIT)
71
+ lp_problem.solve(solver=solver) # Try to solve the problem.
72
+
68
73
  assert lp_problem.status == LpStatusOptimal, Logger.critical(
69
74
  "No solution was found during solving the LP problem")
70
75
  Logger.info(LpStatus[lp_problem.status])
@@ -30,7 +30,8 @@ class ModelValidation:
30
30
  If the model has layers with different output channels index, it should throw an exception.
31
31
 
32
32
  """
33
- raise NotImplemented(f'Framework validation class did not implement validate_output_channel_consistency')
33
+ raise NotImplemented(
34
+ f'Framework validation class did not implement validate_output_channel_consistency') # pragma: no cover
34
35
 
35
36
  def validate(self):
36
37
  """
@@ -17,6 +17,8 @@
17
17
  from typing import Callable, Any
18
18
 
19
19
  import numpy as np
20
+
21
+ from model_compression_toolkit.core.common.logger import Logger
20
22
  from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
21
23
  get_activation_quantization_params_fn, get_weights_quantization_params_fn
22
24
 
@@ -111,7 +113,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
111
113
  self.activation_quantization_params)
112
114
 
113
115
  if fake_quant is None:
114
- raise Exception('Layer is meant to be quantized but fake_quant function is None')
116
+ Logger.error('Layer is meant to be quantized but fake_quant function is None') # pragma: no cover
115
117
  return fake_quant(tensors)
116
118
 
117
119
  @property
@@ -16,6 +16,7 @@
16
16
  from collections.abc import Callable
17
17
  from functools import partial
18
18
 
19
+ from model_compression_toolkit.core.common.logger import Logger
19
20
  from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
21
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.kmeans_params import kmeans_tensor
21
22
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import \
@@ -47,8 +48,9 @@ def get_activation_quantization_params_fn(activation_quantization_method: Quanti
47
48
  elif activation_quantization_method == QuantizationMethod.LUT_POT_QUANTIZER:
48
49
  params_fn = lut_kmeans_histogram
49
50
  else:
50
- raise Exception(
51
- f'No params function for the configuration of quantization method {activation_quantization_method}')
51
+ Logger.error(
52
+ f'No params function for the configuration of '
53
+ f'quantization method {activation_quantization_method}') # pragma: no cover
52
54
  return params_fn
53
55
 
54
56
 
@@ -75,6 +77,7 @@ def get_weights_quantization_params_fn(weights_quantization_method: Quantization
75
77
  elif weights_quantization_method == QuantizationMethod.LUT_SYM_QUANTIZER:
76
78
  params_fn = partial(lut_kmeans_tensor, is_symmetric=True)
77
79
  else:
78
- raise Exception(
79
- f'No params function for the configuration of quantization method {weights_quantization_method}')
80
+ Logger.error(
81
+ f'No params function for the configuration of '
82
+ f'quantization method {weights_quantization_method}') # pragma: no cover
80
83
  return params_fn