mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
  2. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
  3. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +12 -41
  5. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  6. model_compression_toolkit/core/__init__.py +14 -0
  7. model_compression_toolkit/core/analyzer.py +3 -2
  8. model_compression_toolkit/core/common/__init__.py +0 -1
  9. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  11. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  12. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  13. model_compression_toolkit/core/common/framework_info.py +1 -1
  14. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  15. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  16. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  18. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  19. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  20. model_compression_toolkit/core/common/memory_computation.py +1 -1
  21. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  23. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  26. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  28. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  29. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  30. model_compression_toolkit/core/common/model_collector.py +2 -2
  31. model_compression_toolkit/core/common/model_validation.py +1 -1
  32. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  33. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  34. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  35. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  36. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  37. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  38. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  39. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  50. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  51. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  52. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  54. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  55. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  56. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  57. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  58. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  59. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  60. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  61. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  62. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  63. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  65. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  66. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  67. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  68. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  69. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  72. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
  73. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  74. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  75. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  76. model_compression_toolkit/core/keras/constants.py +0 -7
  77. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  85. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  86. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  87. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  88. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  89. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  90. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  91. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  92. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  93. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  94. model_compression_toolkit/core/keras/reader/common.py +1 -1
  95. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  99. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  100. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  102. model_compression_toolkit/core/pytorch/constants.py +4 -6
  103. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  109. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  110. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  111. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  112. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  113. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  114. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  115. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  116. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  117. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  118. model_compression_toolkit/core/runner.py +7 -7
  119. model_compression_toolkit/exporter/__init__.py +5 -0
  120. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
  121. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
  124. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  125. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  126. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
  127. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
  128. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
  129. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
  130. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
  131. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
  132. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
  133. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
  135. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
  136. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
  137. model_compression_toolkit/gptq/__init__.py +6 -0
  138. model_compression_toolkit/gptq/common/gptq_config.py +57 -104
  139. model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
  140. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  141. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  142. model_compression_toolkit/gptq/common/gptq_training.py +30 -39
  143. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  144. model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
  145. model_compression_toolkit/gptq/keras/graph_info.py +8 -33
  146. model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
  147. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  148. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
  149. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  150. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
  151. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  152. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
  153. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
  154. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  155. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
  156. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  157. model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
  158. model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
  159. model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
  160. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  161. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
  162. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  163. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
  164. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  165. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
  166. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
  167. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  168. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
  169. model_compression_toolkit/gptq/runner.py +3 -2
  170. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
  171. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  172. model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
  173. model_compression_toolkit/ptq/__init__.py +3 -0
  174. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  175. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  176. model_compression_toolkit/qat/__init__.py +4 -0
  177. model_compression_toolkit/qat/common/__init__.py +1 -2
  178. model_compression_toolkit/qat/common/qat_config.py +3 -1
  179. model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
  180. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  181. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
  182. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
  183. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  184. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  185. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
  186. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
  187. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  211. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  212. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  213. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  214. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  215. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
  217. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  218. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  219. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  220. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  221. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  222. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
  223. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  224. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
  225. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  226. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
  227. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  228. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
  229. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  232. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  233. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  234. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  235. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  236. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  237. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  238. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  239. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  240. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  241. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  242. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  243. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  244. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  248. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  250. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  254. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  255. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  257. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  259. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  261. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  263. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  265. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  273. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  274. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  275. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  276. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  277. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  278. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  279. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  280. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  281. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  282. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  283. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  284. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  285. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  286. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  287. model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
  288. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
  289. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
  290. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  291. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  292. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  293. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  294. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  300. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  301. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  302. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  303. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  304. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  305. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  306. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  307. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.38.4)
2
+ Generator: bdist_wheel (0.40.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -13,48 +13,19 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.quantization.debug_config import DebugConfig
17
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GradientPTQConfigV2
18
- from model_compression_toolkit.gptq.common.gptq_quantizer_config import GPTQQuantizerConfig, SoftQuantizerConfig
19
- from model_compression_toolkit.core.common.quantization import quantization_config
20
- from model_compression_toolkit.core.common.mixed_precision import mixed_precision_quantization_config
21
- from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
22
- QuantizationErrorMethod, DEFAULTCONFIG
23
- from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
24
- from model_compression_toolkit.core.common import target_platform
25
- from model_compression_toolkit.core.tpc_models.get_target_platform_capabilities import get_target_platform_capabilities
26
- from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
27
- from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
28
- MixedPrecisionQuantizationConfig, MixedPrecisionQuantizationConfigV2
29
- from model_compression_toolkit.qat.common.qat_config import QATConfig, TrainingMethod
30
- from model_compression_toolkit.core.common.logger import set_log_folder
31
- from model_compression_toolkit.core.common.data_loader import FolderImageLoader
32
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo, ChannelAxis
33
- from model_compression_toolkit.core.common.defaultdict import DefaultDict
34
- from model_compression_toolkit.core.common import network_editors as network_editor
35
16
 
36
- from model_compression_toolkit.core.keras.quantization_facade import keras_post_training_quantization, \
37
- keras_post_training_quantization_mixed_precision
38
- from model_compression_toolkit.ptq.keras.quantization_facade import keras_post_training_quantization_experimental
39
- from model_compression_toolkit.gptq.keras.quantization_facade import \
40
- keras_gradient_post_training_quantization_experimental
41
- from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
42
- from model_compression_toolkit.qat.keras.quantization_facade import keras_quantization_aware_training_init, \
43
- keras_quantization_aware_training_finalize
44
- from model_compression_toolkit.qat.pytorch.quantization_facade import pytorch_quantization_aware_training_init, \
45
- pytorch_quantization_aware_training_finalize
46
- from model_compression_toolkit.core.pytorch.quantization_facade import pytorch_post_training_quantization, \
47
- pytorch_post_training_quantization_mixed_precision
48
- from model_compression_toolkit.ptq.pytorch.quantization_facade import pytorch_post_training_quantization_experimental
49
- from model_compression_toolkit.gptq.pytorch.quantization_facade import \
50
- pytorch_gradient_post_training_quantization_experimental
51
- from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config
17
+ from model_compression_toolkit.target_platform_capabilities import target_platform
18
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import get_target_platform_capabilities
19
+ from model_compression_toolkit import core
20
+ from model_compression_toolkit.logger import set_log_folder
21
+ from model_compression_toolkit.legacy.keras_quantization_facade import keras_post_training_quantization, keras_post_training_quantization_mixed_precision
22
+ from model_compression_toolkit.legacy.pytorch_quantization_facade import pytorch_post_training_quantization, pytorch_post_training_quantization_mixed_precision
23
+ from model_compression_toolkit import quantizers_infrastructure
24
+ from model_compression_toolkit import ptq
25
+ from model_compression_toolkit import qat
26
+ from model_compression_toolkit import exporter
27
+ from model_compression_toolkit import gptq
28
+ from model_compression_toolkit.gptq import GradientPTQConfig
52
29
 
53
- from model_compression_toolkit.core.keras.kpi_data_facade import keras_kpi_data, keras_kpi_data_experimental
54
- from model_compression_toolkit.core.pytorch.kpi_data_facade import pytorch_kpi_data, pytorch_kpi_data_experimental
55
-
56
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.load_model import keras_load_quantized_model
57
-
58
- from model_compression_toolkit.exporter.model_exporter import tflite_export_model, TFLiteExportMode, keras_export_model, KerasExportMode, pytorch_export_model, PyTorchExportMode
59
30
 
60
31
  __version__ = "1.8.0"
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
15
16
  import importlib
16
17
 
17
18
  # Supported frameworks in MCT:
@@ -47,9 +48,6 @@ LAST_AXIS = -1
47
48
  DATA_TYPE = 'dtype'
48
49
  FLOAT_32 = 'float32'
49
50
 
50
- # Version
51
- LATEST = 'latest'
52
-
53
51
  # Number of Tensorboard cosine-similarity plots to add:
54
52
  NUM_SAMPLES_DISTANCE_TENSORBOARD = 20
55
53
 
@@ -119,12 +117,10 @@ WEIGHTS_CHANNELS_AXIS = 'weights_channels_axis'
119
117
  DUMMY_NODE = 'dummy_node'
120
118
  DUMMY_TENSOR = 'dummy_tensor'
121
119
 
122
- # TP Model constants
123
- OPS_SET_LIST = 'ops_set_list'
124
120
 
125
121
  # TF Input node base name
126
122
  INPUT_BASE_NAME = 'base_input'
127
123
 
128
124
  # Jacobian-weights constants
129
125
  MIN_JACOBIANS_ITER = 10
130
- JACOBIANS_COMP_TOLERANCE = 1e-3
126
+ JACOBIANS_COMP_TOLERANCE = 1e-3
@@ -12,3 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
16
+ from model_compression_toolkit.core.common.data_loader import FolderImageLoader
17
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo, ChannelAxis
18
+ from model_compression_toolkit.core.common.defaultdict import DefaultDict
19
+ from model_compression_toolkit.core.common import network_editors as network_editor
20
+ from model_compression_toolkit.core.common.quantization.debug_config import DebugConfig
21
+ from model_compression_toolkit.core.common.quantization import quantization_config
22
+ from model_compression_toolkit.core.common.mixed_precision import mixed_precision_quantization_config
23
+ from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, QuantizationErrorMethod, DEFAULTCONFIG
24
+ from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
25
+ from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
26
+ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig, MixedPrecisionQuantizationConfigV2
27
+ from model_compression_toolkit.core.keras.kpi_data_facade import keras_kpi_data, keras_kpi_data_experimental
28
+ from model_compression_toolkit.core.pytorch.kpi_data_facade import pytorch_kpi_data, pytorch_kpi_data_experimental
@@ -17,14 +17,15 @@
17
17
  from typing import Callable
18
18
 
19
19
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
20
- from model_compression_toolkit.core.common import FrameworkInfo, Logger
21
- from model_compression_toolkit.core.common.constants import NUM_SAMPLES_DISTANCE_TENSORBOARD
20
+ from model_compression_toolkit.core.common import FrameworkInfo
21
+ from model_compression_toolkit.constants import NUM_SAMPLES_DISTANCE_TENSORBOARD
22
22
  from model_compression_toolkit.core.common.graph.base_graph import Graph
23
23
 
24
24
  from model_compression_toolkit.core.common.similarity_analyzer import compute_cs
25
25
  from model_compression_toolkit.core.common.visualization.nn_visualizer import NNVisualizer
26
26
 
27
27
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
28
+ from model_compression_toolkit.logger import Logger
28
29
 
29
30
 
30
31
  def analyzer_model_quantization(representative_data_gen: Callable,
@@ -17,7 +17,6 @@ from model_compression_toolkit.core.common.base_substitutions import BaseSubstit
17
17
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
18
18
  from model_compression_toolkit.core.common.graph.base_graph import Graph
19
19
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
20
- from model_compression_toolkit.core.common.logger import Logger
21
20
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, DEFAULTCONFIG
22
21
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import max_power_of_two
23
22
  from model_compression_toolkit.core.common.collectors.statistics_collector import StatsCollector, NoStatsCollector
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import numpy as np
17
- from model_compression_toolkit.core.common.logger import Logger
17
+ from model_compression_toolkit.logger import Logger
18
18
 
19
19
 
20
20
  class BaseCollector(object):
@@ -17,7 +17,7 @@
17
17
  import numpy as np
18
18
 
19
19
  from model_compression_toolkit.core.common.collectors.base_collector import BaseCollector
20
- from model_compression_toolkit.core.common.constants import LAST_AXIS
20
+ from model_compression_toolkit.constants import LAST_AXIS
21
21
 
22
22
 
23
23
  class MeanCollector(BaseCollector):
@@ -16,7 +16,7 @@
16
16
  import numpy as np
17
17
 
18
18
  from model_compression_toolkit.core.common.collectors.base_collector import BaseCollector
19
- from model_compression_toolkit.core.common.constants import LAST_AXIS
19
+ from model_compression_toolkit.constants import LAST_AXIS
20
20
 
21
21
 
22
22
  class MinMaxPerChannelCollector(BaseCollector):
@@ -17,7 +17,7 @@ from typing import Callable, Any, List, Tuple, Dict
17
17
 
18
18
  import numpy as np
19
19
 
20
- from model_compression_toolkit import MixedPrecisionQuantizationConfigV2
20
+ from model_compression_toolkit.core import MixedPrecisionQuantizationConfigV2
21
21
  from model_compression_toolkit.core import common
22
22
  from model_compression_toolkit.core.common import BaseNode
23
23
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
@@ -288,13 +288,6 @@ class FrameworkImplementation(ABC):
288
288
  f'framework\'s get_substitutions_after_second_moment_correction '
289
289
  f'method.') # pragma: no cover
290
290
 
291
- @abstractmethod
292
- def get_gptq_trainer_obj(self):
293
- """
294
- Returns: GPTQTrainer object
295
- """
296
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
297
- f'framework\'s get_gptq_trainer method.') # pragma: no cover
298
291
 
299
292
  @abstractmethod
300
293
  def get_sensitivity_evaluator(self,
@@ -22,7 +22,7 @@ from typing import Dict, Any, List
22
22
 
23
23
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
24
24
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
25
- from model_compression_toolkit.core.common.target_platform.op_quantization_config import QuantizationMethod
25
+ from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import QuantizationMethod
26
26
 
27
27
 
28
28
  class ChannelAxis(Enum):
@@ -16,8 +16,8 @@ import copy
16
16
  from typing import Any, List
17
17
  from model_compression_toolkit.core.common.graph.base_graph import Graph
18
18
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
19
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework import TargetPlatformCapabilities
20
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.layer_filter_params import LayerFilterParams
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.layer_filter_params import LayerFilterParams
21
21
 
22
22
 
23
23
  def filter_fusing_patterns(fusing_patterns: List[List[Any]], node: BaseNode, idx: int = 0) -> List[List[Any]]:
@@ -33,7 +33,7 @@ def filter_fusing_patterns(fusing_patterns: List[List[Any]], node: BaseNode, idx
33
33
  valid_fusing_patterns = []
34
34
  for i,fusing_pattern in enumerate(fusing_patterns):
35
35
  if idx < len(fusing_pattern):
36
- if (type(fusing_pattern[idx]) == LayerFilterParams and fusing_pattern[idx].match(node)) or fusing_pattern[idx] == node.type:
36
+ if (type(fusing_pattern[idx]) == LayerFilterParams and node.is_match_filter_params(fusing_pattern[idx])) or fusing_pattern[idx] == node.type:
37
37
  valid_fusing_patterns.append(fusing_pattern)
38
38
 
39
39
  # Return only valid patterns for this node
@@ -57,7 +57,7 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List[BaseNode]) ->
57
57
  continue
58
58
  counter = 0
59
59
  for i,layer in enumerate(fusing_pattern):
60
- if (type(layer) == LayerFilterParams and layer.match(nodes[i])) or layer == nodes[i].type:
60
+ if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or layer == nodes[i].type:
61
61
  counter += 1
62
62
  if counter == fusion_depth:
63
63
  return True
@@ -30,8 +30,8 @@ from model_compression_toolkit.core.common.graph.base_node import BaseNode
30
30
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
31
31
  from model_compression_toolkit.core.common.collectors.statistics_collector import scale_statistics, shift_statistics
32
32
  from model_compression_toolkit.core.common.user_info import UserInformation
33
- from model_compression_toolkit.core.common.logger import Logger
34
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework import TargetPlatformCapabilities
33
+ from model_compression_toolkit.logger import Logger
34
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
35
35
 
36
36
  OutTensor = namedtuple('OutTensor', 'node node_out_index')
37
37
 
@@ -18,8 +18,11 @@ from typing import Dict, Any, Tuple, List
18
18
 
19
19
  import numpy as np
20
20
 
21
- from model_compression_toolkit.core.common.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE, \
21
+ from model_compression_toolkit.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE, \
22
22
  ACTIVATION_NBITS_ATTRIBUTE
23
+ from model_compression_toolkit.logger import Logger
24
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationConfigOptions, \
25
+ TargetPlatformCapabilities, LayerFilterParams
23
26
 
24
27
 
25
28
  class BaseNode:
@@ -429,3 +432,56 @@ class BaseNode:
429
432
 
430
433
  return len(self.candidates_quantization_cfg) > 0 and \
431
434
  any([c.activation_quantization_cfg.enable_activation_quantization for c in self.candidates_quantization_cfg])
435
+
436
+ def get_qco(self, tpc: TargetPlatformCapabilities) -> QuantizationConfigOptions:
437
+ """
438
+ Get the QuantizationConfigOptions of the node according
439
+ to the mappings from layers/LayerFilterParams to the OperatorsSet in the TargetPlatformModel.
440
+
441
+ Args:
442
+ tpc: TPC to extract the QuantizationConfigOptions for the node
443
+
444
+ Returns:
445
+ QuantizationConfigOptions of the node.
446
+ """
447
+
448
+ if tpc is None:
449
+ Logger.error(f'Can not retrieve QC options for None TPC') # pragma: no cover
450
+
451
+ for fl, qco in tpc.filterlayer2qco.items():
452
+ if self.is_match_filter_params(fl):
453
+ return qco
454
+ if self.type in tpc.layer2qco:
455
+ return tpc.layer2qco.get(self.type)
456
+ return tpc.tp_model.default_qco
457
+
458
+
459
+ def is_match_filter_params(self, layer_filter_params: LayerFilterParams) -> bool:
460
+ """
461
+ Check if the node matches a LayerFilterParams according to its
462
+ layer, conditions and keyword-arguments.
463
+
464
+ Args:
465
+ layer_filter_params: LayerFilterParams to check if the node matches its properties.
466
+
467
+ Returns:
468
+ Whether the node matches to the LayerFilterParams properties.
469
+ """
470
+ # Check the node has the same type as the layer in LayerFilterParams
471
+ if layer_filter_params.layer != self.type:
472
+ return False
473
+
474
+ # Get attributes from node to filter
475
+ layer_config = self.framework_attr
476
+ if hasattr(self, "op_call_kwargs"):
477
+ layer_config.update(self.op_call_kwargs)
478
+
479
+ for attr, value in layer_filter_params.kwargs.items():
480
+ if layer_config.get(attr) != value:
481
+ return False
482
+
483
+ for c in layer_filter_params.conditions:
484
+ if not c.match(layer_config):
485
+ return False
486
+
487
+ return True
@@ -16,7 +16,7 @@
16
16
  from typing import Any, List, Tuple
17
17
  import networkx as nx
18
18
 
19
- from model_compression_toolkit.core.common.logger import Logger
19
+ from model_compression_toolkit.logger import Logger
20
20
 
21
21
 
22
22
  class DirectedBipartiteGraph(nx.DiGraph):
@@ -16,7 +16,7 @@ import copy
16
16
  from typing import List, Tuple, Dict
17
17
 
18
18
  from model_compression_toolkit.core.common import BaseNode
19
- from model_compression_toolkit.core.common.constants import DUMMY_TENSOR, DUMMY_NODE
19
+ from model_compression_toolkit.constants import DUMMY_TENSOR, DUMMY_NODE
20
20
  from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut
21
21
  from model_compression_toolkit.core.common.graph.memory_graph.memory_element import MemoryElements
22
22
  from model_compression_toolkit.core.common.graph.memory_graph.memory_graph import ActivationMemoryTensor, MemoryGraph
@@ -15,8 +15,8 @@
15
15
 
16
16
  from typing import Dict, Any, Tuple
17
17
 
18
- from model_compression_toolkit import FrameworkInfo
19
- from model_compression_toolkit.core.common.constants import VIRTUAL_ACTIVATION_WEIGHTS_NODE_PREFIX, \
18
+ from model_compression_toolkit.core import FrameworkInfo
19
+ from model_compression_toolkit.constants import VIRTUAL_ACTIVATION_WEIGHTS_NODE_PREFIX, \
20
20
  VIRTUAL_WEIGHTS_SUFFIX, VIRTUAL_ACTIVATION_SUFFIX, FLOAT_BITWIDTH
21
21
 
22
22
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from model_compression_toolkit.core.common.constants import BITS_TO_BYTES
15
+ from model_compression_toolkit.constants import BITS_TO_BYTES
16
16
 
17
17
 
18
18
  def compute_quantize_tensor_memory_bytes(tensor_size: float, n_bits: int) -> float:
@@ -18,11 +18,11 @@ import copy
18
18
  from typing import Any, List
19
19
 
20
20
  from model_compression_toolkit.core.common import Graph, BaseNode
21
- from model_compression_toolkit.core.common.logger import Logger
21
+ from model_compression_toolkit.logger import Logger
22
22
 
23
23
 
24
24
  def set_bit_widths(mixed_precision_enable: bool,
25
- graph_to_set_bit_widths: Graph,
25
+ graph: Graph,
26
26
  bit_widths_config: List[int] = None) -> Graph:
27
27
  """
28
28
  Set bit widths configuration to nodes in a graph. For each node, use the desired index
@@ -30,13 +30,11 @@ def set_bit_widths(mixed_precision_enable: bool,
30
30
 
31
31
  Args:
32
32
  mixed_precision_enable: Is mixed precision enabled.
33
- graph_to_set_bit_widths: A prepared for quantization graph to set its bit widths.
33
+ graph: A prepared for quantization graph to set its bit widths.
34
34
  bit_widths_config: MP configuration (a list of indices: one for each node's candidate
35
35
  quantization configuration).
36
36
 
37
37
  """
38
- graph = copy.deepcopy(graph_to_set_bit_widths)
39
-
40
38
  if mixed_precision_enable:
41
39
  assert all([len(n.candidates_quantization_cfg) > 0 for n in graph.get_configurable_sorted_nodes()]), \
42
40
  "All configurable nodes in graph should have at least one candidate configuration in mixed precision mode"
@@ -15,14 +15,13 @@
15
15
  from typing import Callable, Any
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit import FrameworkInfo, KPI, CoreConfig
18
+ from model_compression_toolkit.core import FrameworkInfo, KPI, CoreConfig
19
19
  from model_compression_toolkit.core.common import Graph
20
- from model_compression_toolkit.core.common.constants import FLOAT_BITWIDTH
20
+ from model_compression_toolkit.constants import FLOAT_BITWIDTH
21
21
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
22
22
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
23
- from model_compression_toolkit.core.common.target_platform import TargetPlatformCapabilities
23
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
24
24
  from model_compression_toolkit.core.runner import read_model_to_graph, get_finalized_graph
25
- from model_compression_toolkit.core.common.logger import Logger
26
25
 
27
26
 
28
27
  def compute_kpi_data(in_model: Any,
@@ -18,14 +18,14 @@ from typing import List
18
18
 
19
19
  import numpy as np
20
20
 
21
- from model_compression_toolkit import FrameworkInfo
21
+ from model_compression_toolkit.core import FrameworkInfo
22
22
  from model_compression_toolkit.core.common import Graph, BaseNode
23
- from model_compression_toolkit.core.common.constants import BITS_TO_BYTES, FLOAT_BITWIDTH
23
+ from model_compression_toolkit.constants import BITS_TO_BYTES, FLOAT_BITWIDTH
24
24
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
25
25
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
26
26
  from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
27
27
  VirtualSplitWeightsNode, VirtualSplitActivationNode
28
- from model_compression_toolkit.core.common.logger import Logger
28
+ from model_compression_toolkit.logger import Logger
29
29
 
30
30
 
31
31
  def weights_size_kpi(mp_cfg: List[int],
@@ -16,7 +16,7 @@
16
16
  from enum import Enum
17
17
  from typing import List, Callable, Tuple
18
18
 
19
- from model_compression_toolkit.core.common import Logger
19
+ from model_compression_toolkit.logger import Logger
20
20
  from model_compression_toolkit.core.common.mixed_precision.distance_weighting import get_average_weights
21
21
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, DEFAULTCONFIG
22
22
  from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
@@ -18,8 +18,8 @@ from enum import Enum
18
18
  import numpy as np
19
19
  from typing import List, Callable, Dict
20
20
 
21
- from model_compression_toolkit import MixedPrecisionQuantizationConfigV2
22
- from model_compression_toolkit.core.common import Graph, Logger
21
+ from model_compression_toolkit.core import MixedPrecisionQuantizationConfigV2
22
+ from model_compression_toolkit.core.common import Graph
23
23
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI, KPITarget
24
24
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_functions_mapping import kpi_functions_mapping
25
25
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
@@ -30,6 +30,7 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
30
30
  from model_compression_toolkit.core.common.mixed_precision.solution_refinement_procedure import \
31
31
  greedy_solution_refinement_procedure
32
32
  from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
33
+ from model_compression_toolkit.logger import Logger
33
34
 
34
35
 
35
36
  class BitWidthSearchMethod(Enum):
@@ -18,7 +18,7 @@ from typing import Dict, List
18
18
  import numpy as np
19
19
 
20
20
  from model_compression_toolkit.core.common import BaseNode
21
- from model_compression_toolkit.core.common.logger import Logger
21
+ from model_compression_toolkit.logger import Logger
22
22
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
23
23
  from model_compression_toolkit.core.common.graph.base_graph import Graph
24
24
  from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
@@ -18,7 +18,7 @@ from pulp import *
18
18
  from tqdm import tqdm
19
19
  from typing import Dict, List, Tuple, Callable
20
20
 
21
- from model_compression_toolkit.core.common import Logger
21
+ from model_compression_toolkit.logger import Logger
22
22
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI, KPITarget
23
23
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import MixedPrecisionSearchManager
24
24
 
@@ -17,10 +17,10 @@ import copy
17
17
  import numpy as np
18
18
  from typing import Callable, Any, List
19
19
 
20
- from model_compression_toolkit import FrameworkInfo, MixedPrecisionQuantizationConfigV2
20
+ from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfigV2
21
21
  from model_compression_toolkit.core.common import Graph, BaseNode
22
22
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
23
- from model_compression_toolkit.core.common import Logger
23
+ from model_compression_toolkit.logger import Logger
24
24
 
25
25
 
26
26
  class SensitivityEvaluation:
@@ -15,12 +15,12 @@
15
15
 
16
16
  from typing import List
17
17
 
18
- from model_compression_toolkit import KPI
18
+ from model_compression_toolkit.core import KPI
19
19
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import \
20
20
  MixedPrecisionSearchManager
21
21
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
22
22
  CandidateNodeQuantizationConfig
23
- from model_compression_toolkit.core.common.logger import Logger
23
+ from model_compression_toolkit.logger import Logger
24
24
  import numpy as np
25
25
 
26
26
 
@@ -17,11 +17,11 @@
17
17
  import numpy as np
18
18
  from typing import List
19
19
 
20
- from model_compression_toolkit import FrameworkInfo
20
+ from model_compression_toolkit.core import FrameworkInfo
21
21
  from model_compression_toolkit.core import common
22
22
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
23
23
  from model_compression_toolkit.core.common.graph.base_graph import Graph
24
- from model_compression_toolkit.core.common.logger import Logger
24
+ from model_compression_toolkit.logger import Logger
25
25
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
26
26
 
27
27
 
@@ -1,7 +1,7 @@
1
1
  from abc import abstractmethod
2
2
  from typing import Any
3
3
 
4
- from model_compression_toolkit import FrameworkInfo
4
+ from model_compression_toolkit.core import FrameworkInfo
5
5
 
6
6
 
7
7
  class ModelValidation:
@@ -17,7 +17,10 @@ from abc import ABC, abstractmethod
17
17
  from collections import namedtuple
18
18
  from typing import Callable
19
19
 
20
- from model_compression_toolkit.core.common import Graph, Logger
20
+ from model_compression_toolkit.core.common import Graph
21
+ from model_compression_toolkit.logger import Logger
22
+
23
+
21
24
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
25
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
23
26
  from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import copy
16
15
  from typing import List
17
16
 
18
17
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
@@ -36,7 +35,6 @@ def edit_network_graph(graph: Graph,
36
35
  The graph after it has been applied the edit rules from the network editor list.
37
36
 
38
37
  """
39
- # graph = copy.deepcopy(graph_to_edit)
40
38
  for edit_rule in network_editor:
41
39
  filtered_nodes = graph.filter(edit_rule.filter)
42
40
  for node in filtered_nodes:
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from model_compression_toolkit.core.common.constants import ACTIVATION_QUANTIZATION_CFG, WEIGHTS_QUANTIZATION_CFG, QC, \
15
+ from model_compression_toolkit.constants import ACTIVATION_QUANTIZATION_CFG, WEIGHTS_QUANTIZATION_CFG, QC, \
16
16
  OP_CFG, ACTIVATION_QUANTIZATION_FN, WEIGHTS_QUANTIZATION_FN, ACTIVATION_QUANT_PARAMS_FN, WEIGHTS_QUANT_PARAMS_FN, \
17
17
  WEIGHTS_CHANNELS_AXIS
18
18
  from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
@@ -16,12 +16,12 @@ import copy
16
16
  from typing import List
17
17
 
18
18
  from model_compression_toolkit.core.common import Graph, BaseNode
19
- from model_compression_toolkit.core.common.constants import FLOAT_BITWIDTH
19
+ from model_compression_toolkit.constants import FLOAT_BITWIDTH
20
20
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
21
21
  CandidateNodeQuantizationConfig
22
22
 
23
23
 
24
- def filter_nodes_candidates(graph_to_filter: Graph):
24
+ def filter_nodes_candidates(graph: Graph):
25
25
  """
26
26
  Filters the graph's nodes candidates configuration list.
27
27
  We apply this after mark activation operation to eliminate nodes that their activation are no longer being quantized
@@ -29,9 +29,8 @@ def filter_nodes_candidates(graph_to_filter: Graph):
29
29
  Updating the lists is preformed inplace on the graph object.
30
30
 
31
31
  Args:
32
- graph_to_filter: Graph for which to add quantization info to each node.
32
+ graph: Graph for which to add quantization info to each node.
33
33
  """
34
- graph = copy.deepcopy(graph_to_filter)
35
34
  nodes = list(graph.nodes)
36
35
  for n in nodes:
37
36
  n.candidates_quantization_cfg = filter_node_candidates(node=n)
@@ -18,13 +18,13 @@ from typing import Callable, Any
18
18
 
19
19
  import numpy as np
20
20
 
21
- from model_compression_toolkit.core.common.logger import Logger
21
+ from model_compression_toolkit.logger import Logger
22
22
  from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
23
23
  get_activation_quantization_params_fn, get_weights_quantization_params_fn
24
24
 
25
25
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
26
26
  QuantizationErrorMethod
27
- from model_compression_toolkit.core.common.target_platform import OpQuantizationConfig
27
+ from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig
28
28
 
29
29
 
30
30
  ##########################################
@@ -256,7 +256,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
256
256
  self.weights_n_bits = op_cfg.weights_n_bits
257
257
  self.weights_bias_correction = qc.weights_bias_correction
258
258
  self.weights_second_moment_correction = qc.weights_second_moment_correction
259
- self.weights_per_channel_threshold = qc.weights_per_channel_threshold
259
+ self.weights_per_channel_threshold = op_cfg.weights_per_channel_threshold
260
260
  self.enable_weights_quantization = op_cfg.enable_weights_quantization
261
261
  self.min_threshold = qc.min_threshold
262
262
  self.l_p_value = qc.l_p_value