mct-nightly 1.8.0.22042023.post414__py3-none-any.whl → 1.8.0.22052023.post408__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. {mct_nightly-1.8.0.22042023.post414.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/METADATA +1 -1
  2. {mct_nightly-1.8.0.22042023.post414.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/RECORD +237 -230
  3. model_compression_toolkit/__init__.py +8 -31
  4. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  5. model_compression_toolkit/core/__init__.py +14 -0
  6. model_compression_toolkit/core/analyzer.py +3 -2
  7. model_compression_toolkit/core/common/__init__.py +0 -1
  8. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  9. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  11. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  12. model_compression_toolkit/core/common/fusion/layer_fusing.py +2 -2
  13. model_compression_toolkit/core/common/graph/base_graph.py +1 -1
  14. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  15. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  16. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  17. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  18. model_compression_toolkit/core/common/memory_computation.py +1 -1
  19. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +1 -1
  20. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +2 -3
  21. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  22. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  23. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  25. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  26. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  27. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  28. model_compression_toolkit/core/common/model_collector.py +2 -2
  29. model_compression_toolkit/core/common/model_validation.py +1 -1
  30. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  31. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  32. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
  33. model_compression_toolkit/core/common/quantization/node_quantization_config.py +1 -1
  34. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  35. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
  36. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
  37. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  38. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  39. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +2 -2
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +2 -1
  47. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  48. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  49. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  50. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  51. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  52. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -2
  53. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  54. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +4 -3
  55. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -2
  56. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
  57. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +2 -2
  58. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  59. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +4 -4
  60. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  61. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  62. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  63. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  64. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  65. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +66 -21
  66. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  67. model_compression_toolkit/core/keras/back2framework/model_gradients.py +2 -2
  68. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  69. model_compression_toolkit/core/keras/constants.py +0 -7
  70. model_compression_toolkit/core/keras/default_framework_info.py +2 -2
  71. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  72. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  73. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  74. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  79. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  80. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  81. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  82. model_compression_toolkit/core/keras/kpi_data_facade.py +7 -7
  83. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  84. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  85. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  86. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  87. model_compression_toolkit/core/keras/reader/common.py +1 -1
  88. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  89. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  90. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  91. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  92. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +2 -2
  93. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  94. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  95. model_compression_toolkit/core/pytorch/constants.py +0 -6
  96. model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  98. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  99. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  100. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
  101. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  102. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  103. model_compression_toolkit/core/pytorch/kpi_data_facade.py +6 -6
  104. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  105. model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -9
  106. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  107. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  108. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  109. model_compression_toolkit/core/pytorch/reader/graph_builders.py +3 -2
  110. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  111. model_compression_toolkit/core/runner.py +6 -6
  112. model_compression_toolkit/exporter/__init__.py +6 -3
  113. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  114. model_compression_toolkit/exporter/model_exporter/keras/export_serialization_format.py +20 -0
  115. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  116. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/fakely_quant_tflite_exporter.py +1 -1
  117. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/int8_tflite_exporter.py +1 -1
  118. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +60 -22
  119. model_compression_toolkit/exporter/model_exporter/pytorch/export_serialization_format.py +20 -0
  120. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  121. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +54 -31
  123. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +5 -3
  124. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +4 -2
  125. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +2 -2
  126. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +3 -2
  127. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +3 -2
  128. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +2 -2
  129. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  130. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  131. model_compression_toolkit/gptq/common/gptq_training.py +5 -4
  132. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  133. model_compression_toolkit/gptq/keras/gptq_training.py +41 -14
  134. model_compression_toolkit/gptq/keras/graph_info.py +4 -0
  135. model_compression_toolkit/gptq/keras/quantization_facade.py +26 -19
  136. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -2
  137. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +1 -1
  138. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  139. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +2 -2
  140. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  141. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  142. model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
  143. model_compression_toolkit/gptq/pytorch/quantization_facade.py +11 -11
  144. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -3
  145. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +1 -3
  146. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  147. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +2 -2
  148. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  149. model_compression_toolkit/gptq/runner.py +3 -2
  150. model_compression_toolkit/{exporter/model_exporter/tflite → legacy}/__init__.py +1 -1
  151. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +8 -9
  152. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +8 -9
  153. model_compression_toolkit/ptq/__init__.py +3 -0
  154. model_compression_toolkit/ptq/keras/quantization_facade.py +10 -11
  155. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -7
  156. model_compression_toolkit/qat/__init__.py +4 -0
  157. model_compression_toolkit/qat/common/__init__.py +1 -2
  158. model_compression_toolkit/qat/common/qat_config.py +5 -1
  159. model_compression_toolkit/qat/keras/quantization_facade.py +33 -27
  160. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  161. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +31 -4
  162. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +12 -10
  163. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +8 -8
  164. model_compression_toolkit/qat/pytorch/quantization_facade.py +8 -8
  165. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  166. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +3 -2
  167. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +6 -4
  168. model_compression_toolkit/quantizers_infrastructure/__init__.py +2 -2
  169. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +5 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +1 -1
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +147 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +5 -5
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +2 -2
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +1 -1
  178. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  179. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  180. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +1 -1
  181. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +1 -1
  182. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +1 -1
  183. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +1 -1
  184. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +1 -1
  185. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  186. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +2 -2
  187. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +1 -2
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +1 -1
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +1 -1
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +1 -1
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +1 -1
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +1 -1
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +1 -1
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +1 -1
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +1 -1
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +1 -1
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +1 -1
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  200. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +9 -9
  201. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  202. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +3 -5
  203. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +2 -2
  204. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +2 -2
  205. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  206. model_compression_toolkit/target_platform_capabilities/target_platform/current_tp_model.py +1 -1
  207. model_compression_toolkit/target_platform_capabilities/target_platform/operators.py +1 -1
  208. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  209. model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +11 -2
  210. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  211. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/layer_filter_params.py +32 -34
  212. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +2 -2
  213. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -24
  214. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +1 -1
  215. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/target_platform_capabilities.py +3 -1
  216. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v1/tp_model.py +7 -1
  217. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v2/tp_model.py +7 -1
  218. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v3/tp_model.py +7 -1
  219. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v3_lut/tp_model.py +7 -2
  220. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v4/tp_model.py +7 -1
  221. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v4_lut/tp_model.py +7 -2
  222. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v5/tp_model.py +7 -1
  223. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +1 -3
  224. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +1 -1
  225. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +2 -1
  226. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  227. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +1 -1
  228. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +2 -1
  229. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  230. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +1 -1
  231. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +2 -1
  232. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  233. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +0 -73
  234. {mct_nightly-1.8.0.22042023.post414.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/LICENSE.md +0 -0
  235. {mct_nightly-1.8.0.22042023.post414.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/WHEEL +0 -0
  236. {mct_nightly-1.8.0.22042023.post414.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/top_level.txt +0 -0
  237. /model_compression_toolkit/{core/common/logger.py → logger.py} +0 -0
  238. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
@@ -18,7 +18,7 @@ from typing import Any, Tuple, Dict
18
18
 
19
19
  import numpy as np
20
20
 
21
- from model_compression_toolkit.core.common.constants import MIN_THRESHOLD, DEFAULT_TOL, DEFAULT_DEC_FACTOR, \
21
+ from model_compression_toolkit.constants import MIN_THRESHOLD, DEFAULT_TOL, DEFAULT_DEC_FACTOR, \
22
22
  SYMMETRIC_TENSOR_PER_CHANNEL_N_INTERVALS, SYMMETRIC_TENSOR_PER_CHANNEL_N_ITER, SYMMETRIC_TENSOR_DEC_FREQ, \
23
23
  SYMMETRIC_TENSOR_PER_CHANNEL_DEC_FREQ, SYMMETRIC_TENSOR_N_INTERVALS, SYMMETRIC_TENSOR_N_ITER, \
24
24
  UNIFORM_TENSOR_PER_CHANNEL_N_ITER, UNIFORM_TENSOR_N_ITER, SYMMETRIC_HISTOGRAM_DEC_FREQ, SYMMETRIC_HISTOGRAM_N_ITER, \
@@ -16,7 +16,7 @@ from typing import Dict, Any, Tuple
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common import Logger
19
+ from model_compression_toolkit.logger import Logger
20
20
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
21
21
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
22
  from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
@@ -15,7 +15,7 @@
15
15
  import numpy as np
16
16
 
17
17
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
18
- from model_compression_toolkit.core.common.constants import MIN_THRESHOLD, THRESHOLD
18
+ from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD
19
19
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
20
20
  get_threshold_selection_tensor_error_function, get_threshold_selection_histogram_error_function, _kl_error_histogram
21
21
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_search import \
@@ -15,7 +15,7 @@
15
15
  import numpy as np
16
16
 
17
17
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
18
- from model_compression_toolkit.core.common.constants import MIN_THRESHOLD, RANGE_MIN, RANGE_MAX
18
+ from model_compression_toolkit.constants import MIN_THRESHOLD, RANGE_MIN, RANGE_MAX
19
19
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_search import \
20
20
  qparams_uniform_selection_tensor_search, qparams_uniform_selection_histogram_search
21
21
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
@@ -20,6 +20,7 @@ from model_compression_toolkit.core.common.framework_implementation import Frame
20
20
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
21
21
  from model_compression_toolkit.core.common.graph.base_graph import Graph
22
22
  from model_compression_toolkit.core.common.quantization.quantize_node import get_quantized_kernel_by_weights_qc
23
+ from model_compression_toolkit.logger import Logger
23
24
 
24
25
 
25
26
  def quantize_graph_weights(graph: Graph,
@@ -47,7 +48,7 @@ def quantize_graph_weights(graph: Graph,
47
48
  n.final_weights_quantization_cfg,
48
49
  fw_impl=fw_impl)
49
50
 
50
- common.Logger.debug(
51
+ Logger.debug(
51
52
  f'Node name: {n.name} has the following quantization params: '
52
53
  f'{str(n.final_weights_quantization_cfg.weights_quantization_params)}')
53
54
 
@@ -15,7 +15,7 @@
15
15
 
16
16
 
17
17
  from model_compression_toolkit.core import common
18
- from model_compression_toolkit.core.common import Logger
18
+ from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
20
20
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
21
21
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
@@ -46,7 +46,7 @@ def get_quantized_kernel_by_weights_qc(fw_info: FrameworkInfo,
46
46
  # If weights should be quantized per-channel but a kernel channels mapping is missing.
47
47
  if weights_qc.weights_per_channel_threshold and fw_info.kernel_channels_mapping is \
48
48
  None:
49
- common.Logger.warning(
49
+ Logger.warning(
50
50
  'Weights Per Channel Quantization requires channel mapping function but framework info '
51
51
  'does not contain one')
52
52
  output_channels_axis, input_channels_axis = get_channels_axis(weights_qc,
@@ -16,7 +16,7 @@
16
16
  from sklearn.cluster import KMeans
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.constants import CLUSTER_CENTERS, MIN_THRESHOLD, SCALE_PER_CHANNEL
19
+ from model_compression_toolkit.constants import CLUSTER_CENTERS, MIN_THRESHOLD, SCALE_PER_CHANNEL
20
20
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import kmeans_assign_clusters
21
21
 
22
22
 
@@ -15,7 +15,7 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.constants import CLUSTER_CENTERS, SCALE_PER_CHANNEL, \
18
+ from model_compression_toolkit.constants import CLUSTER_CENTERS, SCALE_PER_CHANNEL, \
19
19
  MULTIPLIER_N_BITS
20
20
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import kmeans_assign_clusters, \
21
21
  get_quantized_tensor, int_quantization_with_threshold
@@ -17,8 +17,10 @@
17
17
  from typing import Tuple, List
18
18
  import numpy as np
19
19
 
20
- from model_compression_toolkit.core.common.constants import MIN_THRESHOLD, EPS
20
+ from model_compression_toolkit.constants import MIN_THRESHOLD, EPS
21
21
  from model_compression_toolkit.core import common
22
+ from model_compression_toolkit.logger import Logger
23
+
22
24
 
23
25
  def max_power_of_two(x: np.ndarray,
24
26
  min_threshold: float = MIN_THRESHOLD) -> np.ndarray:
@@ -236,7 +238,7 @@ def get_tensor_max(tensor_data: np.ndarray,
236
238
 
237
239
  """
238
240
  if n_bits < 1:
239
- common.Logger.error("n_bits must be positive")
241
+ Logger.error("n_bits must be positive")
240
242
  if is_uniform_quantization:
241
243
  expansion_factor = 1.0
242
244
  elif n_bits == 1:
@@ -15,8 +15,8 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.logger import Logger
19
- from model_compression_toolkit.core.common.constants import RANGE_MIN, RANGE_MAX, THRESHOLD
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX, THRESHOLD
20
20
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor, \
21
21
  quantize_tensor
22
22
 
@@ -17,7 +17,8 @@
17
17
  import copy
18
18
  from typing import List
19
19
 
20
- from model_compression_toolkit.core.common import Logger, BaseNode
20
+ from model_compression_toolkit.core.common import BaseNode
21
+ from model_compression_toolkit.logger import Logger
21
22
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
23
  from model_compression_toolkit.core.common.graph.base_graph import Graph
23
24
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
@@ -72,7 +73,7 @@ def set_quantization_configs_to_node(node: BaseNode,
72
73
  tpc: TargetPlatformCapabilities to get default OpQuantizationConfig.
73
74
  mixed_precision_enable: is mixed precision enabled
74
75
  """
75
- node_qc_options = tpc.get_qco_by_node(node)
76
+ node_qc_options = node.get_qco(tpc)
76
77
 
77
78
  # Create QC candidates for weights and activation combined
78
79
  weight_channel_axis = fw_info.kernel_channels_mapping.get(node.type)[0]
@@ -13,11 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Any, Tuple
16
+ from typing import Any
17
17
 
18
18
  import numpy as np
19
19
 
20
- from model_compression_toolkit.core.common.constants import EPS
20
+ from model_compression_toolkit.constants import EPS
21
21
 
22
22
  #########################
23
23
  # Helpful functions
@@ -14,12 +14,12 @@
14
14
  # ==============================================================================
15
15
  import copy
16
16
 
17
- from model_compression_toolkit import CoreConfig
17
+ from model_compression_toolkit.core import CoreConfig
18
18
  from model_compression_toolkit.core.common import Graph, BaseNode
19
19
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
20
20
 
21
21
 
22
- def apply_bias_correction_to_graph(graph: Graph,
22
+ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
23
23
  core_config: CoreConfig,
24
24
  fw_impl: FrameworkImplementation) -> Graph:
25
25
  """
@@ -27,7 +27,7 @@ def apply_bias_correction_to_graph(graph: Graph,
27
27
  correction term in it), and apply the bias correction for each node in the graph.
28
28
 
29
29
  Args:
30
- graph: Graph to apply bias correction to.
30
+ graph_to_apply_bias_correction: Graph to apply bias correction to.
31
31
  core_config: CoreConfig containing parameters of how the model should be quantized.
32
32
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
33
33
 
@@ -35,6 +35,7 @@ def apply_bias_correction_to_graph(graph: Graph,
35
35
  Graph with bias correction apply to it's nodes.
36
36
  """
37
37
 
38
+ graph = copy.deepcopy(graph_to_apply_bias_correction)
38
39
  for n in graph.nodes:
39
40
  if n.is_weights_quantization_enabled() and core_config.quantization_config.weights_bias_correction \
40
41
  and not n.final_weights_quantization_cfg.weights_second_moment_correction:
@@ -18,12 +18,13 @@ from typing import Any
18
18
 
19
19
  import numpy as np
20
20
 
21
- from model_compression_toolkit import CoreConfig
21
+ from model_compression_toolkit.core import CoreConfig
22
22
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
23
23
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
24
- from model_compression_toolkit.core.common import BaseNode, Logger, Graph
24
+ from model_compression_toolkit.core.common import BaseNode, Graph
25
25
  from model_compression_toolkit.core.common.quantization.quantize_node import get_quantized_kernel_by_weights_qc
26
26
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
27
+ from model_compression_toolkit.logger import Logger
27
28
 
28
29
 
29
30
  def compute_bias_correction_of_graph(graph: Graph,
@@ -20,7 +20,7 @@ from typing import Callable
20
20
  import numpy as np
21
21
 
22
22
  from model_compression_toolkit.core import common
23
- from model_compression_toolkit.core.common import Logger
23
+ from model_compression_toolkit.logger import Logger
24
24
  from model_compression_toolkit.core.common.graph.base_graph import Graph
25
25
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
26
26
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
@@ -23,8 +23,8 @@ from model_compression_toolkit.core.common.graph.base_graph import Graph
23
23
  from model_compression_toolkit.core.common.graph.graph_matchers import EdgeMatcher, NodeOperationMatcher
24
24
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
25
25
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
26
- from model_compression_toolkit.core.common.constants import THRESHOLD, RANGE_MIN, RANGE_MAX
27
- from model_compression_toolkit.core.common.logger import Logger
26
+ from model_compression_toolkit.constants import THRESHOLD, RANGE_MIN, RANGE_MAX
27
+ from model_compression_toolkit.logger import Logger
28
28
 
29
29
 
30
30
  class BatchNormalizationRefusing(common.BaseSubstitution):
@@ -17,7 +17,7 @@
17
17
  import copy
18
18
  import numpy as np
19
19
  from typing import Tuple, Callable
20
- from model_compression_toolkit.core.common.logger import Logger
20
+ from model_compression_toolkit.logger import Logger
21
21
  from model_compression_toolkit.core import common
22
22
  from model_compression_toolkit.core.common.graph.base_graph import Graph
23
23
  from model_compression_toolkit.core.common.graph.graph_matchers import EdgeMatcher, NodeOperationMatcher
@@ -16,9 +16,9 @@ import copy
16
16
  import numpy as np
17
17
  from typing import List, Tuple, Any, Callable
18
18
 
19
- from model_compression_toolkit.core.common.logger import Logger
19
+ from model_compression_toolkit.logger import Logger
20
20
  from model_compression_toolkit.core.common import FrameworkInfo, Graph, BaseNode
21
- from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
21
+ from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
22
22
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
23
23
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
24
24
  from model_compression_toolkit.core.common.quantization.set_node_quantization_config import create_node_activation_qc, \
@@ -356,7 +356,7 @@ def shift_negative_function(graph: Graph,
356
356
  bypass_candidate_qc.activation_quantization_cfg.activation_quantization_params[SIGNED] = False
357
357
  graph.shift_stats_collector(bypass_node, np.array(shift_value))
358
358
 
359
- add_node_qco = graph.tpc.get_qco_by_node(add_node).quantization_config_list
359
+ add_node_qco = add_node.get_qco(graph.tpc).quantization_config_list
360
360
  for op_qc_idx, candidate_qc in enumerate(add_node.candidates_quantization_cfg):
361
361
  candidate_qc.weights_quantization_cfg.enable_weights_quantization = False
362
362
 
@@ -495,7 +495,7 @@ def apply_shift_negative_correction(graph: Graph,
495
495
  nodes = list(graph.nodes())
496
496
  for n in nodes:
497
497
  # Skip substitution if QuantizationMethod is uniform.
498
- node_qco = graph.tpc.get_qco_by_node(n)
498
+ node_qco = n.get_qco(graph.tpc)
499
499
  if any([op_qc.activation_quantization_method is QuantizationMethod.UNIFORM
500
500
  for op_qc in node_qco.quantization_config_list]):
501
501
  continue
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from model_compression_toolkit.core.common import BaseNode, Graph, BaseSubstitution
17
- from model_compression_toolkit.core.common.logger import Logger
17
+ from model_compression_toolkit.logger import Logger
18
18
  from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode
19
19
 
20
20
 
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import itertools
17
- from model_compression_toolkit.core.common.logger import Logger
17
+ from model_compression_toolkit.logger import Logger
18
18
  from model_compression_toolkit.core.common import BaseNode, Graph, BaseSubstitution
19
19
  from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualSplitWeightsNode, \
20
20
  VirtualSplitActivationNode
@@ -31,7 +31,7 @@ from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
31
31
  from tensorboard.summary.writer.event_file_writer import EventFileWriter
32
32
  from typing import List, Any, Dict
33
33
  from networkx import topological_sort
34
- from model_compression_toolkit import FrameworkInfo
34
+ from model_compression_toolkit.core import FrameworkInfo
35
35
  from model_compression_toolkit.core.common import Graph, BaseNode
36
36
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
37
37
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common import Logger
16
+ from model_compression_toolkit.logger import Logger
17
17
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
18
18
  from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
19
19
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import List
16
16
 
17
- from model_compression_toolkit import FrameworkInfo
17
+ from model_compression_toolkit.core import FrameworkInfo
18
18
  from model_compression_toolkit.core.common import BaseNode
19
19
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
20
20
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
@@ -13,16 +13,15 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from abc import abstractmethod
17
-
18
16
  import tensorflow as tf
19
17
  from keras.engine.input_layer import InputLayer
20
18
  from keras.models import Model, clone_model
21
19
  from packaging import version
22
20
 
21
+ from model_compression_toolkit.constants import INPUT_BASE_NAME
23
22
  from model_compression_toolkit.core.common.back2framework.base_model_builder import BaseModelBuilder
24
23
  from model_compression_toolkit.core.common.user_info import UserInformation
25
- from model_compression_toolkit.core.common.constants import INPUT_BASE_NAME
24
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.activation_quantization_holder import ActivationQuantizationHolder
26
25
 
27
26
  # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
28
27
  if version.parse(tf.__version__) < version.parse("2.6"):
@@ -38,7 +37,6 @@ else:
38
37
  from typing import Any, Dict, List, Tuple, Callable
39
38
  from tensorflow.python.util.object_identity import Reference as TFReference
40
39
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
41
- from model_compression_toolkit.core.common.logger import Logger
42
40
  from model_compression_toolkit.core import common
43
41
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
44
42
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
@@ -95,7 +93,8 @@ class KerasModelBuilder(BaseModelBuilder):
95
93
  append2output=None,
96
94
  fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
97
95
  return_float_outputs: bool = False,
98
- wrapper: Callable = None):
96
+ wrapper: Callable = None,
97
+ get_activation_quantizer_holder_fn: Callable=None):
99
98
  """
100
99
 
101
100
  Args:
@@ -104,6 +103,8 @@ class KerasModelBuilder(BaseModelBuilder):
104
103
  fw_info: Information about the specific framework of the model that is built.
105
104
  return_float_outputs: Whether the model returns float tensors or not.
106
105
  wrapper: A function wrapper keras Layers.
106
+ get_activation_quantizer_holder_fn: Function to retrieve a quantization holder for a node.
107
+
107
108
  """
108
109
 
109
110
  super().__init__(graph,
@@ -114,6 +115,19 @@ class KerasModelBuilder(BaseModelBuilder):
114
115
  # Build an OperationHandler to handle conversions from graph nodes to Keras operators.
115
116
  self.oh = OperationHandler(self.graph)
116
117
  self.wrapper = wrapper
118
+ self.get_activation_quantizer_holder = get_activation_quantizer_holder_fn
119
+
120
+ @property
121
+ def use_activation_holder_during_model_building(self) -> bool:
122
+ """
123
+
124
+ Returns: Whether the model builder uses ActivationQuantizationHolder during
125
+ model building (by adding it as a layer when converting the graph to the Keras model)
126
+ or not. If so - the model builder expects the activation quantizers to not be wrapped
127
+ in KerasQuantizeWrapper that was received in its init.
128
+
129
+ """
130
+ return self.get_activation_quantizer_holder is not None
117
131
 
118
132
  def _quantize_node_activations(self,
119
133
  node: BaseNode,
@@ -187,9 +201,8 @@ class KerasModelBuilder(BaseModelBuilder):
187
201
  node_to_output_tensors_dict.update({n: [out_tensors_of_n]})
188
202
  node_to_output_tensors_dict_float.update({n: [out_tensors_of_n_float]})
189
203
 
190
- # convert node_to_output_tensors_dict keys to nodes' names since oh.node_sort contains different objects
191
- # than
192
- # original graph nodes.
204
+ # convert node_to_output_tensors_dict keys to nodes' names since oh.node_sort
205
+ # contains different objects than original graph nodes.
193
206
  node_name_to_outtensors = self._convert_node2name(node_to_output_tensors_dict)
194
207
  node_name_to_outtensors_float = self._convert_node2name(node_to_output_tensors_dict_float)
195
208
 
@@ -214,9 +227,12 @@ class KerasModelBuilder(BaseModelBuilder):
214
227
  def _wrap(layer):
215
228
  _node = self.oh.layer_to_node_dict.get(layer)
216
229
  if _node is not None:
217
- return self.wrapper(_node, layer)
218
- elif is_layer_fake_quant(layer):
230
+ return self.wrapper(_node,
231
+ layer)
232
+
233
+ elif is_layer_fake_quant(layer) or isinstance(layer, ActivationQuantizationHolder):
219
234
  return layer
235
+
220
236
  raise Exception( # pragma: no cover
221
237
  f'Mismatch between keras model and graph cant find node named: '
222
238
  f'{get_node_name_from_layer(layer)}')
@@ -278,13 +294,9 @@ class KerasModelBuilder(BaseModelBuilder):
278
294
  """
279
295
  if len(input_tensors) == 0: # Placeholder handling
280
296
  out_tensors_of_n_float = input_nodes_to_input_tensors[n]
281
- if self.wrapper is not None:
282
- # if a wrapper is defined, add an identity layer for cloning. The Identity will be warpped
283
- out_tensors_of_n = op_func(out_tensors_of_n_float)
284
- elif n.is_activation_quantization_enabled():
285
- out_tensors_of_n = self._quantize_node_activations(n, out_tensors_of_n_float)
286
- else:
287
- out_tensors_of_n = out_tensors_of_n_float
297
+ out_tensors_of_n = self._run_operation_activation_quantization(n,
298
+ out_tensors_of_n_float,
299
+ op_func)
288
300
  else:
289
301
  input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists
290
302
  # Build a functional node using its args
@@ -299,11 +311,9 @@ class KerasModelBuilder(BaseModelBuilder):
299
311
  if len(input_tensors) == 1:
300
312
  input_tensors = input_tensors[0]
301
313
  out_tensors_of_n_float = op_func(input_tensors)
302
- out_tensors_of_n = out_tensors_of_n_float
303
314
 
304
- # Add a fake quant node if the node has an activation threshold and a wrapper isn't defined
305
- if n.is_activation_quantization_enabled() and self.wrapper is None:
306
- out_tensors_of_n = self._quantize_node_activations(n, out_tensors_of_n_float)
315
+ out_tensors_of_n = self._run_operation_activation_quantization(n,
316
+ out_tensors_of_n_float)
307
317
 
308
318
  # Save a mapping from the layer that created the tensor to the node (as this layer is not the
309
319
  # same instance as op_func. We do this to solve an issue that names are different between these
@@ -318,3 +328,38 @@ class KerasModelBuilder(BaseModelBuilder):
318
328
  self.oh.layer_to_node_dict[layer_from_tensor] = n
319
329
 
320
330
  return out_tensors_of_n, out_tensors_of_n_float
331
+
332
+ def _run_operation_activation_quantization(self,
333
+ node: BaseNode,
334
+ node_outputs: List[TFReference],
335
+ identity_layer: Layer = None):
336
+ """
337
+ Quantize node's activations
338
+
339
+ Args:
340
+ node: Node to quantize its activations
341
+ node_outputs: Output tensors of the float node.
342
+ identity_layer: Identity layer (should be passed only when quantizing input layers)
343
+
344
+ Returns:
345
+ Quantized node's outputs.
346
+ """
347
+ if self.wrapper is not None:
348
+ # If identity layer was passed, use it for inference
349
+ # (needed since wrapping an Input layer can not be wrapped)
350
+ if identity_layer is not None:
351
+ node_outputs = identity_layer(node_outputs)
352
+
353
+ # In case the activation quantizer is attached out of the wrapper we use get_activation_quantizer_holder
354
+ # for the activation quantization holder (if the node's activations are quantized)
355
+ if node.is_activation_quantization_enabled() and self.use_activation_holder_during_model_building:
356
+ activation_quantizer_holder = self.get_activation_quantizer_holder(node)
357
+ quantized_node_outputs = activation_quantizer_holder(node_outputs)
358
+ return quantized_node_outputs
359
+
360
+ elif node.is_activation_quantization_enabled(): # Used only when old exporter is used
361
+ quantized_node_outputs = self._quantize_node_activations(node,
362
+ node_outputs)
363
+ return quantized_node_outputs
364
+
365
+ return node_outputs
@@ -36,7 +36,7 @@ else:
36
36
  from keras.layers.core import TFOpLambda, SlicingOpLambda
37
37
 
38
38
  from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
39
- from model_compression_toolkit.core.common.logger import Logger
39
+ from model_compression_toolkit.logger import Logger
40
40
  from model_compression_toolkit.core import common
41
41
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
42
42
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
@@ -26,13 +26,13 @@ else:
26
26
 
27
27
  from typing import Any, Dict, List, Tuple
28
28
  from tensorflow.python.util.object_identity import Reference as TFReference
29
- from model_compression_toolkit.core.common.constants import EPS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE
29
+ from model_compression_toolkit.constants import EPS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE
30
30
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
31
31
  from model_compression_toolkit.core import common
32
32
  from model_compression_toolkit.core.common import BaseNode, Graph
33
33
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
34
34
  from model_compression_toolkit.core.keras.back2framework.instance_builder import OperationHandler
35
- from model_compression_toolkit.core.common.logger import Logger
35
+ from model_compression_toolkit.logger import Logger
36
36
 
37
37
 
38
38
  def build_input_tensors_list(node: BaseNode,
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import List
16
16
 
17
- from model_compression_toolkit import FrameworkInfo
17
+ from model_compression_toolkit.core import FrameworkInfo
18
18
  from model_compression_toolkit.core import common
19
19
  from model_compression_toolkit.core.common import BaseNode
20
20
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
@@ -99,13 +99,6 @@ OUTPUT_BIAS = '/attention_output/bias'
99
99
  # ReLU bound constants
100
100
  RELU_POT_BOUND = 8.0
101
101
 
102
- # Supported TP models names for Tensorflow:
103
- DEFAULT_TP_MODEL = 'default'
104
- IMX500_TP_MODEL = 'imx500'
105
- TFLITE_TP_MODEL = 'tflite'
106
- QNNPACK_TP_MODEL = 'qnnpack'
107
-
108
-
109
102
  # TFOpLambda functions:
110
103
  ADD = 'add'
111
104
  PAD = 'pad'
@@ -25,9 +25,9 @@ else:
25
25
  from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU
26
26
 
27
27
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
28
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo, ChannelAxis
28
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
29
29
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
30
- from model_compression_toolkit.core.common.constants import SOFTMAX_THRESHOLD
30
+ from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
31
31
  from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
32
32
  KERNEL, DEPTHWISE_KERNEL
33
33
  from model_compression_toolkit.core.keras.quantizer.fake_quant_builder import power_of_two_quantization, symmetric_quantization, uniform_quantization
@@ -17,7 +17,7 @@
17
17
  from tensorflow.keras.layers import Dense, DepthwiseConv2D, Conv2D, Conv2DTranspose, Activation, SeparableConv2D
18
18
 
19
19
  from model_compression_toolkit.core import common
20
- from model_compression_toolkit.core.common.constants import FLOAT_32, DATA_TYPE
20
+ from model_compression_toolkit.constants import FLOAT_32, DATA_TYPE
21
21
  from model_compression_toolkit.core.common.graph.base_graph import Graph
22
22
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, \
23
23
  NodeFrameworkAttrMatcher
@@ -23,7 +23,7 @@ from model_compression_toolkit.core.common.graph.base_graph import Graph
23
23
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, EdgeMatcher, WalkMatcher
24
24
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
25
25
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
26
- from model_compression_toolkit.core.common.constants import THRESHOLD
26
+ from model_compression_toolkit.constants import THRESHOLD
27
27
  from model_compression_toolkit.core.keras.constants import KERNEL
28
28
 
29
29
  input_node = NodeOperationMatcher(InputLayer)
@@ -21,7 +21,7 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
21
21
  from model_compression_toolkit.core.common.substitutions.linear_collapsing import Conv2DCollapsing
22
22
  from model_compression_toolkit.core.keras.constants import KERNEL, KERNEL_SIZE, STRIDES, DILATIONS, LINEAR, \
23
23
  ACTIVATION, BIAS, USE_BIAS, LAYER_NAME, FILTERS, PADDING, GROUPS, DATA_FORMAT
24
- from model_compression_toolkit.core.common.logger import Logger
24
+ from model_compression_toolkit.logger import Logger
25
25
 
26
26
 
27
27
  def linear_collapsing_node_matchers() -> Tuple[NodeOperationMatcher, NodeOperationMatcher]:
@@ -23,17 +23,16 @@ else:
23
23
  from keras.layers.core import TFOpLambda
24
24
  from keras.layers import MultiHeadAttention, Conv2D, Softmax, Concatenate, Reshape, Permute
25
25
 
26
- from model_compression_toolkit.core.common.logger import Logger
26
+ from model_compression_toolkit.logger import Logger
27
27
  from model_compression_toolkit.core import common
28
28
  from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
29
29
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
30
30
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
31
- from model_compression_toolkit.core.common.constants import REUSE, REUSE_GROUP
32
- from model_compression_toolkit.core.keras.reader.node_builder import REUSED_IDENTIFIER
31
+ from model_compression_toolkit.constants import REUSE, REUSE_GROUP
33
32
  from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, NUM_HEADS, KEY_DIM, VALUE_DIM, \
34
33
  QUERY_SHAPE, KEY_SHAPE, VALUE_SHAPE, OUTPUT_SHAPE, ATTENTION_AXES, ACTIVATION, LINEAR, FILTERS, \
35
34
  FUNCTION, DIMS, TARGET_SHAPE, F_STRIDED_SLICE, F_STACK, Q_KERNEL, Q_BIAS, K_KERNEL, K_BIAS, V_KERNEL, V_BIAS, \
36
- OUTPUT_KERNEL, OUTPUT_BIAS, F_MATMUL, TRANSPOSE_B, KERNEL_SIZE, AXIS, F_STRIDED_SLICE_BEGIN, F_STRIDED_SLICE_END
35
+ OUTPUT_KERNEL, OUTPUT_BIAS, F_MATMUL, KERNEL_SIZE, AXIS, F_STRIDED_SLICE_BEGIN, F_STRIDED_SLICE_END
37
36
 
38
37
 
39
38
  class MHAParams:
@@ -23,6 +23,7 @@ from model_compression_toolkit.core import common
23
23
  from model_compression_toolkit.core.common import Graph, BaseNode
24
24
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, WalkMatcher
25
25
  from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, RELU_MAX_VALUE, RELU_POT_BOUND
26
+ from model_compression_toolkit.logger import Logger
26
27
 
27
28
 
28
29
  class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
@@ -81,7 +82,7 @@ class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
81
82
  scale_factor = max_value / self.threshold
82
83
 
83
84
  non_linear_node.framework_attr[RELU_MAX_VALUE] = np.float32(self.threshold)
84
- common.Logger.debug(
85
+ Logger.debug(
85
86
  f"Node named:{non_linear_node.name} max value change "
86
87
  f"to:{non_linear_node.framework_attr[RELU_MAX_VALUE]}")
87
88