mct-nightly 1.8.0.22032023.post333__py3-none-any.whl → 1.8.0.22052023.post408__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/METADATA +4 -3
  2. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/RECORD +294 -284
  3. model_compression_toolkit/__init__.py +9 -32
  4. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  5. model_compression_toolkit/core/__init__.py +14 -0
  6. model_compression_toolkit/core/analyzer.py +3 -2
  7. model_compression_toolkit/core/common/__init__.py +0 -1
  8. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  9. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  11. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  12. model_compression_toolkit/core/common/framework_info.py +1 -1
  13. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  14. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  15. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  16. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  18. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  19. model_compression_toolkit/core/common/memory_computation.py +1 -1
  20. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  21. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  23. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  26. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  28. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  29. model_compression_toolkit/core/common/model_collector.py +2 -2
  30. model_compression_toolkit/core/common/model_validation.py +1 -1
  31. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  32. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  33. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  34. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  35. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  36. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  37. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  38. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  39. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  50. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  51. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  52. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  54. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  55. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  56. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  57. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  58. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  60. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  63. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  65. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  66. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  67. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  69. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +66 -21
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  73. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  74. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  75. model_compression_toolkit/core/keras/constants.py +0 -7
  76. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  85. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  86. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  87. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  88. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  89. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  90. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  91. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  92. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  93. model_compression_toolkit/core/keras/reader/common.py +1 -1
  94. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  95. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  99. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  100. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/constants.py +0 -6
  102. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  103. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  109. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  110. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  111. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  112. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  113. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  114. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  115. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  116. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  117. model_compression_toolkit/core/runner.py +7 -7
  118. model_compression_toolkit/exporter/__init__.py +6 -3
  119. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  120. model_compression_toolkit/exporter/model_exporter/keras/export_serialization_format.py +20 -0
  121. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/fakely_quant_tflite_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/int8_tflite_exporter.py +1 -1
  124. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +60 -22
  125. model_compression_toolkit/exporter/model_exporter/pytorch/export_serialization_format.py +20 -0
  126. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  127. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  128. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +54 -31
  129. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +5 -3
  130. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +5 -3
  131. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +2 -2
  132. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +3 -2
  133. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +4 -3
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +2 -2
  135. model_compression_toolkit/gptq/common/gptq_config.py +2 -4
  136. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  137. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  138. model_compression_toolkit/gptq/common/gptq_training.py +5 -4
  139. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  140. model_compression_toolkit/gptq/keras/gptq_training.py +41 -14
  141. model_compression_toolkit/gptq/keras/graph_info.py +4 -0
  142. model_compression_toolkit/gptq/keras/quantization_facade.py +27 -20
  143. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  144. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -2
  145. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  146. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -5
  147. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +21 -16
  148. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  149. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +2 -2
  150. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  151. model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
  152. model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -13
  153. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  154. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -3
  155. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  156. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -2
  157. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +13 -5
  158. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  159. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +2 -2
  160. model_compression_toolkit/gptq/runner.py +3 -2
  161. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +11 -12
  162. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  163. model_compression_toolkit/ptq/__init__.py +3 -0
  164. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  165. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  166. model_compression_toolkit/qat/__init__.py +4 -0
  167. model_compression_toolkit/qat/common/__init__.py +1 -2
  168. model_compression_toolkit/qat/common/qat_config.py +5 -1
  169. model_compression_toolkit/qat/keras/quantization_facade.py +34 -28
  170. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  171. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +31 -4
  172. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +13 -11
  173. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +9 -9
  174. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  175. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  176. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +4 -3
  177. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +7 -5
  178. model_compression_toolkit/quantizers_infrastructure/__init__.py +2 -2
  179. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  180. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  181. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +5 -0
  182. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +2 -2
  183. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +147 -0
  184. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +5 -5
  185. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  186. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  187. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +3 -5
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +2 -3
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  211. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +9 -9
  212. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  213. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -6
  214. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  215. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  217. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +2 -2
  218. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  219. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  220. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  221. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  222. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  223. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  224. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  225. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  226. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  227. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  228. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  229. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  232. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  233. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  234. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  235. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  236. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  237. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  238. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  239. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  240. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  241. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  242. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  243. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  244. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  248. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  250. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  254. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  255. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  257. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  259. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  261. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  263. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  265. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  273. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  274. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  275. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  276. model_compression_toolkit/exporter/model_exporter/tflite/__init__.py +0 -14
  277. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +0 -73
  278. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/LICENSE.md +0 -0
  279. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/WHEEL +0 -0
  280. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/top_level.txt +0 -0
  281. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  282. /model_compression_toolkit/{core/common/logger.py → logger.py} +0 -0
  283. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  284. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  285. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  286. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  287. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  288. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  289. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  290. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  291. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  292. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  293. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  294. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -16,15 +16,14 @@
16
16
  from typing import Callable
17
17
  from functools import partial
18
18
 
19
- from model_compression_toolkit import CoreConfig
20
- from model_compression_toolkit.core import common
21
- from model_compression_toolkit.core.common import Logger
22
- from model_compression_toolkit.core.common.constants import TENSORFLOW, FOUND_TF
23
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
19
+ from model_compression_toolkit.core import CoreConfig
20
+ from model_compression_toolkit.logger import Logger
21
+ from model_compression_toolkit.constants import FOUND_TF
24
22
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
25
23
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
26
24
  MixedPrecisionQuantizationConfigV2
27
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework import TargetPlatformCapabilities
25
+ from model_compression_toolkit.quantizers_infrastructure import ActivationQuantizationHolder
26
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
28
27
  from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
29
28
  from model_compression_toolkit.ptq.runner import ptq_runner
30
29
 
@@ -36,7 +35,7 @@ if FOUND_TF:
36
35
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
37
36
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
38
37
  from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
39
- from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL
38
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
40
39
 
41
40
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
42
41
 
@@ -46,33 +45,36 @@ if FOUND_TF:
46
45
  from model_compression_toolkit import get_target_platform_capabilities
47
46
  from model_compression_toolkit.core import common
48
47
  from model_compression_toolkit.core.common import BaseNode
49
- from model_compression_toolkit.core.common.constants import TENSORFLOW
48
+ from model_compression_toolkit.constants import TENSORFLOW
50
49
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
51
50
  from model_compression_toolkit.qat.common.qat_config import _is_qat_applicable
52
- from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL
51
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
53
52
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
54
- from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder
53
+ from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder, \
54
+ get_activation_quantizer_holder
55
55
  from model_compression_toolkit.qat.common.qat_config import QATConfig
56
56
  from model_compression_toolkit import quantizers_infrastructure as qi
57
57
 
58
58
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
59
59
 
60
60
 
61
- def qat_wrapper(n: common.BaseNode, layer: Layer, qat_config):
61
+ def qat_wrapper(n: common.BaseNode,
62
+ layer: Layer,
63
+ qat_config: QATConfig):
62
64
  """
63
65
  A function which takes a computational graph node and a keras layer and perform the quantization wrapping
64
66
  Args:
67
+ qat_config: Configuration of QAT (such as training methods for example).
65
68
  n: A node of mct graph.
66
- layer: A keras layer
69
+ layer: A keras layer.
67
70
 
68
71
  Returns: Wrapped layer
69
72
 
70
73
  """
71
74
  if _is_qat_applicable(n, DEFAULT_KERAS_INFO):
72
75
  weights_quantizers, activation_quantizers = quantization_builder(n, qat_config, DEFAULT_KERAS_INFO)
73
- return qi.KerasQuantizationWrapper(layer, weights_quantizers, activation_quantizers)
74
- else:
75
- return layer
76
+ return qi.KerasQuantizationWrapper(layer, weights_quantizers)
77
+ return layer
76
78
 
77
79
 
78
80
  def keras_quantization_aware_training_init(in_model: Model,
@@ -134,24 +136,24 @@ if FOUND_TF:
134
136
 
135
137
  Create a MCT core config, containing the quantization configuration:
136
138
 
137
- >>> config = mct.CoreConfig()
139
+ >>> config = mct.core.CoreConfig()
138
140
 
139
141
  If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
140
142
  The candidates bitwidth for quantization should be defined in the target platform model:
141
143
 
142
- >>> config = mct.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfigV2())
144
+ >>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfigV2())
143
145
 
144
146
  For mixed-precision set a target KPI object:
145
147
  Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
146
148
  that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
147
149
  while the bias will not):
148
150
 
149
- >>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
151
+ >>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
150
152
 
151
153
  Pass the model, the representative dataset generator, the configuration and the target KPI to get a
152
154
  quantized model:
153
155
 
154
- >>> quantized_model, quantization_info, custom_objects = mct.keras_quantization_aware_training_init(model, repr_datagen, kpi, core_config=config)
156
+ >>> quantized_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init(model, repr_datagen, kpi, core_config=config)
155
157
 
156
158
  Use the quantized model for fine-tuning. For loading the model from file, use the custom_objects dictionary:
157
159
 
@@ -165,11 +167,11 @@ if FOUND_TF:
165
167
 
166
168
  if core_config.mixed_precision_enable:
167
169
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
168
- common.Logger.error("Given quantization config to mixed-precision facade is not of type "
170
+ Logger.error("Given quantization config to mixed-precision facade is not of type "
169
171
  "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization API,"
170
172
  "or pass a valid mixed precision configuration.")
171
173
 
172
- common.Logger.info("Using experimental mixed-precision quantization. "
174
+ Logger.info("Using experimental mixed-precision quantization. "
173
175
  "If you encounter an issue please file a bug.")
174
176
 
175
177
  tb_w = _init_tensorboard_writer(fw_info)
@@ -188,7 +190,11 @@ if FOUND_TF:
188
190
  tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
189
191
 
190
192
  _qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
191
- qat_model, user_info = KerasModelBuilder(graph=tg, fw_info=fw_info, wrapper=_qat_wrapper).build_model()
193
+ qat_model, user_info = KerasModelBuilder(graph=tg,
194
+ fw_info=fw_info,
195
+ wrapper=_qat_wrapper,
196
+ get_activation_quantizer_holder_fn=partial(get_activation_quantizer_holder,
197
+ qat_config=qat_config)).build_model()
192
198
 
193
199
  user_info.mixed_precision_cfg = bit_widths_config
194
200
  #TODO: remove the last output after updating documentation.
@@ -223,33 +229,33 @@ if FOUND_TF:
223
229
 
224
230
  Create a MCT core config, containing the quantization configuration:
225
231
 
226
- >>> config = mct.CoreConfig()
232
+ >>> config = mct.core.CoreConfig()
227
233
 
228
234
  If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
229
235
  The candidates bitwidth for quantization should be defined in the target platform model:
230
236
 
231
- >>> config = mct.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfigV2())
237
+ >>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfigV2())
232
238
 
233
239
  For mixed-precision set a target KPI object:
234
240
  Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
235
241
  that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
236
242
  while the bias will not):
237
243
 
238
- >>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
244
+ >>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
239
245
 
240
246
  Pass the model, the representative dataset generator, the configuration and the target KPI to get a
241
247
  quantized model:
242
248
 
243
- >>> quantized_model, quantization_info, custom_objects = mct.keras_quantization_aware_training_init(model, repr_datagen, kpi, core_config=config)
249
+ >>> quantized_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init(model, repr_datagen, kpi, core_config=config)
244
250
 
245
251
  Use the quantized model for fine-tuning. For loading the model from file, use the custom_objects dictionary:
246
252
 
247
253
  >>> quantized_model = tf.keras.models.load_model(model_file, custom_objects=custom_objects)
248
- >>> quantized_model = mct.keras_quantization_aware_training_finalize(quantized_model)
254
+ >>> quantized_model = mct.qat.keras_quantization_aware_training_finalize(quantized_model)
249
255
 
250
256
  """
251
257
  def _export(layer):
252
- if isinstance(layer, qi.KerasQuantizationWrapper):
258
+ if isinstance(layer, (qi.KerasQuantizationWrapper, ActivationQuantizationHolder)):
253
259
  layer.convert_to_inferable_quantizers()
254
260
  return layer
255
261
 
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  from typing import Union
16
16
 
17
- from model_compression_toolkit.core.common import Logger
18
- from model_compression_toolkit.core.common.constants import FOUND_TF
17
+ from model_compression_toolkit.logger import Logger
18
+ from model_compression_toolkit.constants import FOUND_TF
19
19
 
20
20
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
21
21
  TrainableQuantizerActivationConfig, BaseKerasTrainableQuantizer
@@ -12,20 +12,47 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Tuple, Dict, List
15
+ from typing import Tuple, Dict, List, Union, Callable
16
16
 
17
17
  from model_compression_toolkit.core import common
18
18
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
19
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
20
+ from model_compression_toolkit.logger import Logger
21
+ from model_compression_toolkit.qat.common.qat_config import QATConfig, _is_qat_applicable
22
+ from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
23
+ from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget, ActivationQuantizationHolder
19
24
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizer_config import \
20
25
  get_trainable_quantizer_weights_config, get_trainable_quantizer_activation_config, \
21
26
  get_trainable_quantizer_quantization_candidates
22
- from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
23
- from model_compression_toolkit.qat.common.qat_config import QATConfig
24
- from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
25
27
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizers import \
26
28
  get_trainable_quantizer_class
27
29
 
28
30
 
31
+ def get_activation_quantizer_holder(n: common.BaseNode,
32
+ qat_config: QATConfig) -> Union[None, Callable]:
33
+ """
34
+ Retrieve a ActivationQuantizationHolder layer to use for activation quantization for a node.
35
+ If the layer is not supposed to be wrapped with activation quantizers - return None.
36
+
37
+ Args:
38
+ n: Node to get ActivationQuantizationHolder to attach in its output.
39
+ qat_config: Configuration of QAT (such as training methods for example).
40
+
41
+ Returns:
42
+ A ActivationQuantizationHolder layer for the node activation quantization.
43
+ """
44
+ _, activation_quantizers = quantization_builder(n,
45
+ qat_config,
46
+ DEFAULT_KERAS_INFO)
47
+
48
+ # Holder by definition uses a single quantizer for the activation quantization
49
+ # thus we make sure this is the only possible case (unless it's a node with no activation
50
+ # quantization, which in this case has an empty list).
51
+ if len(activation_quantizers) == 1:
52
+ return ActivationQuantizationHolder(activation_quantizers[0])
53
+ Logger.error(f'ActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers were found for node {n}')
54
+
55
+
29
56
  def quantization_builder(n: common.BaseNode,
30
57
  qat_config: QATConfig,
31
58
  fw_info: FrameworkInfo,
@@ -18,13 +18,15 @@ from typing import Union
18
18
  import numpy as np
19
19
  import tensorflow as tf
20
20
  from tensorflow.python.framework.tensor_shape import TensorShape
21
- from model_compression_toolkit.core.common.constants import SIGNED
21
+ from model_compression_toolkit.constants import SIGNED
22
+ from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
22
23
 
23
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
24
+ from model_compression_toolkit.qat import TrainingMethod
25
+
26
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
24
27
  from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
25
- from model_compression_toolkit.qat.common.constants import FQ_MIN, FQ_MAX
26
- from model_compression_toolkit import quantizers_infrastructure as qi, TrainingMethod
27
- from model_compression_toolkit.core.common import constants as C
28
+ from model_compression_toolkit import quantizers_infrastructure as qi, constants as C
29
+
28
30
  from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
29
31
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
30
32
  TrainableQuantizerActivationConfig
@@ -53,11 +55,11 @@ class STEWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
53
55
  """
54
56
  super().__init__(quantization_config)
55
57
  self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
56
- self.threshold_values = quantization_config.weights_quantization_params[C.THRESHOLD]
58
+ self.threshold_values = np.array(quantization_config.weights_quantization_params[C.THRESHOLD])
57
59
  self.threshold_shape = np.asarray(self.threshold_values).shape
58
60
  self.per_channel = self.quantization_config.weights_per_channel_threshold
59
61
  self.channel_axis = self.quantization_config.weights_channels_axis
60
- self.np_threshold_values = np.reshape(np.asarray(self.threshold_values),[-1]) if self.channel_axis else float(self.threshold_values)
62
+ self.np_threshold_values = np.reshape(np.asarray(self.threshold_values),[-1]) if self.per_channel else float(self.threshold_values)
61
63
 
62
64
  if self.per_channel and self.channel_axis not in [-1, len(self.threshold_shape) - 1]:
63
65
  # Tensorflow's fake_quant_with_min_max_vars_per_channel only works on last axis, so
@@ -93,21 +95,21 @@ class STEWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
93
95
  """
94
96
  ptq_threshold_tensor = layer.add_weight(
95
97
  name + THRESHOLD_TENSOR,
96
- shape=len(self.np_threshold_values) if self.channel_axis else (),
98
+ shape=len(self.np_threshold_values) if self.per_channel else (),
97
99
  initializer=tf.keras.initializers.Constant(1.0),
98
100
  trainable=False)
99
101
  ptq_threshold_tensor.assign(self.np_threshold_values)
100
102
 
101
103
  fq_min = layer.add_weight(
102
104
  name + FQ_MIN,
103
- shape=len(self.min) if self.channel_axis else (),
105
+ shape=len(self.min) if self.per_channel else (),
104
106
  initializer=tf.keras.initializers.Constant(-1.0),
105
107
  trainable=False)
106
108
  fq_min.assign(self.min)
107
109
 
108
110
  fq_max = layer.add_weight(
109
111
  name + FQ_MAX,
110
- shape=len(self.max) if self.channel_axis else (),
112
+ shape=len(self.max) if self.per_channel else (),
111
113
  initializer=tf.keras.initializers.Constant(1.0),
112
114
  trainable=False)
113
115
  fq_max.assign(self.max)
@@ -134,7 +136,7 @@ class STEWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
134
136
 
135
137
  _min = self.get_quantizer_variable(FQ_MIN)
136
138
  _max = self.get_quantizer_variable(FQ_MAX)
137
- if self.channel_axis:
139
+ if self.per_channel:
138
140
  if self.perm_vec:
139
141
  inputs = tf.transpose(inputs, perm=self.perm_vec)
140
142
  q_tensor = tf.quantization.fake_quant_with_min_max_vars_per_channel(inputs, _min, _max,
@@ -15,13 +15,15 @@
15
15
  import numpy as np
16
16
  import tensorflow as tf
17
17
  from tensorflow.python.framework.tensor_shape import TensorShape
18
- from model_compression_toolkit.core.common.constants import RANGE_MIN, RANGE_MAX
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
- from model_compression_toolkit.qat.common.constants import FQ_MIN, FQ_MAX
18
+ from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX
19
+ from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
20
+ from model_compression_toolkit.qat import TrainingMethod
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
+
21
23
  from model_compression_toolkit.qat.keras.quantizer.quant_utils import adjust_range_to_include_zero
22
24
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero
23
- from model_compression_toolkit import quantizers_infrastructure as qi, TrainingMethod
24
- from model_compression_toolkit.core.common import constants as C
25
+ from model_compression_toolkit import quantizers_infrastructure as qi, constants as C
26
+
25
27
  from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
26
28
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
27
29
  TrainableQuantizerActivationConfig
@@ -50,8 +52,8 @@ class STEUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
50
52
 
51
53
  """
52
54
  super().__init__(quantization_config)
53
- self.max_values = quantization_config.weights_quantization_params[RANGE_MAX]
54
- self.min_values = quantization_config.weights_quantization_params[RANGE_MIN]
55
+ self.max_values = np.array(quantization_config.weights_quantization_params[RANGE_MAX])
56
+ self.min_values = np.array(quantization_config.weights_quantization_params[RANGE_MIN])
55
57
  self.num_bits = self.quantization_config.weights_n_bits
56
58
  self.per_channel = self.quantization_config.weights_per_channel_threshold
57
59
  self.channel_axis = self.quantization_config.weights_channels_axis
@@ -98,7 +100,6 @@ class STEUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
98
100
  self.add_quantizer_variable(FQ_MIN, fq_min, VariableGroup.QPARAMS)
99
101
  self.add_quantizer_variable(FQ_MAX, fq_max, VariableGroup.QPARAMS)
100
102
 
101
-
102
103
  def __call__(self, inputs: tf.Tensor,
103
104
  training: bool):
104
105
  """
@@ -199,7 +200,6 @@ class STEUniformActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
199
200
  self.add_quantizer_variable(FQ_MIN, fq_min, VariableGroup.QPARAMS)
200
201
  self.add_quantizer_variable(FQ_MAX, fq_max, VariableGroup.QPARAMS)
201
202
 
202
-
203
203
  def __call__(self,
204
204
  inputs: tf.Tensor,
205
205
  training: bool):
@@ -16,16 +16,16 @@ import copy
16
16
  from typing import Callable
17
17
  from functools import partial
18
18
 
19
- from model_compression_toolkit.core.common.constants import FOUND_TORCH, PYTORCH
19
+ from model_compression_toolkit.constants import FOUND_TORCH, PYTORCH
20
20
 
21
- from model_compression_toolkit import CoreConfig
21
+ from model_compression_toolkit.core import CoreConfig
22
22
  from model_compression_toolkit.core import common
23
- from model_compression_toolkit.core.common import Logger
23
+ from model_compression_toolkit.logger import Logger
24
24
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
25
25
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
26
26
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
27
27
  MixedPrecisionQuantizationConfigV2
28
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework import TargetPlatformCapabilities
28
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
29
29
  from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
30
30
  from model_compression_toolkit.ptq.runner import ptq_runner
31
31
 
@@ -34,7 +34,7 @@ if FOUND_TORCH:
34
34
  import torch.nn as nn
35
35
  from torch.nn import Module
36
36
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
37
- from model_compression_toolkit.core.pytorch.constants import DEFAULT_TP_MODEL
37
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
38
38
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
39
39
  from model_compression_toolkit.qat.common.qat_config import _is_qat_applicable
40
40
  from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
@@ -121,7 +121,7 @@ if FOUND_TORCH:
121
121
 
122
122
  Create a MCT core config, containing the quantization configuration:
123
123
 
124
- >>> config = mct.CoreConfig()
124
+ >>> config = mct.core.CoreConfig()
125
125
 
126
126
  Pass the model, the representative dataset generator, the configuration and the target KPI to get a
127
127
  quantized model. Now the model contains quantizer wrappers for fine tunning the weights:
@@ -134,11 +134,11 @@ if FOUND_TORCH:
134
134
 
135
135
  if core_config.mixed_precision_enable:
136
136
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
137
- common.Logger.error("Given quantization config to mixed-precision facade is not of type "
137
+ Logger.error("Given quantization config to mixed-precision facade is not of type "
138
138
  "MixedPrecisionQuantizationConfigV2. Please use pytorch_post_training_quantization API,"
139
139
  "or pass a valid mixed precision configuration.")
140
140
 
141
- common.Logger.info("Using experimental mixed-precision quantization. "
141
+ Logger.info("Using experimental mixed-precision quantization. "
142
142
  "If you encounter an issue please file a bug.")
143
143
 
144
144
  tb_w = _init_tensorboard_writer(fw_info)
@@ -193,7 +193,7 @@ if FOUND_TORCH:
193
193
 
194
194
  Create a MCT core config, containing the quantization configuration:
195
195
 
196
- >>> config = mct.CoreConfig()
196
+ >>> config = mct.core.CoreConfig()
197
197
 
198
198
  Pass the model, the representative dataset generator, the configuration and the target KPI to get a
199
199
  quantized model:
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  from typing import Union
16
16
 
17
- from model_compression_toolkit.core.common.logger import Logger
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
17
+ from model_compression_toolkit.logger import Logger
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
19
 
20
20
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
21
21
  TrainableQuantizerActivationConfig
@@ -18,12 +18,13 @@ import numpy as np
18
18
  import torch
19
19
  import torch.nn as nn
20
20
 
21
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
21
+ from model_compression_toolkit.qat import TrainingMethod
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
23
  from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
23
- from model_compression_toolkit import quantizers_infrastructure as qi, TrainingMethod
24
+ from model_compression_toolkit import quantizers_infrastructure as qi, constants as C
24
25
  from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
25
26
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
26
- from model_compression_toolkit.core.common import constants as C
27
+
27
28
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
28
29
  from model_compression_toolkit.qat.pytorch.quantizer.quantizer_utils import ste_round, ste_clip, symmetric_quantizer
29
30
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
@@ -17,11 +17,13 @@ import torch
17
17
  import torch.nn as nn
18
18
  from torch import Tensor
19
19
 
20
- from model_compression_toolkit.core.common.constants import RANGE_MAX, RANGE_MIN
21
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
22
- from model_compression_toolkit.qat.common.constants import FQ_MIN, FQ_MAX
23
- from model_compression_toolkit.core.common import constants as C
24
- from model_compression_toolkit import quantizers_infrastructure as qi, TrainingMethod
20
+ from model_compression_toolkit.constants import RANGE_MAX, RANGE_MIN
21
+ from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
22
+
23
+ from model_compression_toolkit.qat import TrainingMethod
24
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
25
+ from model_compression_toolkit import quantizers_infrastructure as qi, constants as C
26
+
25
27
  from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
26
28
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
27
29
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
@@ -14,10 +14,10 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget, BaseInferableQuantizer
17
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
18
- TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
17
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
19
18
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.keras.base_keras_quantizer import BaseKerasTrainableQuantizer
20
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.load_model import keras_load_quantized_model
21
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantize_wrapper import KerasQuantizationWrapper
22
21
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantize_wrapper import PytorchQuantizationWrapper
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.activation_quantization_holder import ActivationQuantizationHolder
23
23
 
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ # Quantizers constants (for GPTQ, QAT, etc.)
16
17
  FQ_MIN = "min"
17
18
  FQ_MAX = "max"
18
19
  THRESHOLD_TENSOR = "ptq_threshold_tensor"
@@ -15,7 +15,7 @@
15
15
  from enum import Enum
16
16
  from typing import Any, Dict, List
17
17
 
18
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
19
19
 
20
20
 
21
21
  class QuantizationTarget(Enum):
@@ -15,7 +15,12 @@
15
15
 
16
16
  IS_WEIGHTS = "is_weights"
17
17
  IS_ACTIVATIONS = "is_activations"
18
+
19
+ # In KerasQuantizationWrapper and PytorchQuantizationWrapper multiple quantizers are kept
18
20
  ACTIVATION_QUANTIZERS = "activation_quantizers"
21
+ # In ActivationQuantizationHolder only one quantizer is used thus a new attribute name is needed
22
+ ACTIVATION_HOLDER_QUANTIZER = "activation_holder_quantizer"
23
+
19
24
  WEIGHTS_QUANTIZERS = "weights_quantizer"
20
25
  WEIGHTS_QUANTIZATION_METHOD = 'weights_quantization_method'
21
26
  WEIGHTS_N_BITS = 'weights_n_bits'
@@ -13,8 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common import Logger
17
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
16
+ from model_compression_toolkit.logger import Logger
17
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
18
18
  from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
19
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import QUANTIZATION_TARGET, \
20
20
  QUANTIZATION_METHOD