mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
  2. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
  3. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +12 -41
  5. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  6. model_compression_toolkit/core/__init__.py +14 -0
  7. model_compression_toolkit/core/analyzer.py +3 -2
  8. model_compression_toolkit/core/common/__init__.py +0 -1
  9. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  11. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  12. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  13. model_compression_toolkit/core/common/framework_info.py +1 -1
  14. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  15. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  16. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  18. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  19. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  20. model_compression_toolkit/core/common/memory_computation.py +1 -1
  21. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  23. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  26. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  28. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  29. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  30. model_compression_toolkit/core/common/model_collector.py +2 -2
  31. model_compression_toolkit/core/common/model_validation.py +1 -1
  32. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  33. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  34. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  35. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  36. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  37. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  38. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  39. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  50. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  51. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  52. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  54. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  55. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  56. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  57. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  58. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  59. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  60. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  61. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  62. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  63. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  65. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  66. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  67. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  68. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  69. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  72. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
  73. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  74. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  75. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  76. model_compression_toolkit/core/keras/constants.py +0 -7
  77. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  85. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  86. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  87. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  88. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  89. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  90. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  91. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  92. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  93. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  94. model_compression_toolkit/core/keras/reader/common.py +1 -1
  95. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  99. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  100. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  102. model_compression_toolkit/core/pytorch/constants.py +4 -6
  103. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  109. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  110. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  111. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  112. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  113. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  114. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  115. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  116. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  117. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  118. model_compression_toolkit/core/runner.py +7 -7
  119. model_compression_toolkit/exporter/__init__.py +5 -0
  120. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
  121. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
  124. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  125. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  126. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
  127. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
  128. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
  129. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
  130. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
  131. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
  132. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
  133. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
  135. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
  136. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
  137. model_compression_toolkit/gptq/__init__.py +6 -0
  138. model_compression_toolkit/gptq/common/gptq_config.py +57 -104
  139. model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
  140. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  141. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  142. model_compression_toolkit/gptq/common/gptq_training.py +30 -39
  143. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  144. model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
  145. model_compression_toolkit/gptq/keras/graph_info.py +8 -33
  146. model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
  147. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  148. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
  149. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  150. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
  151. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  152. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
  153. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
  154. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  155. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
  156. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  157. model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
  158. model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
  159. model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
  160. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  161. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
  162. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  163. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
  164. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  165. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
  166. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
  167. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  168. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
  169. model_compression_toolkit/gptq/runner.py +3 -2
  170. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
  171. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  172. model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
  173. model_compression_toolkit/ptq/__init__.py +3 -0
  174. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  175. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  176. model_compression_toolkit/qat/__init__.py +4 -0
  177. model_compression_toolkit/qat/common/__init__.py +1 -2
  178. model_compression_toolkit/qat/common/qat_config.py +3 -1
  179. model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
  180. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  181. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
  182. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
  183. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  184. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  185. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
  186. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
  187. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  211. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  212. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  213. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  214. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  215. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
  217. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  218. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  219. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  220. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  221. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  222. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
  223. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  224. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
  225. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  226. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
  227. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  228. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
  229. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  232. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  233. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  234. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  235. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  236. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  237. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  238. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  239. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  240. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  241. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  242. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  243. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  244. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  248. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  250. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  254. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  255. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  257. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  259. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  261. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  263. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  265. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  273. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  274. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  275. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  276. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  277. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  278. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  279. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  280. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  281. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  282. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  283. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  284. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  285. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  286. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  287. model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
  288. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
  289. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
  290. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  291. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  292. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  293. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  294. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  300. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  301. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  302. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  303. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  304. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  305. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  306. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  307. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -12,12 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
16
- from typing import Union
15
+ from abc import abstractmethod
16
+ from enum import Enum
17
+ from typing import Union, List, Any
17
18
  from inspect import signature
18
19
 
19
20
  from model_compression_toolkit.core import common
20
- from model_compression_toolkit.core.common import Logger
21
+ from model_compression_toolkit.logger import Logger
21
22
 
22
23
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer, \
23
24
  QuantizationTarget
@@ -27,6 +28,19 @@ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructur
27
28
  QUANTIZATION_TARGET
28
29
 
29
30
 
31
+ VAR = 'var'
32
+ GROUP = 'group'
33
+
34
+ class VariableGroup(Enum):
35
+ """
36
+ An enum for choosing trainable variable group
37
+ 0. WEIGHTS
38
+ 1. QPARAMS
39
+ """
40
+ WEIGHTS = 0
41
+ QPARAMS = 1
42
+
43
+
30
44
  class BaseTrainableQuantizer(BaseInferableQuantizer):
31
45
  def __init__(self,
32
46
  quantization_config: Union[TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig]):
@@ -41,9 +55,9 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
41
55
  for i, (k, v) in enumerate(self.get_sig().parameters.items()):
42
56
  if i == 0:
43
57
  if v.annotation not in [TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]:
44
- common.Logger.error(f"First parameter must be either TrainableQuantizerWeightsConfig or TrainableQuantizerActivationConfig") # pragma: no cover
58
+ Logger.error(f"First parameter must be either TrainableQuantizerWeightsConfig or TrainableQuantizerActivationConfig") # pragma: no cover
45
59
  elif v.default is v.empty:
46
- common.Logger.error(f"Parameter {k} doesn't have a default value") # pragma: no cover
60
+ Logger.error(f"Parameter {k} doesn't have a default value") # pragma: no cover
47
61
 
48
62
  super(BaseTrainableQuantizer, self).__init__()
49
63
  self.quantization_config = quantization_config
@@ -59,17 +73,19 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
59
73
  if static_quantization_target == QuantizationTarget.Weights:
60
74
  self.validate_weights()
61
75
  if self.quantization_config.weights_quantization_method not in static_quantization_method:
62
- common.Logger.error(
76
+ Logger.error(
63
77
  f'Quantization method mismatch expected: {static_quantization_method} and got {self.quantization_config.weights_quantization_method}')
64
78
  elif static_quantization_target == QuantizationTarget.Activation:
65
79
  self.validate_activation()
66
80
  if self.quantization_config.activation_quantization_method not in static_quantization_method:
67
- common.Logger.error(
81
+ Logger.error(
68
82
  f'Quantization method mismatch expected: {static_quantization_method} and got {self.quantization_config.activation_quantization_method}')
69
83
  else:
70
- common.Logger.error(
84
+ Logger.error(
71
85
  f'Unknown Quantization Part:{static_quantization_target}') # pragma: no cover
72
86
 
87
+ self.quantizer_parameters = {}
88
+
73
89
  @classmethod
74
90
  def get_sig(cls):
75
91
  return signature(cls)
@@ -129,7 +145,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
129
145
 
130
146
  """
131
147
  if self.activation_quantization() or not self.weights_quantization():
132
- common.Logger.error(f'Expect weight quantization got activation')
148
+ Logger.error(f'Expect weight quantization got activation')
133
149
 
134
150
  def validate_activation(self) -> None:
135
151
  """
@@ -137,7 +153,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
137
153
 
138
154
  """
139
155
  if not self.activation_quantization() or self.weights_quantization():
140
- common.Logger.error(f'Expect activation quantization got weight')
156
+ Logger.error(f'Expect activation quantization got weight')
141
157
 
142
158
  def convert2inferable(self) -> BaseInferableQuantizer:
143
159
  """
@@ -147,3 +163,38 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
147
163
  BaseInferableQuantizer object.
148
164
  """
149
165
  raise NotImplemented # pragma: no cover
166
+
167
+ def add_quantizer_variable(self, name: str, variable: Any, group: VariableGroup = VariableGroup.WEIGHTS):
168
+ """
169
+ Add a quantizer variable to quantizer_parameters dictionary
170
+ """
171
+ self.quantizer_parameters.update({name: {VAR: variable, GROUP: group}})
172
+
173
+ def get_quantizer_variable(self, name: str) -> Any:
174
+ """
175
+ Get a quantizer variable by name
176
+
177
+ Args:
178
+ name: variable name
179
+
180
+ Returns:
181
+ trainable variable
182
+ """
183
+ if name in self.quantizer_parameters:
184
+ return self.quantizer_parameters[name][VAR]
185
+ else:
186
+ Logger.error(f'Variable {name} is not exist in quantizers parameters!') # pragma: no cover
187
+
188
+
189
+ @abstractmethod
190
+ def get_trainable_variables(self, group: VariableGroup) -> List[Any]:
191
+ """
192
+ Get trainable parameters with specific group from quantizer
193
+
194
+ Args:
195
+ group: Enum of variable group
196
+
197
+ Returns:
198
+ List of trainable variables
199
+ """
200
+ raise NotImplemented # pragma: no cover
@@ -13,7 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from typing import List
16
- from model_compression_toolkit.core.common import BaseNode, Logger
16
+ from model_compression_toolkit.core.common import BaseNode
17
+ from model_compression_toolkit.logger import Logger
17
18
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
18
19
  TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig, TrainableQuantizerCandidateConfig
19
20
 
@@ -12,11 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Union
15
+ from typing import Union, Any
16
16
 
17
- from model_compression_toolkit import TrainingMethod, RoundingType
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.logger import Logger
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
19
  from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
21
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants \
22
21
  import QUANTIZATION_TARGET, QUANTIZATION_METHOD, QUANTIZER_TYPE
@@ -25,7 +24,7 @@ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructur
25
24
 
26
25
 
27
26
  def get_trainable_quantizer_class(quant_target: QuantizationTarget,
28
- quantizer_type: Union[TrainingMethod, RoundingType],
27
+ quantizer_type: Union[Any, Any],
29
28
  quant_method: QuantizationMethod,
30
29
  quantizer_base_class: type) -> type:
31
30
  """
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from abc import ABC
16
16
  from typing import Dict, List
17
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
18
18
 
19
19
 
20
20
  class TrainableQuantizerCandidateConfig:
@@ -12,12 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Dict, Any, Union
15
+ from typing import Dict, Any, Union, List
16
16
 
17
- from model_compression_toolkit.core.common import Logger
18
- from model_compression_toolkit.core.common.constants import FOUND_TF
19
-
20
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer
17
+ from model_compression_toolkit.logger import Logger
18
+ from model_compression_toolkit.constants import FOUND_TF
19
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
20
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer, VAR, GROUP
21
21
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
22
22
  TrainableQuantizerActivationConfig
23
23
 
@@ -25,7 +25,7 @@ if FOUND_TF:
25
25
  QUANTIZATION_CONFIG = 'quantization_config'
26
26
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.keras.config_serialization import config_serialization, \
27
27
  config_deserialization
28
-
28
+ import tensorflow as tf
29
29
 
30
30
  class BaseKerasTrainableQuantizer(BaseTrainableQuantizer):
31
31
  def __init__(self,
@@ -61,6 +61,24 @@ if FOUND_TF:
61
61
  # Note that a quantizer only receive quantization config and the rest of define hardcoded inside the speficie quantizer.
62
62
  return cls(quantization_config=quantization_config)
63
63
 
64
+ def get_trainable_variables(self, group: VariableGroup) -> List[tf.Tensor]:
65
+ """
66
+ Get trainable parameters with specific group from quantizer
67
+
68
+ Args:
69
+ group: Enum of variable group
70
+
71
+ Returns:
72
+ List of trainable variables
73
+ """
74
+ quantizer_trainable = []
75
+ for name, parameter_dict in self.quantizer_parameters.items():
76
+ quantizer_parameter, parameter_group = parameter_dict[VAR], parameter_dict[GROUP]
77
+ if quantizer_parameter.trainable and parameter_group == group:
78
+ quantizer_trainable.append(quantizer_parameter)
79
+ return quantizer_trainable
80
+
81
+
64
82
  else:
65
83
  class BaseKerasTrainableQuantizer(BaseTrainableQuantizer):
66
84
  def __init__(self,
@@ -17,7 +17,7 @@ import copy
17
17
  from typing import Any, Union
18
18
  from enum import Enum
19
19
 
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
21
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
22
22
  TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig
23
23
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common import constants as C
@@ -12,17 +12,20 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Union
15
+ from typing import Union, List
16
16
 
17
- from model_compression_toolkit.core.common.logger import Logger
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
-
20
- from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer
17
+ from model_compression_toolkit.logger import Logger
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
20
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer, VAR, GROUP
21
21
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
22
22
  TrainableQuantizerActivationConfig
23
23
 
24
+
24
25
  if FOUND_TORCH:
25
26
 
27
+ import torch
28
+
26
29
  class BasePytorchTrainableQuantizer(BaseTrainableQuantizer):
27
30
  def __init__(self,
28
31
  quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
@@ -35,6 +38,24 @@ if FOUND_TORCH:
35
38
  """
36
39
  super().__init__(quantization_config)
37
40
 
41
+
42
+ def get_trainable_variables(self, group: VariableGroup) -> List[torch.Tensor]:
43
+ """
44
+ Get trainable parameters with specific group from quantizer
45
+
46
+ Args:
47
+ group: Enum of variable group
48
+
49
+ Returns:
50
+ List of trainable variables
51
+ """
52
+ quantizer_trainable = []
53
+ for name, parameter_dict in self.quantizer_parameters.items():
54
+ quantizer_parameter, parameter_group = parameter_dict[VAR], parameter_dict[GROUP]
55
+ if quantizer_parameter.requires_grad and parameter_group == group:
56
+ quantizer_trainable.append(quantizer_parameter)
57
+ return quantizer_trainable
58
+
38
59
  else:
39
60
  class BasePytorchTrainableQuantizer(BaseTrainableQuantizer):
40
61
  def __init__(self,
@@ -0,0 +1,27 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # TP Model constants
17
+ OPS_SET_LIST = 'ops_set_list'
18
+
19
+ # Version
20
+ LATEST = 'latest'
21
+
22
+
23
+ # Supported TP models names:
24
+ DEFAULT_TP_MODEL = 'default'
25
+ IMX500_TP_MODEL = 'imx500'
26
+ TFLITE_TP_MODEL = 'tflite'
27
+ QNNPACK_TP_MODEL = 'qnnpack'
@@ -13,16 +13,16 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.target_platform.fusing import Fusing
17
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework import \
16
+ from model_compression_toolkit.target_platform_capabilities.target_platform.fusing import Fusing
17
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
18
18
  TargetPlatformCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater, LayerFilterParams, OperationsToLayers, get_current_tpc
19
19
 
20
- from model_compression_toolkit.core.common.target_platform.target_platform_model import \
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import \
21
21
  get_default_quantization_config_options, TargetPlatformModel
22
22
 
23
- from model_compression_toolkit.core.common.target_platform.op_quantization_config import OpQuantizationConfig, \
23
+ from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import OpQuantizationConfig, \
24
24
  QuantizationConfigOptions, QuantizationMethod
25
- from model_compression_toolkit.core.common.target_platform.operators import OperatorsSet, OperatorSetConcat
25
+ from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSet, OperatorSetConcat
26
26
 
27
27
 
28
28
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.logger import Logger
16
+ from model_compression_toolkit.logger import Logger
17
17
 
18
18
  def get_current_tp_model():
19
19
  """
@@ -16,8 +16,8 @@
16
16
 
17
17
  from typing import Any
18
18
 
19
- from model_compression_toolkit.core.common.target_platform.operators import OperatorSetConcat
20
- from model_compression_toolkit.core.common.target_platform.target_platform_model_component import TargetPlatformModelComponent
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorSetConcat
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import TargetPlatformModelComponent
21
21
 
22
22
 
23
23
  class Fusing(TargetPlatformModelComponent):
@@ -14,10 +14,10 @@
14
14
  # ==============================================================================
15
15
  from typing import Dict, Any
16
16
 
17
- from model_compression_toolkit.core.common.constants import OPS_SET_LIST
18
- from model_compression_toolkit.core.common.target_platform.target_platform_model_component import TargetPlatformModelComponent
19
- from model_compression_toolkit.core.common.target_platform.current_tp_model import _current_tp_model
20
- from model_compression_toolkit.core.common.target_platform.op_quantization_config import QuantizationConfigOptions
17
+ from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import TargetPlatformModelComponent
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import _current_tp_model
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import QuantizationConfigOptions
21
21
 
22
22
 
23
23
  class OperatorsSetBase(TargetPlatformModelComponent):
@@ -0,0 +1,20 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from enum import Enum
16
+
17
+
18
+ class QuantizationFormat(Enum):
19
+ FAKELY_QUANT = 0
20
+ INT8 = 1
@@ -16,16 +16,16 @@
16
16
  import pprint
17
17
  from typing import Any, Dict
18
18
 
19
- from model_compression_toolkit.core.common.target_platform.current_tp_model import _current_tp_model, \
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import _current_tp_model, \
20
20
  get_current_tp_model
21
- from model_compression_toolkit.core.common.target_platform.fusing import Fusing
22
- from model_compression_toolkit.core.common.target_platform.target_platform_model_component import \
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform.fusing import Fusing
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import \
23
23
  TargetPlatformModelComponent
24
- from model_compression_toolkit.core.common.target_platform.op_quantization_config import OpQuantizationConfig, \
24
+ from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import OpQuantizationConfig, \
25
25
  QuantizationConfigOptions
26
- from model_compression_toolkit.core.common.target_platform.operators import OperatorsSetBase
27
- from model_compression_toolkit.core.common.immutable import ImmutableClass
28
- from model_compression_toolkit.core.common.logger import Logger
26
+ from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSetBase
27
+ from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
28
+ from model_compression_toolkit.logger import Logger
29
29
 
30
30
 
31
31
  def get_default_quantization_config_options() -> QuantizationConfigOptions:
@@ -223,3 +223,12 @@ class TargetPlatformModel(ImmutableClass):
223
223
 
224
224
  """
225
225
  pprint.pprint(self.get_info(), sort_dicts=False)
226
+
227
+ def set_quantization_format(self,
228
+ quantization_format: Any):
229
+ """
230
+ Set quantization format.
231
+ Args:
232
+ quantization_format: A quantization format (fake-quant, int8 etc.) from enum QuantizationFormat.
233
+ """
234
+ self.quantization_format = quantization_format
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import Any, Dict
16
16
 
17
- from model_compression_toolkit.core.common.target_platform.current_tp_model import _current_tp_model
17
+ from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import _current_tp_model
18
18
 
19
19
 
20
20
  class TargetPlatformModelComponent:
@@ -13,13 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.current_tpc import get_current_tpc
17
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.target_platform_capabilities import TargetPlatformCapabilities
18
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.attribute_filter import \
16
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import get_current_tpc
17
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities import TargetPlatformCapabilities
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import \
19
19
  Eq, GreaterEq, NotEq, SmallerEq, Greater, Smaller
20
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.layer_filter_params import \
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.layer_filter_params import \
21
21
  LayerFilterParams
22
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.operations_to_layers import \
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.operations_to_layers import \
23
23
  OperationsToLayers, OperationsSetToLayers
24
24
 
25
25
 
@@ -16,7 +16,7 @@
16
16
  import operator
17
17
  from typing import Any, Callable, Dict
18
18
 
19
- from model_compression_toolkit.core.common.logger import Logger
19
+ from model_compression_toolkit.logger import Logger
20
20
 
21
21
 
22
22
  class Filter:
@@ -13,10 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Any, Dict
17
-
18
- from model_compression_toolkit.core.common.graph.base_node import BaseNode
19
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.attribute_filter import AttributeFilter
16
+ from typing import Any
17
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import AttributeFilter
20
18
 
21
19
 
22
20
  class LayerFilterParams:
@@ -87,34 +85,34 @@ class LayerFilterParams:
87
85
  params.extend([str(c) for c in self.conditions])
88
86
  params_str = ', '.join(params)
89
87
  return f'{self.layer.__name__}({params_str})'
90
-
91
- def match(self,
92
- node: BaseNode) -> bool:
93
- """
94
- Check if a node matches the layer, conditions and keyword-arguments of
95
- the LayerFilterParams.
96
-
97
- Args:
98
- node: Node to check if matches to the LayerFilterParams properties.
99
-
100
- Returns:
101
- Whether the node matches to the LayerFilterParams properties.
102
- """
103
- # Check the node has the same type as the layer in LayerFilterParams
104
- if self.layer != node.type:
105
- return False
106
-
107
- # Get attributes from node to filter
108
- layer_config = node.framework_attr
109
- if hasattr(node, "op_call_kwargs"):
110
- layer_config.update(node.op_call_kwargs)
111
-
112
- for attr, value in self.kwargs.items():
113
- if layer_config.get(attr) != value:
114
- return False
115
-
116
- for c in self.conditions:
117
- if not c.match(layer_config):
118
- return False
119
-
120
- return True
88
+ #
89
+ # def match(self,
90
+ # node: BaseNode) -> bool:
91
+ # """
92
+ # Check if a node matches the layer, conditions and keyword-arguments of
93
+ # the LayerFilterParams.
94
+ #
95
+ # Args:
96
+ # node: Node to check if matches to the LayerFilterParams properties.
97
+ #
98
+ # Returns:
99
+ # Whether the node matches to the LayerFilterParams properties.
100
+ # """
101
+ # # Check the node has the same type as the layer in LayerFilterParams
102
+ # if self.layer != node.type:
103
+ # return False
104
+ #
105
+ # # Get attributes from node to filter
106
+ # layer_config = node.framework_attr
107
+ # if hasattr(node, "op_call_kwargs"):
108
+ # layer_config.update(node.op_call_kwargs)
109
+ #
110
+ # for attr, value in self.kwargs.items():
111
+ # if layer_config.get(attr) != value:
112
+ # return False
113
+ #
114
+ # for c in self.conditions:
115
+ # if not c.match(layer_config):
116
+ # return False
117
+ #
118
+ # return True
@@ -15,10 +15,10 @@
15
15
 
16
16
  from typing import List, Any
17
17
 
18
- from model_compression_toolkit.core.common.logger import Logger
19
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.current_tpc import _current_tpc
20
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
21
- from model_compression_toolkit.core.common.target_platform.operators import OperatorsSet, OperatorSetConcat, \
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorSetConcat, \
22
22
  OperatorsSetBase
23
23
 
24
24
 
@@ -18,18 +18,17 @@ import itertools
18
18
  import pprint
19
19
  from typing import List, Any, Dict, Tuple
20
20
 
21
- from model_compression_toolkit.core.common.logger import Logger
22
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.operations_to_layers import \
21
+ from model_compression_toolkit.logger import Logger
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.operations_to_layers import \
23
23
  OperationsToLayers, OperationsSetToLayers
24
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
25
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.layer_filter_params import LayerFilterParams
26
- from model_compression_toolkit.core.common.immutable import ImmutableClass
27
- from model_compression_toolkit.core.common.graph.base_node import BaseNode
28
- from model_compression_toolkit.core.common.target_platform.op_quantization_config import QuantizationConfigOptions, \
24
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
25
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.layer_filter_params import LayerFilterParams
26
+ from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
27
+ from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import QuantizationConfigOptions, \
29
28
  OpQuantizationConfig
30
- from model_compression_toolkit.core.common.target_platform.operators import OperatorsSet, OperatorsSetBase
31
- from model_compression_toolkit.core.common.target_platform.target_platform_model import TargetPlatformModel
32
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.current_tpc import _current_tpc
29
+ from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSetBase
30
+ from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import TargetPlatformModel
31
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
33
32
 
34
33
 
35
34
  class TargetPlatformCapabilities(ImmutableClass):
@@ -163,26 +162,6 @@ class TargetPlatformCapabilities(ImmutableClass):
163
162
  """
164
163
  return self.tp_model.get_default_op_quantization_config()
165
164
 
166
- def get_qco_by_node(self,
167
- node: BaseNode) -> QuantizationConfigOptions:
168
- """
169
- Get the QuantizationConfigOptions of a node in a graph according
170
- to the mappings from layers/LayerFilterParams to the OperatorsSet in the TargetPlatformModel.
171
-
172
- Args:
173
- node: Node from graph to get its QuantizationConfigOptions.
174
-
175
- Returns:
176
- QuantizationConfigOptions of the node.
177
- """
178
- if node is None:
179
- Logger.error(f'Can not retrieve QC options for None node') # pragma: no cover
180
- for fl, qco in self.filterlayer2qco.items():
181
- if fl.match(node):
182
- return qco
183
- if node.type in self.layer2qco:
184
- return self.layer2qco.get(node.type)
185
- return self.tp_model.default_qco
186
165
 
187
166
  def _get_config_options_mapping(self) -> Tuple[Dict[Any, QuantizationConfigOptions],
188
167
  Dict[LayerFilterParams, QuantizationConfigOptions]]:
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.current_tpc import _current_tpc
16
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
17
17
 
18
18
 
19
19
  class TargetPlatformCapabilitiesComponent: