mct-nightly 1.8.0.22032023.post333__py3-none-any.whl → 1.8.0.22052023.post408__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/METADATA +4 -3
  2. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/RECORD +294 -284
  3. model_compression_toolkit/__init__.py +9 -32
  4. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  5. model_compression_toolkit/core/__init__.py +14 -0
  6. model_compression_toolkit/core/analyzer.py +3 -2
  7. model_compression_toolkit/core/common/__init__.py +0 -1
  8. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  9. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  11. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  12. model_compression_toolkit/core/common/framework_info.py +1 -1
  13. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  14. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  15. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  16. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  18. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  19. model_compression_toolkit/core/common/memory_computation.py +1 -1
  20. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  21. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  23. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  26. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  28. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  29. model_compression_toolkit/core/common/model_collector.py +2 -2
  30. model_compression_toolkit/core/common/model_validation.py +1 -1
  31. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  32. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  33. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  34. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  35. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  36. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  37. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  38. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  39. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  50. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  51. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  52. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  54. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  55. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  56. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  57. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  58. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  60. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  63. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  65. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  66. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  67. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  69. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +66 -21
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  73. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  74. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  75. model_compression_toolkit/core/keras/constants.py +0 -7
  76. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  85. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  86. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  87. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  88. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  89. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  90. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  91. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  92. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  93. model_compression_toolkit/core/keras/reader/common.py +1 -1
  94. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  95. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  99. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  100. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/constants.py +0 -6
  102. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  103. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  109. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  110. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  111. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  112. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  113. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  114. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  115. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  116. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  117. model_compression_toolkit/core/runner.py +7 -7
  118. model_compression_toolkit/exporter/__init__.py +6 -3
  119. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  120. model_compression_toolkit/exporter/model_exporter/keras/export_serialization_format.py +20 -0
  121. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/fakely_quant_tflite_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/int8_tflite_exporter.py +1 -1
  124. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +60 -22
  125. model_compression_toolkit/exporter/model_exporter/pytorch/export_serialization_format.py +20 -0
  126. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  127. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  128. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +54 -31
  129. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +5 -3
  130. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +5 -3
  131. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +2 -2
  132. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +3 -2
  133. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +4 -3
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +2 -2
  135. model_compression_toolkit/gptq/common/gptq_config.py +2 -4
  136. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  137. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  138. model_compression_toolkit/gptq/common/gptq_training.py +5 -4
  139. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  140. model_compression_toolkit/gptq/keras/gptq_training.py +41 -14
  141. model_compression_toolkit/gptq/keras/graph_info.py +4 -0
  142. model_compression_toolkit/gptq/keras/quantization_facade.py +27 -20
  143. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  144. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -2
  145. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  146. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -5
  147. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +21 -16
  148. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  149. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +2 -2
  150. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  151. model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
  152. model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -13
  153. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  154. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -3
  155. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  156. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -2
  157. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +13 -5
  158. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  159. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +2 -2
  160. model_compression_toolkit/gptq/runner.py +3 -2
  161. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +11 -12
  162. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  163. model_compression_toolkit/ptq/__init__.py +3 -0
  164. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  165. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  166. model_compression_toolkit/qat/__init__.py +4 -0
  167. model_compression_toolkit/qat/common/__init__.py +1 -2
  168. model_compression_toolkit/qat/common/qat_config.py +5 -1
  169. model_compression_toolkit/qat/keras/quantization_facade.py +34 -28
  170. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  171. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +31 -4
  172. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +13 -11
  173. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +9 -9
  174. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  175. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  176. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +4 -3
  177. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +7 -5
  178. model_compression_toolkit/quantizers_infrastructure/__init__.py +2 -2
  179. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  180. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  181. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +5 -0
  182. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +2 -2
  183. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +147 -0
  184. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +5 -5
  185. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  186. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  187. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +3 -5
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +2 -3
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  211. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +9 -9
  212. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  213. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -6
  214. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  215. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  217. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +2 -2
  218. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  219. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  220. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  221. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  222. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  223. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  224. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  225. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  226. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  227. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  228. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  229. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  232. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  233. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  234. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  235. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  236. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  237. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  238. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  239. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  240. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  241. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  242. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  243. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  244. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  248. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  250. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  254. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  255. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  257. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  259. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  261. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  263. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  265. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  273. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  274. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  275. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  276. model_compression_toolkit/exporter/model_exporter/tflite/__init__.py +0 -14
  277. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +0 -73
  278. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/LICENSE.md +0 -0
  279. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/WHEEL +0 -0
  280. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/top_level.txt +0 -0
  281. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  282. /model_compression_toolkit/{core/common/logger.py → logger.py} +0 -0
  283. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  284. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  285. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  286. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  287. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  288. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  289. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  290. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  291. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  292. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  293. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  294. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -16,7 +16,7 @@
16
16
  from collections.abc import Callable
17
17
  from functools import partial
18
18
 
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
20
  from model_compression_toolkit.core.common.quantization.quantizers.kmeans_quantizer import kmeans_quantizer
21
21
  from model_compression_toolkit.core.common.quantization.quantizers.lut_kmeans_quantizer import lut_kmeans_quantizer
22
22
  from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import power_of_two_quantizer, \
@@ -16,8 +16,8 @@
16
16
  from collections.abc import Callable
17
17
  from functools import partial
18
18
 
19
- from model_compression_toolkit.core.common.logger import Logger
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+ from model_compression_toolkit.logger import Logger
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
21
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.kmeans_params import kmeans_tensor
22
22
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import \
23
23
  lut_kmeans_tensor, lut_kmeans_histogram
@@ -17,10 +17,11 @@ from typing import Tuple, Callable
17
17
  import numpy as np
18
18
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
19
19
  from model_compression_toolkit.core.common.similarity_analyzer import compute_mse, compute_mae, compute_lp_norm
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
21
- from model_compression_toolkit.core.common.constants import FLOAT_32
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
+ from model_compression_toolkit.constants import FLOAT_32
22
22
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor
23
23
 
24
+
24
25
  def _mse_error_histogram(q_bins: np.ndarray,
25
26
  q_count: np.ndarray,
26
27
  bins: np.ndarray,
@@ -17,7 +17,7 @@ import numpy as np
17
17
  from sklearn.cluster import KMeans
18
18
 
19
19
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
20
- from model_compression_toolkit.core.common.constants import CLUSTER_CENTERS, SCALE_PER_CHANNEL, MIN_THRESHOLD, EPS
20
+ from model_compression_toolkit.constants import CLUSTER_CENTERS, SCALE_PER_CHANNEL, MIN_THRESHOLD, EPS
21
21
 
22
22
 
23
23
  def kmeans_tensor(tensor_data: np.ndarray,
@@ -17,7 +17,7 @@ import numpy as np
17
17
  from sklearn.cluster import KMeans
18
18
 
19
19
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
20
- from model_compression_toolkit.core.common.constants import CLUSTER_CENTERS, MIN_THRESHOLD, SCALE_PER_CHANNEL, \
20
+ from model_compression_toolkit.constants import CLUSTER_CENTERS, MIN_THRESHOLD, SCALE_PER_CHANNEL, \
21
21
  MULTIPLIER_N_BITS, THRESHOLD
22
22
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import \
23
23
  max_power_of_two, int_quantization_with_threshold
@@ -26,7 +26,7 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
26
26
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import \
27
27
  power_of_two_selection_tensor
28
28
 
29
- from model_compression_toolkit.core.common.logger import Logger
29
+ from model_compression_toolkit.logger import Logger
30
30
 
31
31
 
32
32
  def lut_kmeans_tensor(tensor_data: np.ndarray,
@@ -15,13 +15,13 @@
15
15
  import numpy as np
16
16
 
17
17
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
18
- from model_compression_toolkit.core.common.constants import MIN_THRESHOLD, THRESHOLD
18
+ from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD
19
19
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_search import \
20
20
  qparams_selection_tensor_search, qparams_selection_histogram_search
21
21
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import max_power_of_two, get_tensor_max
22
22
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
23
23
  get_threshold_selection_tensor_error_function, get_threshold_selection_histogram_error_function
24
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
24
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
25
25
 
26
26
 
27
27
  def power_of_two_selection_tensor(tensor_data: np.ndarray,
@@ -13,11 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import numpy as np
16
- from typing import Tuple, Dict
16
+ from typing import Dict
17
17
 
18
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
19
19
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
20
- from model_compression_toolkit.core.common.constants import SIGNED
20
+ from model_compression_toolkit.constants import SIGNED
21
21
  from model_compression_toolkit.core.common.quantization import quantization_params_generation
22
22
  from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
23
23
  from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
@@ -16,7 +16,7 @@ from typing import List
16
16
 
17
17
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
18
18
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
19
- from model_compression_toolkit.core.common import Graph, BaseNode, Logger
19
+ from model_compression_toolkit.core.common import Graph, BaseNode
20
20
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
21
21
  import get_activations_qparams
22
22
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation import \
@@ -18,7 +18,7 @@ from typing import Any, Tuple, Dict
18
18
 
19
19
  import numpy as np
20
20
 
21
- from model_compression_toolkit.core.common.constants import MIN_THRESHOLD, DEFAULT_TOL, DEFAULT_DEC_FACTOR, \
21
+ from model_compression_toolkit.constants import MIN_THRESHOLD, DEFAULT_TOL, DEFAULT_DEC_FACTOR, \
22
22
  SYMMETRIC_TENSOR_PER_CHANNEL_N_INTERVALS, SYMMETRIC_TENSOR_PER_CHANNEL_N_ITER, SYMMETRIC_TENSOR_DEC_FREQ, \
23
23
  SYMMETRIC_TENSOR_PER_CHANNEL_DEC_FREQ, SYMMETRIC_TENSOR_N_INTERVALS, SYMMETRIC_TENSOR_N_ITER, \
24
24
  UNIFORM_TENSOR_PER_CHANNEL_N_ITER, UNIFORM_TENSOR_N_ITER, SYMMETRIC_HISTOGRAM_DEC_FREQ, SYMMETRIC_HISTOGRAM_N_ITER, \
@@ -16,7 +16,7 @@ from typing import Dict, Any, Tuple
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common import Logger
19
+ from model_compression_toolkit.logger import Logger
20
20
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
21
21
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
22
  from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
@@ -15,7 +15,7 @@
15
15
  import numpy as np
16
16
 
17
17
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
18
- from model_compression_toolkit.core.common.constants import MIN_THRESHOLD, THRESHOLD
18
+ from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD
19
19
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
20
20
  get_threshold_selection_tensor_error_function, get_threshold_selection_histogram_error_function, _kl_error_histogram
21
21
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_search import \
@@ -23,7 +23,7 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
23
23
  qparams_symmetric_selection_histogram_search, kl_qparams_symmetric_selection_histogram_search
24
24
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import \
25
25
  get_tensor_max
26
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
26
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
27
27
 
28
28
 
29
29
  def symmetric_selection_tensor(tensor_data: np.ndarray,
@@ -15,14 +15,14 @@
15
15
  import numpy as np
16
16
 
17
17
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
18
- from model_compression_toolkit.core.common.constants import MIN_THRESHOLD, RANGE_MIN, RANGE_MAX
18
+ from model_compression_toolkit.constants import MIN_THRESHOLD, RANGE_MIN, RANGE_MAX
19
19
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_search import \
20
20
  qparams_uniform_selection_tensor_search, qparams_uniform_selection_histogram_search
21
21
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
22
22
  get_threshold_selection_tensor_error_function, get_threshold_selection_histogram_error_function
23
23
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import get_tensor_max, \
24
24
  get_tensor_min
25
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
25
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
26
26
 
27
27
  def uniform_selection_tensor(tensor_data: np.ndarray,
28
28
  p: int,
@@ -20,9 +20,10 @@ from model_compression_toolkit.core.common.framework_implementation import Frame
20
20
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
21
21
  from model_compression_toolkit.core.common.graph.base_graph import Graph
22
22
  from model_compression_toolkit.core.common.quantization.quantize_node import get_quantized_kernel_by_weights_qc
23
+ from model_compression_toolkit.logger import Logger
23
24
 
24
25
 
25
- def quantize_graph_weights(graph_to_quantize: Graph,
26
+ def quantize_graph_weights(graph: Graph,
26
27
  fw_info: FrameworkInfo,
27
28
  fw_impl: FrameworkImplementation) -> Graph:
28
29
  """
@@ -32,12 +33,11 @@ def quantize_graph_weights(graph_to_quantize: Graph,
32
33
  is calculated and subtracted from the original node's bias. The graph is quantized in-place.
33
34
 
34
35
  Args:
35
- graph_to_quantize: Graph to quantize its nodes.
36
+ graph: Graph to quantize its nodes.
36
37
  fw_info: Framework information needed for quantizing the graph's nodes' weights and activations.
37
38
  fw_impl: FrameworkImplementation with specific framework implementations.
38
39
 
39
40
  """
40
- graph = copy.deepcopy(graph_to_quantize)
41
41
  # Iterate over nodes in the graph and quantize each node's weights and activations
42
42
  # (according to operators groups in framework info).
43
43
  for n in graph.nodes():
@@ -48,7 +48,7 @@ def quantize_graph_weights(graph_to_quantize: Graph,
48
48
  n.final_weights_quantization_cfg,
49
49
  fw_impl=fw_impl)
50
50
 
51
- common.Logger.debug(
51
+ Logger.debug(
52
52
  f'Node name: {n.name} has the following quantization params: '
53
53
  f'{str(n.final_weights_quantization_cfg.weights_quantization_params)}')
54
54
 
@@ -15,7 +15,7 @@
15
15
 
16
16
 
17
17
  from model_compression_toolkit.core import common
18
- from model_compression_toolkit.core.common import Logger
18
+ from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
20
20
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
21
21
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
@@ -46,7 +46,7 @@ def get_quantized_kernel_by_weights_qc(fw_info: FrameworkInfo,
46
46
  # If weights should be quantized per-channel but a kernel channels mapping is missing.
47
47
  if weights_qc.weights_per_channel_threshold and fw_info.kernel_channels_mapping is \
48
48
  None:
49
- common.Logger.warning(
49
+ Logger.warning(
50
50
  'Weights Per Channel Quantization requires channel mapping function but framework info '
51
51
  'does not contain one')
52
52
  output_channels_axis, input_channels_axis = get_channels_axis(weights_qc,
@@ -16,7 +16,7 @@
16
16
  from sklearn.cluster import KMeans
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.constants import CLUSTER_CENTERS, MIN_THRESHOLD, SCALE_PER_CHANNEL
19
+ from model_compression_toolkit.constants import CLUSTER_CENTERS, MIN_THRESHOLD, SCALE_PER_CHANNEL
20
20
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import kmeans_assign_clusters
21
21
 
22
22
 
@@ -15,7 +15,7 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.constants import CLUSTER_CENTERS, SCALE_PER_CHANNEL, \
18
+ from model_compression_toolkit.constants import CLUSTER_CENTERS, SCALE_PER_CHANNEL, \
19
19
  MULTIPLIER_N_BITS
20
20
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import kmeans_assign_clusters, \
21
21
  get_quantized_tensor, int_quantization_with_threshold
@@ -17,8 +17,10 @@
17
17
  from typing import Tuple, List
18
18
  import numpy as np
19
19
 
20
- from model_compression_toolkit.core.common.constants import MIN_THRESHOLD, EPS
20
+ from model_compression_toolkit.constants import MIN_THRESHOLD, EPS
21
21
  from model_compression_toolkit.core import common
22
+ from model_compression_toolkit.logger import Logger
23
+
22
24
 
23
25
  def max_power_of_two(x: np.ndarray,
24
26
  min_threshold: float = MIN_THRESHOLD) -> np.ndarray:
@@ -236,7 +238,7 @@ def get_tensor_max(tensor_data: np.ndarray,
236
238
 
237
239
  """
238
240
  if n_bits < 1:
239
- common.Logger.error("n_bits must be positive")
241
+ Logger.error("n_bits must be positive")
240
242
  if is_uniform_quantization:
241
243
  expansion_factor = 1.0
242
244
  elif n_bits == 1:
@@ -15,8 +15,8 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.logger import Logger
19
- from model_compression_toolkit.core.common.constants import RANGE_MIN, RANGE_MAX, THRESHOLD
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX, THRESHOLD
20
20
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor, \
21
21
  quantize_tensor
22
22
 
@@ -17,7 +17,8 @@
17
17
  import copy
18
18
  from typing import List
19
19
 
20
- from model_compression_toolkit.core.common import Logger, BaseNode
20
+ from model_compression_toolkit.core.common import BaseNode
21
+ from model_compression_toolkit.logger import Logger
21
22
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
23
  from model_compression_toolkit.core.common.graph.base_graph import Graph
23
24
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
@@ -28,8 +29,8 @@ from model_compression_toolkit.core.common.quantization.quantization_params_fn_s
28
29
  get_activation_quantization_params_fn, get_weights_quantization_params_fn
29
30
  from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
30
31
  get_weights_quantization_fn
31
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework import TargetPlatformCapabilities
32
- from model_compression_toolkit.core.common.target_platform.op_quantization_config import OpQuantizationConfig, \
32
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
33
+ from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import OpQuantizationConfig, \
33
34
  QuantizationConfigOptions
34
35
 
35
36
 
@@ -48,14 +49,13 @@ def set_quantization_configuration_to_graph(graph: Graph,
48
49
  The graph with quantization configurations attached to each node in it.
49
50
  """
50
51
 
51
- graph_with_qcs = copy.deepcopy(graph)
52
- for n in graph_with_qcs.nodes:
52
+ for n in graph.nodes:
53
53
  set_quantization_configs_to_node(node=n,
54
54
  quant_config=quant_config,
55
55
  fw_info=graph.fw_info,
56
56
  tpc=graph.tpc,
57
57
  mixed_precision_enable=mixed_precision_enable)
58
- return graph_with_qcs
58
+ return graph
59
59
 
60
60
 
61
61
  def set_quantization_configs_to_node(node: BaseNode,
@@ -73,7 +73,7 @@ def set_quantization_configs_to_node(node: BaseNode,
73
73
  tpc: TargetPlatformCapabilities to get default OpQuantizationConfig.
74
74
  mixed_precision_enable: is mixed precision enabled
75
75
  """
76
- node_qc_options = tpc.get_qco_by_node(node)
76
+ node_qc_options = node.get_qco(tpc)
77
77
 
78
78
  # Create QC candidates for weights and activation combined
79
79
  weight_channel_axis = fw_info.kernel_channels_mapping.get(node.type)[0]
@@ -13,11 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Any, Tuple
16
+ from typing import Any
17
17
 
18
18
  import numpy as np
19
19
 
20
- from model_compression_toolkit.core.common.constants import EPS
20
+ from model_compression_toolkit.constants import EPS
21
21
 
22
22
  #########################
23
23
  # Helpful functions
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  import copy
16
16
 
17
- from model_compression_toolkit import CoreConfig
17
+ from model_compression_toolkit.core import CoreConfig
18
18
  from model_compression_toolkit.core.common import Graph, BaseNode
19
19
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
20
20
 
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import copy
16
15
  from typing import Callable, Any
17
16
 
18
17
  from tqdm import tqdm
@@ -92,7 +91,7 @@ def quantized_model_builder_for_second_moment_correction(graph: common.Graph,
92
91
  return quantized_model
93
92
 
94
93
 
95
- def apply_second_moment_correction_to_graph(graph_to_apply_second_moment_correction: Graph,
94
+ def apply_second_moment_correction_to_graph(graph: Graph,
96
95
  representative_data_gen: Callable,
97
96
  core_config: CoreConfig,
98
97
  fw_info: FrameworkInfo,
@@ -100,7 +99,7 @@ def apply_second_moment_correction_to_graph(graph_to_apply_second_moment_correct
100
99
  """
101
100
  Apply second moment correction on graph.
102
101
  Args:
103
- graph_to_apply_second_moment_correction: Graph to apply second moment correction.
102
+ graph: Graph to apply second moment correction.
104
103
  representative_data_gen (Callable): Dataset used for calibration.
105
104
  core_config (CoreConfig): Configuration object containing parameters of how the model should be
106
105
  quantized, including mixed precision parameters.
@@ -110,7 +109,6 @@ def apply_second_moment_correction_to_graph(graph_to_apply_second_moment_correct
110
109
  Returns:
111
110
  Graph after second moment correction.
112
111
  """
113
- graph = copy.deepcopy(graph_to_apply_second_moment_correction)
114
112
  semi_quantized_model = quantized_model_builder_for_second_moment_correction(graph, fw_info, fw_impl)
115
113
  fw_impl.apply_second_moment_correction(semi_quantized_model, core_config, representative_data_gen, graph)
116
114
  graph = substitute(graph, fw_impl.get_substitutions_after_second_moment_correction(core_config.quantization_config))
@@ -18,15 +18,16 @@ from typing import Any
18
18
 
19
19
  import numpy as np
20
20
 
21
- from model_compression_toolkit import CoreConfig
21
+ from model_compression_toolkit.core import CoreConfig
22
22
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
23
23
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
24
- from model_compression_toolkit.core.common import BaseNode, Logger, Graph
24
+ from model_compression_toolkit.core.common import BaseNode, Graph
25
25
  from model_compression_toolkit.core.common.quantization.quantize_node import get_quantized_kernel_by_weights_qc
26
26
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
27
+ from model_compression_toolkit.logger import Logger
27
28
 
28
29
 
29
- def compute_bias_correction_of_graph(graph_co_compute_bias: Graph,
30
+ def compute_bias_correction_of_graph(graph: Graph,
30
31
  core_config: CoreConfig,
31
32
  fw_info: FrameworkInfo,
32
33
  fw_impl: FrameworkImplementation) -> Graph:
@@ -35,7 +36,7 @@ def compute_bias_correction_of_graph(graph_co_compute_bias: Graph,
35
36
  compute the bias-correction term, and store it in the candidate weights quantization configuration.
36
37
 
37
38
  Args:
38
- graph_co_compute_bias: Graph with nodes to compute the bias correction for
39
+ graph: Graph with nodes to compute the bias correction for
39
40
  each node's weights quantization configuration candidates.
40
41
  core_config: CoreConfig containing parameters of how the model should be quantized.
41
42
  fw_info: Framework info like lists of nodes their kernel should quantized.
@@ -46,7 +47,6 @@ def compute_bias_correction_of_graph(graph_co_compute_bias: Graph,
46
47
  for each node.
47
48
  """
48
49
 
49
- graph = copy.deepcopy(graph_co_compute_bias)
50
50
  for n in graph.nodes:
51
51
  if n.is_weights_quantization_enabled() and core_config.quantization_config.weights_bias_correction:
52
52
  _compute_bias_correction_per_candidate_qc(n,
@@ -13,14 +13,12 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import copy
17
-
18
16
  from typing import List
19
17
 
20
18
  from model_compression_toolkit.core import common
21
19
 
22
20
 
23
- def substitute(graph_to_substitute: common.Graph,
21
+ def substitute(graph: common.Graph,
24
22
  substitutions_list: List[common.BaseSubstitution]) -> common.Graph:
25
23
  """
26
24
  Apply a list of substitutions on a graph.
@@ -32,9 +30,8 @@ def substitute(graph_to_substitute: common.Graph,
32
30
  Transformed graph after applying all substitutions in substitutions_list.
33
31
  """
34
32
 
35
- graph = copy.deepcopy(graph_to_substitute)
36
33
  for substitution in substitutions_list:
37
34
  matched_nodes = graph.filter(substitution.matcher_instance)
38
35
  for idn in matched_nodes:
39
36
  graph = substitution.substitute(graph, idn)
40
- return graph
37
+ return graph
@@ -20,11 +20,11 @@ from typing import Callable
20
20
  import numpy as np
21
21
 
22
22
  from model_compression_toolkit.core import common
23
- from model_compression_toolkit.core.common import Logger
23
+ from model_compression_toolkit.logger import Logger
24
24
  from model_compression_toolkit.core.common.graph.base_graph import Graph
25
25
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
26
26
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
27
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
27
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
28
28
 
29
29
 
30
30
  class BatchNormalizationReconstruction(common.BaseSubstitution):
@@ -22,9 +22,9 @@ from model_compression_toolkit.core import common
22
22
  from model_compression_toolkit.core.common.graph.base_graph import Graph
23
23
  from model_compression_toolkit.core.common.graph.graph_matchers import EdgeMatcher, NodeOperationMatcher
24
24
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
25
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
26
- from model_compression_toolkit.core.common.constants import THRESHOLD, RANGE_MIN, RANGE_MAX
27
- from model_compression_toolkit.core.common.logger import Logger
25
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
26
+ from model_compression_toolkit.constants import THRESHOLD, RANGE_MIN, RANGE_MAX
27
+ from model_compression_toolkit.logger import Logger
28
28
 
29
29
 
30
30
  class BatchNormalizationRefusing(common.BaseSubstitution):
@@ -17,7 +17,7 @@
17
17
  import copy
18
18
  import numpy as np
19
19
  from typing import Tuple, Callable
20
- from model_compression_toolkit.core.common.logger import Logger
20
+ from model_compression_toolkit.logger import Logger
21
21
  from model_compression_toolkit.core import common
22
22
  from model_compression_toolkit.core.common.graph.base_graph import Graph
23
23
  from model_compression_toolkit.core.common.graph.graph_matchers import EdgeMatcher, NodeOperationMatcher
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import copy
17
-
18
16
  from model_compression_toolkit.core import common
19
17
 
20
18
 
@@ -32,7 +30,6 @@ def linear_collapsing_substitute(graph: common.Graph,
32
30
  Returns:
33
31
  Transformed graph after applying all linear collapsing substitutions.
34
32
  """
35
- graph = copy.deepcopy(graph)
36
33
  matched_nodes = graph.filter(linear_collapsing_substitution.matcher_instance)
37
34
  matched_nodes_list = []
38
35
  match_indicator = True
@@ -16,11 +16,11 @@ import copy
16
16
  import numpy as np
17
17
  from typing import List, Tuple, Any, Callable
18
18
 
19
- from model_compression_toolkit.core.common.logger import Logger
19
+ from model_compression_toolkit.logger import Logger
20
20
  from model_compression_toolkit.core.common import FrameworkInfo, Graph, BaseNode
21
- from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
21
+ from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
22
22
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
23
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
23
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
24
24
  from model_compression_toolkit.core.common.quantization.set_node_quantization_config import create_node_activation_qc, \
25
25
  set_quantization_configs_to_node
26
26
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
@@ -356,7 +356,7 @@ def shift_negative_function(graph: Graph,
356
356
  bypass_candidate_qc.activation_quantization_cfg.activation_quantization_params[SIGNED] = False
357
357
  graph.shift_stats_collector(bypass_node, np.array(shift_value))
358
358
 
359
- add_node_qco = graph.tpc.get_qco_by_node(add_node).quantization_config_list
359
+ add_node_qco = add_node.get_qco(graph.tpc).quantization_config_list
360
360
  for op_qc_idx, candidate_qc in enumerate(add_node.candidates_quantization_cfg):
361
361
  candidate_qc.weights_quantization_cfg.enable_weights_quantization = False
362
362
 
@@ -495,7 +495,7 @@ def apply_shift_negative_correction(graph: Graph,
495
495
  nodes = list(graph.nodes())
496
496
  for n in nodes:
497
497
  # Skip substitution if QuantizationMethod is uniform.
498
- node_qco = graph.tpc.get_qco_by_node(n)
498
+ node_qco = n.get_qco(graph.tpc)
499
499
  if any([op_qc.activation_quantization_method is QuantizationMethod.UNIFORM
500
500
  for op_qc in node_qco.quantization_config_list]):
501
501
  continue
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from model_compression_toolkit.core.common import BaseNode, Graph, BaseSubstitution
17
- from model_compression_toolkit.core.common.logger import Logger
17
+ from model_compression_toolkit.logger import Logger
18
18
  from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode
19
19
 
20
20
 
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import itertools
17
- from model_compression_toolkit.core.common.logger import Logger
17
+ from model_compression_toolkit.logger import Logger
18
18
  from model_compression_toolkit.core.common import BaseNode, Graph, BaseSubstitution
19
19
  from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualSplitWeightsNode, \
20
20
  VirtualSplitActivationNode
@@ -31,7 +31,7 @@ from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
31
31
  from tensorboard.summary.writer.event_file_writer import EventFileWriter
32
32
  from typing import List, Any, Dict
33
33
  from networkx import topological_sort
34
- from model_compression_toolkit import FrameworkInfo
34
+ from model_compression_toolkit.core import FrameworkInfo
35
35
  from model_compression_toolkit.core.common import Graph, BaseNode
36
36
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
37
37
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common import Logger
16
+ from model_compression_toolkit.logger import Logger
17
17
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
18
18
  from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
19
19
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import List
16
16
 
17
- from model_compression_toolkit import FrameworkInfo
17
+ from model_compression_toolkit.core import FrameworkInfo
18
18
  from model_compression_toolkit.core.common import BaseNode
19
19
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
20
20
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO