mct-nightly 1.8.0.8042023.post345__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (288) hide show
  1. {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +4 -3
  2. {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +285 -277
  3. model_compression_toolkit/__init__.py +9 -32
  4. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  5. model_compression_toolkit/core/__init__.py +14 -0
  6. model_compression_toolkit/core/analyzer.py +3 -2
  7. model_compression_toolkit/core/common/__init__.py +0 -1
  8. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  9. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  11. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  12. model_compression_toolkit/core/common/framework_info.py +1 -1
  13. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  14. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  15. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  16. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  18. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  19. model_compression_toolkit/core/common/memory_computation.py +1 -1
  20. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  21. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  23. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  26. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  28. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  29. model_compression_toolkit/core/common/model_collector.py +2 -2
  30. model_compression_toolkit/core/common/model_validation.py +1 -1
  31. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  32. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  33. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  34. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  35. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  36. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  37. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  38. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  39. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  50. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  51. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  52. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  54. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  55. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  56. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  57. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  58. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  60. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  63. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  65. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  66. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  67. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  69. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  73. model_compression_toolkit/core/keras/back2framework/model_gradients.py +2 -2
  74. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  75. model_compression_toolkit/core/keras/constants.py +0 -7
  76. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  85. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  86. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  87. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  88. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  89. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  90. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  91. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  92. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  93. model_compression_toolkit/core/keras/reader/common.py +1 -1
  94. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  95. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +2 -2
  99. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  100. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/constants.py +0 -6
  102. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  103. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  109. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  110. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  111. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  112. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  113. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  114. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  115. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  116. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  117. model_compression_toolkit/core/runner.py +7 -7
  118. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  119. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  120. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
  121. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  122. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
  124. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +1 -1
  125. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
  126. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
  127. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +3 -2
  128. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +5 -3
  129. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +2 -2
  130. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +3 -2
  131. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +4 -3
  132. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +2 -2
  133. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  134. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  135. model_compression_toolkit/gptq/common/gptq_training.py +2 -1
  136. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  137. model_compression_toolkit/gptq/keras/gptq_training.py +5 -4
  138. model_compression_toolkit/gptq/keras/quantization_facade.py +27 -20
  139. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  140. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -2
  141. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  142. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -5
  143. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +2 -2
  144. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  145. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +2 -2
  146. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  147. model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
  148. model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -13
  149. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -3
  150. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +8 -3
  151. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -2
  152. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +2 -2
  153. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +9 -11
  154. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +2 -2
  155. model_compression_toolkit/gptq/runner.py +3 -2
  156. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +11 -12
  157. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  158. model_compression_toolkit/ptq/__init__.py +3 -0
  159. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  160. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  161. model_compression_toolkit/qat/__init__.py +4 -0
  162. model_compression_toolkit/qat/common/__init__.py +1 -2
  163. model_compression_toolkit/qat/common/qat_config.py +3 -1
  164. model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
  165. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  166. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +13 -11
  167. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +9 -9
  168. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  169. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  170. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +4 -3
  171. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +7 -5
  172. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +2 -2
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  178. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  179. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  180. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  181. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  182. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  183. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  184. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  185. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  186. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  187. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +2 -2
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +2 -2
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  202. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +9 -9
  203. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  204. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -6
  205. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  206. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +2 -2
  207. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  208. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +2 -2
  209. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  210. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  211. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  212. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  213. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  214. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  215. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  216. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  217. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  218. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  219. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  220. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  221. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  222. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  223. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  224. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  225. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  226. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  227. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  228. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  229. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  230. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  231. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  232. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  233. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  234. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  235. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  236. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  237. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  238. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  239. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  240. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  241. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  242. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  243. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  244. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  245. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  248. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  250. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  254. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  255. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  257. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  259. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  261. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  263. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  264. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  265. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  266. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  267. {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
  268. {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +0 -0
  269. {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
  270. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  271. /model_compression_toolkit/{core/common/logger.py → logger.py} +0 -0
  272. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  273. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  274. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  275. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  276. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  277. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  278. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  279. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  280. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  281. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  282. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  283. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  284. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  285. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  286. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  287. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  288. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -18,7 +18,7 @@ from typing import Union, List, Any
18
18
  from inspect import signature
19
19
 
20
20
  from model_compression_toolkit.core import common
21
- from model_compression_toolkit.core.common import Logger
21
+ from model_compression_toolkit.logger import Logger
22
22
 
23
23
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer, \
24
24
  QuantizationTarget
@@ -55,9 +55,9 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
55
55
  for i, (k, v) in enumerate(self.get_sig().parameters.items()):
56
56
  if i == 0:
57
57
  if v.annotation not in [TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]:
58
- common.Logger.error(f"First parameter must be either TrainableQuantizerWeightsConfig or TrainableQuantizerActivationConfig") # pragma: no cover
58
+ Logger.error(f"First parameter must be either TrainableQuantizerWeightsConfig or TrainableQuantizerActivationConfig") # pragma: no cover
59
59
  elif v.default is v.empty:
60
- common.Logger.error(f"Parameter {k} doesn't have a default value") # pragma: no cover
60
+ Logger.error(f"Parameter {k} doesn't have a default value") # pragma: no cover
61
61
 
62
62
  super(BaseTrainableQuantizer, self).__init__()
63
63
  self.quantization_config = quantization_config
@@ -73,15 +73,15 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
73
73
  if static_quantization_target == QuantizationTarget.Weights:
74
74
  self.validate_weights()
75
75
  if self.quantization_config.weights_quantization_method not in static_quantization_method:
76
- common.Logger.error(
76
+ Logger.error(
77
77
  f'Quantization method mismatch expected: {static_quantization_method} and got {self.quantization_config.weights_quantization_method}')
78
78
  elif static_quantization_target == QuantizationTarget.Activation:
79
79
  self.validate_activation()
80
80
  if self.quantization_config.activation_quantization_method not in static_quantization_method:
81
- common.Logger.error(
81
+ Logger.error(
82
82
  f'Quantization method mismatch expected: {static_quantization_method} and got {self.quantization_config.activation_quantization_method}')
83
83
  else:
84
- common.Logger.error(
84
+ Logger.error(
85
85
  f'Unknown Quantization Part:{static_quantization_target}') # pragma: no cover
86
86
 
87
87
  self.quantizer_parameters = {}
@@ -145,7 +145,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
145
145
 
146
146
  """
147
147
  if self.activation_quantization() or not self.weights_quantization():
148
- common.Logger.error(f'Expect weight quantization got activation')
148
+ Logger.error(f'Expect weight quantization got activation')
149
149
 
150
150
  def validate_activation(self) -> None:
151
151
  """
@@ -153,7 +153,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
153
153
 
154
154
  """
155
155
  if not self.activation_quantization() or self.weights_quantization():
156
- common.Logger.error(f'Expect activation quantization got weight')
156
+ Logger.error(f'Expect activation quantization got weight')
157
157
 
158
158
  def convert2inferable(self) -> BaseInferableQuantizer:
159
159
  """
@@ -183,7 +183,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
183
183
  if name in self.quantizer_parameters:
184
184
  return self.quantizer_parameters[name][VAR]
185
185
  else:
186
- common.Logger.error(f'Variable {name} is not exist in quantizers parameters!') # pragma: no cover
186
+ Logger.error(f'Variable {name} is not exist in quantizers parameters!') # pragma: no cover
187
187
 
188
188
 
189
189
  @abstractmethod
@@ -13,7 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from typing import List
16
- from model_compression_toolkit.core.common import BaseNode, Logger
16
+ from model_compression_toolkit.core.common import BaseNode
17
+ from model_compression_toolkit.logger import Logger
17
18
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
18
19
  TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig, TrainableQuantizerCandidateConfig
19
20
 
@@ -12,12 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Union
15
+ from typing import Union, Any
16
16
 
17
- from model_compression_toolkit.gptq import RoundingType
18
- from model_compression_toolkit import TrainingMethod
19
- from model_compression_toolkit.core.common import Logger
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.logger import Logger
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
19
  from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
22
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants \
23
21
  import QUANTIZATION_TARGET, QUANTIZATION_METHOD, QUANTIZER_TYPE
@@ -26,7 +24,7 @@ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructur
26
24
 
27
25
 
28
26
  def get_trainable_quantizer_class(quant_target: QuantizationTarget,
29
- quantizer_type: Union[TrainingMethod, RoundingType],
27
+ quantizer_type: Union[Any, Any],
30
28
  quant_method: QuantizationMethod,
31
29
  quantizer_base_class: type) -> type:
32
30
  """
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from abc import ABC
16
16
  from typing import Dict, List
17
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
18
18
 
19
19
 
20
20
  class TrainableQuantizerCandidateConfig:
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  from typing import Dict, Any, Union, List
16
16
 
17
- from model_compression_toolkit.core.common import Logger
18
- from model_compression_toolkit.core.common.constants import FOUND_TF
17
+ from model_compression_toolkit.logger import Logger
18
+ from model_compression_toolkit.constants import FOUND_TF
19
19
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
20
20
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer, VAR, GROUP
21
21
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
@@ -17,7 +17,7 @@ import copy
17
17
  from typing import Any, Union
18
18
  from enum import Enum
19
19
 
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
21
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
22
22
  TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig
23
23
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common import constants as C
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  from typing import Union, List
16
16
 
17
- from model_compression_toolkit.core.common.logger import Logger
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
17
+ from model_compression_toolkit.logger import Logger
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
19
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
20
20
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer, VAR, GROUP
21
21
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
@@ -0,0 +1,27 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # TP Model constants
17
+ OPS_SET_LIST = 'ops_set_list'
18
+
19
+ # Version
20
+ LATEST = 'latest'
21
+
22
+
23
+ # Supported TP models names:
24
+ DEFAULT_TP_MODEL = 'default'
25
+ IMX500_TP_MODEL = 'imx500'
26
+ TFLITE_TP_MODEL = 'tflite'
27
+ QNNPACK_TP_MODEL = 'qnnpack'
@@ -13,16 +13,16 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.target_platform.fusing import Fusing
17
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework import \
16
+ from model_compression_toolkit.target_platform_capabilities.target_platform.fusing import Fusing
17
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
18
18
  TargetPlatformCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater, LayerFilterParams, OperationsToLayers, get_current_tpc
19
19
 
20
- from model_compression_toolkit.core.common.target_platform.target_platform_model import \
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import \
21
21
  get_default_quantization_config_options, TargetPlatformModel
22
22
 
23
- from model_compression_toolkit.core.common.target_platform.op_quantization_config import OpQuantizationConfig, \
23
+ from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import OpQuantizationConfig, \
24
24
  QuantizationConfigOptions, QuantizationMethod
25
- from model_compression_toolkit.core.common.target_platform.operators import OperatorsSet, OperatorSetConcat
25
+ from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSet, OperatorSetConcat
26
26
 
27
27
 
28
28
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.logger import Logger
16
+ from model_compression_toolkit.logger import Logger
17
17
 
18
18
  def get_current_tp_model():
19
19
  """
@@ -16,8 +16,8 @@
16
16
 
17
17
  from typing import Any
18
18
 
19
- from model_compression_toolkit.core.common.target_platform.operators import OperatorSetConcat
20
- from model_compression_toolkit.core.common.target_platform.target_platform_model_component import TargetPlatformModelComponent
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorSetConcat
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import TargetPlatformModelComponent
21
21
 
22
22
 
23
23
  class Fusing(TargetPlatformModelComponent):
@@ -14,10 +14,10 @@
14
14
  # ==============================================================================
15
15
  from typing import Dict, Any
16
16
 
17
- from model_compression_toolkit.core.common.constants import OPS_SET_LIST
18
- from model_compression_toolkit.core.common.target_platform.target_platform_model_component import TargetPlatformModelComponent
19
- from model_compression_toolkit.core.common.target_platform.current_tp_model import _current_tp_model
20
- from model_compression_toolkit.core.common.target_platform.op_quantization_config import QuantizationConfigOptions
17
+ from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import TargetPlatformModelComponent
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import _current_tp_model
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import QuantizationConfigOptions
21
21
 
22
22
 
23
23
  class OperatorsSetBase(TargetPlatformModelComponent):
@@ -0,0 +1,20 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from enum import Enum
16
+
17
+
18
+ class QuantizationFormat(Enum):
19
+ FAKELY_QUANT = 0
20
+ INT8 = 1
@@ -16,16 +16,16 @@
16
16
  import pprint
17
17
  from typing import Any, Dict
18
18
 
19
- from model_compression_toolkit.core.common.target_platform.current_tp_model import _current_tp_model, \
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import _current_tp_model, \
20
20
  get_current_tp_model
21
- from model_compression_toolkit.core.common.target_platform.fusing import Fusing
22
- from model_compression_toolkit.core.common.target_platform.target_platform_model_component import \
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform.fusing import Fusing
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import \
23
23
  TargetPlatformModelComponent
24
- from model_compression_toolkit.core.common.target_platform.op_quantization_config import OpQuantizationConfig, \
24
+ from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import OpQuantizationConfig, \
25
25
  QuantizationConfigOptions
26
- from model_compression_toolkit.core.common.target_platform.operators import OperatorsSetBase
27
- from model_compression_toolkit.core.common.immutable import ImmutableClass
28
- from model_compression_toolkit.core.common.logger import Logger
26
+ from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSetBase
27
+ from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
28
+ from model_compression_toolkit.logger import Logger
29
29
 
30
30
 
31
31
  def get_default_quantization_config_options() -> QuantizationConfigOptions:
@@ -223,3 +223,12 @@ class TargetPlatformModel(ImmutableClass):
223
223
 
224
224
  """
225
225
  pprint.pprint(self.get_info(), sort_dicts=False)
226
+
227
+ def set_quantization_format(self,
228
+ quantization_format: Any):
229
+ """
230
+ Set quantization format.
231
+ Args:
232
+ quantization_format: A quantization format (fake-quant, int8 etc.) from enum QuantizationFormat.
233
+ """
234
+ self.quantization_format = quantization_format
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import Any, Dict
16
16
 
17
- from model_compression_toolkit.core.common.target_platform.current_tp_model import _current_tp_model
17
+ from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import _current_tp_model
18
18
 
19
19
 
20
20
  class TargetPlatformModelComponent:
@@ -13,13 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.current_tpc import get_current_tpc
17
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.target_platform_capabilities import TargetPlatformCapabilities
18
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.attribute_filter import \
16
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import get_current_tpc
17
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities import TargetPlatformCapabilities
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import \
19
19
  Eq, GreaterEq, NotEq, SmallerEq, Greater, Smaller
20
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.layer_filter_params import \
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.layer_filter_params import \
21
21
  LayerFilterParams
22
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.operations_to_layers import \
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.operations_to_layers import \
23
23
  OperationsToLayers, OperationsSetToLayers
24
24
 
25
25
 
@@ -16,7 +16,7 @@
16
16
  import operator
17
17
  from typing import Any, Callable, Dict
18
18
 
19
- from model_compression_toolkit.core.common.logger import Logger
19
+ from model_compression_toolkit.logger import Logger
20
20
 
21
21
 
22
22
  class Filter:
@@ -13,10 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Any, Dict
17
-
18
- from model_compression_toolkit.core.common.graph.base_node import BaseNode
19
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.attribute_filter import AttributeFilter
16
+ from typing import Any
17
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import AttributeFilter
20
18
 
21
19
 
22
20
  class LayerFilterParams:
@@ -87,34 +85,34 @@ class LayerFilterParams:
87
85
  params.extend([str(c) for c in self.conditions])
88
86
  params_str = ', '.join(params)
89
87
  return f'{self.layer.__name__}({params_str})'
90
-
91
- def match(self,
92
- node: BaseNode) -> bool:
93
- """
94
- Check if a node matches the layer, conditions and keyword-arguments of
95
- the LayerFilterParams.
96
-
97
- Args:
98
- node: Node to check if matches to the LayerFilterParams properties.
99
-
100
- Returns:
101
- Whether the node matches to the LayerFilterParams properties.
102
- """
103
- # Check the node has the same type as the layer in LayerFilterParams
104
- if self.layer != node.type:
105
- return False
106
-
107
- # Get attributes from node to filter
108
- layer_config = node.framework_attr
109
- if hasattr(node, "op_call_kwargs"):
110
- layer_config.update(node.op_call_kwargs)
111
-
112
- for attr, value in self.kwargs.items():
113
- if layer_config.get(attr) != value:
114
- return False
115
-
116
- for c in self.conditions:
117
- if not c.match(layer_config):
118
- return False
119
-
120
- return True
88
+ #
89
+ # def match(self,
90
+ # node: BaseNode) -> bool:
91
+ # """
92
+ # Check if a node matches the layer, conditions and keyword-arguments of
93
+ # the LayerFilterParams.
94
+ #
95
+ # Args:
96
+ # node: Node to check if matches to the LayerFilterParams properties.
97
+ #
98
+ # Returns:
99
+ # Whether the node matches to the LayerFilterParams properties.
100
+ # """
101
+ # # Check the node has the same type as the layer in LayerFilterParams
102
+ # if self.layer != node.type:
103
+ # return False
104
+ #
105
+ # # Get attributes from node to filter
106
+ # layer_config = node.framework_attr
107
+ # if hasattr(node, "op_call_kwargs"):
108
+ # layer_config.update(node.op_call_kwargs)
109
+ #
110
+ # for attr, value in self.kwargs.items():
111
+ # if layer_config.get(attr) != value:
112
+ # return False
113
+ #
114
+ # for c in self.conditions:
115
+ # if not c.match(layer_config):
116
+ # return False
117
+ #
118
+ # return True
@@ -15,10 +15,10 @@
15
15
 
16
16
  from typing import List, Any
17
17
 
18
- from model_compression_toolkit.core.common.logger import Logger
19
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.current_tpc import _current_tpc
20
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
21
- from model_compression_toolkit.core.common.target_platform.operators import OperatorsSet, OperatorSetConcat, \
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorSetConcat, \
22
22
  OperatorsSetBase
23
23
 
24
24
 
@@ -18,18 +18,17 @@ import itertools
18
18
  import pprint
19
19
  from typing import List, Any, Dict, Tuple
20
20
 
21
- from model_compression_toolkit.core.common.logger import Logger
22
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.operations_to_layers import \
21
+ from model_compression_toolkit.logger import Logger
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.operations_to_layers import \
23
23
  OperationsToLayers, OperationsSetToLayers
24
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
25
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.layer_filter_params import LayerFilterParams
26
- from model_compression_toolkit.core.common.immutable import ImmutableClass
27
- from model_compression_toolkit.core.common.graph.base_node import BaseNode
28
- from model_compression_toolkit.core.common.target_platform.op_quantization_config import QuantizationConfigOptions, \
24
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
25
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.layer_filter_params import LayerFilterParams
26
+ from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
27
+ from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import QuantizationConfigOptions, \
29
28
  OpQuantizationConfig
30
- from model_compression_toolkit.core.common.target_platform.operators import OperatorsSet, OperatorsSetBase
31
- from model_compression_toolkit.core.common.target_platform.target_platform_model import TargetPlatformModel
32
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.current_tpc import _current_tpc
29
+ from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSetBase
30
+ from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import TargetPlatformModel
31
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
33
32
 
34
33
 
35
34
  class TargetPlatformCapabilities(ImmutableClass):
@@ -163,26 +162,6 @@ class TargetPlatformCapabilities(ImmutableClass):
163
162
  """
164
163
  return self.tp_model.get_default_op_quantization_config()
165
164
 
166
- def get_qco_by_node(self,
167
- node: BaseNode) -> QuantizationConfigOptions:
168
- """
169
- Get the QuantizationConfigOptions of a node in a graph according
170
- to the mappings from layers/LayerFilterParams to the OperatorsSet in the TargetPlatformModel.
171
-
172
- Args:
173
- node: Node from graph to get its QuantizationConfigOptions.
174
-
175
- Returns:
176
- QuantizationConfigOptions of the node.
177
- """
178
- if node is None:
179
- Logger.error(f'Can not retrieve QC options for None node') # pragma: no cover
180
- for fl, qco in self.filterlayer2qco.items():
181
- if fl.match(node):
182
- return qco
183
- if node.type in self.layer2qco:
184
- return self.layer2qco.get(node.type)
185
- return self.tp_model.default_qco
186
165
 
187
166
  def _get_config_options_mapping(self) -> Tuple[Dict[Any, QuantizationConfigOptions],
188
167
  Dict[LayerFilterParams, QuantizationConfigOptions]]:
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.current_tpc import _current_tpc
16
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
17
17
 
18
18
 
19
19
  class TargetPlatformCapabilitiesComponent:
@@ -0,0 +1,25 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH
16
+
17
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v4.tp_model import get_tp_model, generate_tp_model, get_op_quantization_configs
18
+
19
+ if FOUND_TF:
20
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v4.tpc_keras import get_keras_tpc as get_keras_tpc_latest
21
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v4.tpc_keras import generate_keras_tpc
22
+
23
+ if FOUND_TORCH:
24
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v4.tpc_pytorch import get_pytorch_tpc as get_pytorch_tpc_latest
25
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v4.tpc_pytorch import generate_pytorch_tpc
@@ -13,21 +13,23 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH, LATEST
16
+ from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH
17
+ from model_compression_toolkit.target_platform_capabilities.constants import LATEST
18
+
17
19
 
18
20
  ###############################
19
21
  # Build Tensorflow TPC models
20
22
  ###############################
21
23
  keras_tpc_models_dict = None
22
24
  if FOUND_TF:
23
- from model_compression_toolkit.core.tpc_models.default_tpc.latest import get_keras_tpc_latest
24
- from model_compression_toolkit.core.tpc_models.default_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
25
- from model_compression_toolkit.core.tpc_models.default_tpc.v2.tpc_keras import get_keras_tpc as get_keras_tpc_v2
26
- from model_compression_toolkit.core.tpc_models.default_tpc.v3.tpc_keras import get_keras_tpc as get_keras_tpc_v3
27
- from model_compression_toolkit.core.tpc_models.default_tpc.v4.tpc_keras import get_keras_tpc as get_keras_tpc_v4
28
- from model_compression_toolkit.core.tpc_models.default_tpc.v3_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v3_lut
29
- from model_compression_toolkit.core.tpc_models.default_tpc.v4_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v4_lut
30
- from model_compression_toolkit.core.tpc_models.default_tpc.v5.tpc_keras import get_keras_tpc as get_keras_tpc_v5
25
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.latest import get_keras_tpc_latest
26
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
27
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v2.tpc_keras import get_keras_tpc as get_keras_tpc_v2
28
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v3.tpc_keras import get_keras_tpc as get_keras_tpc_v3
29
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v4.tpc_keras import get_keras_tpc as get_keras_tpc_v4
30
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v3_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v3_lut
31
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v4_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v4_lut
32
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v5.tpc_keras import get_keras_tpc as get_keras_tpc_v5
31
33
 
32
34
  # Keras: TPC versioning
33
35
  keras_tpc_models_dict = {'v1': get_keras_tpc_v1(),
@@ -44,20 +46,20 @@ if FOUND_TF:
44
46
  ###############################
45
47
  pytorch_tpc_models_dict = None
46
48
  if FOUND_TORCH:
47
- from model_compression_toolkit.core.tpc_models.default_tpc.latest import get_pytorch_tpc_latest
48
- from model_compression_toolkit.core.tpc_models.default_tpc.v1.tpc_pytorch import \
49
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.latest import get_pytorch_tpc_latest
50
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v1.tpc_pytorch import \
49
51
  get_pytorch_tpc as get_pytorch_tpc_v1
50
- from model_compression_toolkit.core.tpc_models.default_tpc.v2.tpc_pytorch import \
52
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v2.tpc_pytorch import \
51
53
  get_pytorch_tpc as get_pytorch_tpc_v2
52
- from model_compression_toolkit.core.tpc_models.default_tpc.v3.tpc_pytorch import \
54
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v3.tpc_pytorch import \
53
55
  get_pytorch_tpc as get_pytorch_tpc_v3
54
- from model_compression_toolkit.core.tpc_models.default_tpc.v4.tpc_pytorch import \
56
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v4.tpc_pytorch import \
55
57
  get_pytorch_tpc as get_pytorch_tpc_v4
56
- from model_compression_toolkit.core.tpc_models.default_tpc.v3_lut.tpc_pytorch import \
58
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v3_lut.tpc_pytorch import \
57
59
  get_pytorch_tpc as get_pytorch_tpc_v3_lut
58
- from model_compression_toolkit.core.tpc_models.default_tpc.v4_lut.tpc_pytorch import \
60
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v4_lut.tpc_pytorch import \
59
61
  get_pytorch_tpc as get_pytorch_tpc_v4_lut
60
- from model_compression_toolkit.core.tpc_models.default_tpc.v5.tpc_pytorch import \
62
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v5.tpc_pytorch import \
61
63
  get_pytorch_tpc as get_pytorch_tpc_v5
62
64
 
63
65
  # Pytorch: TPC versioning
@@ -15,7 +15,10 @@
15
15
  from typing import List, Tuple
16
16
 
17
17
  import model_compression_toolkit as mct
18
- from model_compression_toolkit.core.common.target_platform import OpQuantizationConfig, TargetPlatformModel
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \
19
+ TargetPlatformModel
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform.quantization_format import \
21
+ QuantizationFormat
19
22
 
20
23
  tp = mct.target_platform
21
24
 
@@ -106,6 +109,9 @@ def generate_tp_model(default_config: OpQuantizationConfig,
106
109
  # be used for operations that will be attached to this set's label.
107
110
  # Otherwise, it will be a configure-less set (used in fusing):
108
111
 
112
+ # Set quantization format to fakely quant
113
+ generated_tpc.set_quantization_format(QuantizationFormat.FAKELY_QUANT)
114
+
109
115
  # May suit for operations like: Dropout, Reshape, etc.
110
116
  tp.OperatorsSet("NoQuantization",
111
117
  tp.get_default_quantization_config_options().clone_and_edit(
@@ -23,9 +23,9 @@ else:
23
23
  from keras.layers import Conv2D, DepthwiseConv2D, Reshape, ZeroPadding2D, \
24
24
  Dropout, MaxPooling2D, Activation, ReLU, Flatten, Cropping2D
25
25
 
26
- from model_compression_toolkit.core.tpc_models.default_tpc.v1.tp_model import get_tp_model
26
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v1.tp_model import get_tp_model
27
27
  import model_compression_toolkit as mct
28
- from model_compression_toolkit.core.tpc_models.default_tpc.v1 import __version__ as TPC_VERSION
28
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.default_tpc.v1 import __version__ as TPC_VERSION
29
29
 
30
30
  tp = mct.target_platform
31
31