mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
  2. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
  3. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +12 -41
  5. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  6. model_compression_toolkit/core/__init__.py +14 -0
  7. model_compression_toolkit/core/analyzer.py +3 -2
  8. model_compression_toolkit/core/common/__init__.py +0 -1
  9. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  11. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  12. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  13. model_compression_toolkit/core/common/framework_info.py +1 -1
  14. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  15. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  16. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  18. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  19. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  20. model_compression_toolkit/core/common/memory_computation.py +1 -1
  21. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  23. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  26. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  28. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  29. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  30. model_compression_toolkit/core/common/model_collector.py +2 -2
  31. model_compression_toolkit/core/common/model_validation.py +1 -1
  32. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  33. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  34. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  35. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  36. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  37. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  38. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  39. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  50. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  51. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  52. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  54. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  55. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  56. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  57. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  58. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  59. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  60. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  61. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  62. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  63. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  65. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  66. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  67. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  68. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  69. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  72. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
  73. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  74. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  75. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  76. model_compression_toolkit/core/keras/constants.py +0 -7
  77. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  85. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  86. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  87. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  88. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  89. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  90. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  91. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  92. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  93. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  94. model_compression_toolkit/core/keras/reader/common.py +1 -1
  95. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  99. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  100. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  102. model_compression_toolkit/core/pytorch/constants.py +4 -6
  103. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  109. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  110. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  111. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  112. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  113. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  114. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  115. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  116. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  117. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  118. model_compression_toolkit/core/runner.py +7 -7
  119. model_compression_toolkit/exporter/__init__.py +5 -0
  120. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
  121. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
  124. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  125. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  126. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
  127. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
  128. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
  129. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
  130. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
  131. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
  132. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
  133. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
  135. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
  136. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
  137. model_compression_toolkit/gptq/__init__.py +6 -0
  138. model_compression_toolkit/gptq/common/gptq_config.py +57 -104
  139. model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
  140. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  141. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  142. model_compression_toolkit/gptq/common/gptq_training.py +30 -39
  143. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  144. model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
  145. model_compression_toolkit/gptq/keras/graph_info.py +8 -33
  146. model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
  147. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  148. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
  149. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  150. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
  151. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  152. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
  153. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
  154. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  155. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
  156. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  157. model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
  158. model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
  159. model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
  160. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  161. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
  162. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  163. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
  164. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  165. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
  166. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
  167. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  168. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
  169. model_compression_toolkit/gptq/runner.py +3 -2
  170. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
  171. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  172. model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
  173. model_compression_toolkit/ptq/__init__.py +3 -0
  174. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  175. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  176. model_compression_toolkit/qat/__init__.py +4 -0
  177. model_compression_toolkit/qat/common/__init__.py +1 -2
  178. model_compression_toolkit/qat/common/qat_config.py +3 -1
  179. model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
  180. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  181. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
  182. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
  183. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  184. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  185. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
  186. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
  187. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  211. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  212. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  213. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  214. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  215. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
  217. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  218. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  219. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  220. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  221. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  222. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
  223. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  224. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
  225. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  226. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
  227. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  228. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
  229. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  232. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  233. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  234. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  235. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  236. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  237. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  238. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  239. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  240. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  241. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  242. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  243. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  244. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  248. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  250. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  254. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  255. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  257. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  259. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  261. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  263. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  265. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  273. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  274. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  275. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  276. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  277. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  278. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  279. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  280. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  281. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  282. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  283. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  284. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  285. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  286. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  287. model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
  288. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
  289. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
  290. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  291. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  292. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  293. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  294. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  300. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  301. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  302. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  303. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  304. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  305. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  306. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  307. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -20,7 +20,7 @@ from model_compression_toolkit.core.common import BaseNode
20
20
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
21
21
  from model_compression_toolkit.core.common.substitutions.residual_collapsing import ResidualCollapsing
22
22
  from model_compression_toolkit.core.pytorch.constants import KERNEL
23
- from model_compression_toolkit.core.common.logger import Logger
23
+ from model_compression_toolkit.logger import Logger
24
24
 
25
25
 
26
26
  def residual_collapsing_node_matchers() -> Tuple[NodeOperationMatcher, NodeOperationMatcher]:
@@ -21,7 +21,7 @@ from torch import reshape
21
21
  from torch.nn.functional import hardswish, silu, prelu, elu
22
22
  from torch.nn.functional import avg_pool2d
23
23
 
24
- from model_compression_toolkit import CoreConfig, FrameworkInfo
24
+ from model_compression_toolkit.core import CoreConfig, FrameworkInfo
25
25
  from model_compression_toolkit.core import common
26
26
  from model_compression_toolkit.core.common import BaseNode, Graph
27
27
  from model_compression_toolkit.core.common.graph.graph_matchers import EdgeMatcher
@@ -15,21 +15,21 @@
15
15
 
16
16
  from typing import Callable
17
17
 
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.constants import PYTORCH
20
- from model_compression_toolkit.core.common.target_platform import TargetPlatformCapabilities
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import PYTORCH
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
21
21
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
22
22
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
23
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_data import compute_kpi_data
24
24
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
25
25
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
26
26
  MixedPrecisionQuantizationConfig, DEFAULT_MIXEDPRECISION_CONFIG, MixedPrecisionQuantizationConfigV2
27
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
27
+ from model_compression_toolkit.constants import FOUND_TORCH
28
28
 
29
29
  if FOUND_TORCH:
30
30
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
31
31
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
32
- from model_compression_toolkit.core.pytorch.constants import DEFAULT_TP_MODEL
32
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
33
33
  from torch.nn import Module
34
34
 
35
35
  from model_compression_toolkit import get_target_platform_capabilities
@@ -51,7 +51,7 @@ if FOUND_TORCH:
51
51
  representative_data_gen (Callable): Dataset used for calibration.
52
52
  quant_config (MixedPrecisionQuantizationConfig): MixedPrecisionQuantizationConfig containing parameters of how the model should be quantized.
53
53
  fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default PyTorch info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/pytorch/default_framework_info.py>`_
54
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to. `Default PyTorch TPC <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/tpc_models/pytorch_tp_models/pytorch_default.py>`_
54
+ target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
55
55
 
56
56
  Returns:
57
57
  A KPI object with total weights parameters sum, max activation tensor and total kpi.
@@ -75,7 +75,7 @@ if FOUND_TORCH:
75
75
  Import mct and call for KPI data calculation:
76
76
 
77
77
  >>> import model_compression_toolkit as mct
78
- >>> kpi_data = mct.pytorch_kpi_data(module, repr_datagen)
78
+ >>> kpi_data = mct.core.pytorch_kpi_data(module, repr_datagen)
79
79
 
80
80
  """
81
81
 
@@ -111,7 +111,7 @@ if FOUND_TORCH:
111
111
  representative_data_gen (Callable): Dataset used for calibration.
112
112
  core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision
113
113
  fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default PyTorch info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/pytorch/default_framework_info.py>`_
114
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to. `Default PyTorch TPC <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/tpc_models/pytorch_tp_models/pytorch_default.py>`_
114
+ target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
115
115
 
116
116
  Returns:
117
117
 
@@ -132,7 +132,7 @@ if FOUND_TORCH:
132
132
  Import mct and call for KPI data calculation:
133
133
 
134
134
  >>> import model_compression_toolkit as mct
135
- >>> kpi_data = mct.pytorch_kpi_data(module, repr_datagen)
135
+ >>> kpi_data = mct.core.pytorch_kpi_data(module, repr_datagen)
136
136
 
137
137
  """
138
138
 
@@ -18,7 +18,7 @@ from typing import Any, List
18
18
  import torch
19
19
  import copy
20
20
 
21
- from model_compression_toolkit import FrameworkInfo
21
+ from model_compression_toolkit.core import FrameworkInfo
22
22
  from model_compression_toolkit.core.common import BaseNode
23
23
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
24
24
  from model_compression_toolkit.core.pytorch.utils import set_model, to_torch_tensor
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import operator
16
+ from copy import deepcopy
16
17
  from typing import List, Any, Tuple, Callable, Type, Dict
17
18
 
18
19
  import numpy as np
@@ -22,7 +23,7 @@ from torch.nn import Conv2d, ConvTranspose2d, Linear
22
23
  from torch.nn import Module, Sigmoid, Softmax
23
24
 
24
25
  import model_compression_toolkit.core.pytorch.constants as pytorch_constants
25
- from model_compression_toolkit import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfigV2
26
+ from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfigV2
26
27
  from model_compression_toolkit.core import common
27
28
  from model_compression_toolkit.core.common import Graph, BaseNode
28
29
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
@@ -74,10 +75,7 @@ from model_compression_toolkit.core.pytorch.pytorch_node_prior_info import creat
74
75
  from model_compression_toolkit.core.pytorch.reader.reader import model_reader
75
76
  from model_compression_toolkit.core.pytorch.statistics_correction.apply_second_moment_correction import \
76
77
  pytorch_apply_second_moment_correction
77
- from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
78
- from model_compression_toolkit.core.pytorch.utils import torch_tensor_to_numpy
79
- from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
80
- from model_compression_toolkit.gptq.pytorch.gptq_training import PytorchGPTQTrainer
78
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model
81
79
 
82
80
 
83
81
  class PytorchImplementation(FrameworkImplementation):
@@ -127,7 +125,9 @@ class PytorchImplementation(FrameworkImplementation):
127
125
  Returns:
128
126
  Graph representing the input module.
129
127
  """
130
- return model_reader(module, representative_data_gen, self.to_numpy, self.to_tensor)
128
+ _module = deepcopy(module)
129
+ _module.eval()
130
+ return model_reader(_module, representative_data_gen, self.to_numpy, self.to_tensor)
131
131
 
132
132
  def model_builder(self,
133
133
  graph: Graph,
@@ -323,12 +323,6 @@ class PytorchImplementation(FrameworkImplementation):
323
323
  substitutions_list.append(pytorch_batchnorm_refusing())
324
324
  return substitutions_list
325
325
 
326
- def get_gptq_trainer_obj(self) -> Type[GPTQTrainer]:
327
- """
328
- Returns: GPTQTrainer object
329
- """
330
- return PytorchGPTQTrainer
331
-
332
326
  def get_sensitivity_evaluator(self,
333
327
  graph: Graph,
334
328
  quant_config: MixedPrecisionQuantizationConfigV2,
@@ -16,7 +16,7 @@ from typing import Any, Tuple
16
16
  import numpy as np
17
17
  from torch.nn import BatchNorm2d
18
18
 
19
- from model_compression_toolkit import FrameworkInfo
19
+ from model_compression_toolkit.core import FrameworkInfo
20
20
  from model_compression_toolkit.core.common import BaseNode, Graph
21
21
  from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
22
22
  from model_compression_toolkit.core.pytorch.constants import MOVING_MEAN, MOVING_VARIANCE, GAMMA, BETA
@@ -12,10 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Tuple, Callable
15
+ from typing import Callable
16
16
  import torch
17
17
 
18
- from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX
18
+ from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX
19
19
  from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import threshold_is_power_of_two
20
20
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero
21
21
 
@@ -3,7 +3,7 @@ from typing import Dict, Callable
3
3
  import torch
4
4
  import numpy as np
5
5
 
6
- from model_compression_toolkit.core.common.constants import SIGNED, CLUSTER_CENTERS, THRESHOLD, MULTIPLIER_N_BITS, EPS
6
+ from model_compression_toolkit.constants import SIGNED, CLUSTER_CENTERS, THRESHOLD, MULTIPLIER_N_BITS, EPS
7
7
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
8
8
 
9
9
 
@@ -25,6 +25,7 @@ from model_compression_toolkit.core.common.graph.functional_node import Function
25
25
  from model_compression_toolkit.core.pytorch.constants import OUTPUT, PLACEHOLDER, TENSOR_META, CALL_FUNCTION, TYPE, \
26
26
  CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, GET_ATTR, CONSTANT, BUFFER
27
27
  from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder, ConstantHolder, BufferHolder
28
+ from model_compression_toolkit.logger import Logger
28
29
 
29
30
 
30
31
  def extract_holder_weights(constant_name, node_target, model, weights, to_numpy):
@@ -64,6 +65,7 @@ def nodes_builder(model: GraphModule,
64
65
  Args:
65
66
  model: Pytorch FX model.
66
67
  module_dict: A dictionary of the Pyotrch model's named modules.
68
+ to_numpy: A function to convert a Tensor to numpy array
67
69
 
68
70
  Returns:
69
71
  A list of Graph nodes that were built from the fx GraphModule nodes.
@@ -91,7 +93,7 @@ def nodes_builder(model: GraphModule,
91
93
  node_type = node.target
92
94
  if node_type == getattr:
93
95
  node_has_activation = False
94
- common.Logger.warning(
96
+ Logger.warning(
95
97
  'Pytorch model has a parameter or constant Tensor value. This can cause unexpected behaviour when '
96
98
  'converting the model.')
97
99
  elif node.op == PLACEHOLDER:
@@ -112,7 +114,7 @@ def nodes_builder(model: GraphModule,
112
114
  else:
113
115
  node_type = ConstantHolder
114
116
  node_has_activation = False
115
- common.Logger.warning(
117
+ Logger.warning(
116
118
  'Pytorch model has a parameter or constant Tensor value. This can cause unexpected behaviour when '
117
119
  'converting the model.')
118
120
  else:
@@ -18,7 +18,7 @@ from typing import Any, Callable
18
18
  import torch
19
19
  from tqdm import tqdm
20
20
 
21
- from model_compression_toolkit import CoreConfig
21
+ from model_compression_toolkit.core import CoreConfig
22
22
  from model_compression_toolkit.core import common
23
23
  from model_compression_toolkit.core.pytorch.constants import GAMMA, BETA, MOVING_MEAN, MOVING_VARIANCE
24
24
  from model_compression_toolkit.core.pytorch.utils import set_model, to_torch_tensor
@@ -22,7 +22,7 @@ from tqdm import tqdm
22
22
 
23
23
  from model_compression_toolkit.core import common
24
24
  from model_compression_toolkit.core.common import FrameworkInfo
25
- from model_compression_toolkit.core.common import Logger
25
+ from model_compression_toolkit.logger import Logger
26
26
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
27
27
  from model_compression_toolkit.core.common.fusion.layer_fusing import fusion
28
28
  from model_compression_toolkit.core.common.graph.base_graph import Graph
@@ -48,7 +48,7 @@ from model_compression_toolkit.core.common.statistics_correction.statistics_corr
48
48
  from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
49
49
  from model_compression_toolkit.core.common.substitutions.linear_collapsing_substitution import \
50
50
  linear_collapsing_substitute
51
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework import TargetPlatformCapabilities
51
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
52
52
  from model_compression_toolkit.core.common.visualization.final_config_visualizer import \
53
53
  WeightsFinalBitwidthConfigVisualizer, \
54
54
  ActivationFinalBitwidthConfigVisualizer
@@ -143,9 +143,9 @@ def core_runner(in_model: Any,
143
143
  weights_conf_nodes_bitwidth = tg.get_final_weights_config()
144
144
  activation_conf_nodes_bitwidth = tg.get_final_activation_config()
145
145
 
146
- common.Logger.info(
146
+ Logger.info(
147
147
  f'Final weights bit-width configuration: {[node_b[1] for node_b in weights_conf_nodes_bitwidth]}')
148
- common.Logger.info(
148
+ Logger.info(
149
149
  f'Final activation bit-width configuration: {[node_b[1] for node_b in activation_conf_nodes_bitwidth]}')
150
150
 
151
151
  if tb_w is not None:
@@ -259,9 +259,9 @@ def _init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter:
259
259
  A TensorBoardWriter object.
260
260
  """
261
261
  tb_w = None
262
- if common.Logger.LOG_PATH is not None:
263
- tb_log_dir = os.path.join(os.getcwd(), common.Logger.LOG_PATH, 'tensorboard_logs')
264
- common.Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}')
262
+ if Logger.LOG_PATH is not None:
263
+ tb_log_dir = os.path.join(os.getcwd(), Logger.LOG_PATH, 'tensorboard_logs')
264
+ Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}')
265
265
  tb_w = TensorboardWriter(tb_log_dir, fw_info)
266
266
  return tb_w
267
267
 
@@ -12,3 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
16
+ from model_compression_toolkit.exporter.model_exporter.keras.keras_export_facade import keras_export_model, KerasExportMode
17
+ from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import PyTorchExportMode, pytorch_export_model
18
+ from model_compression_toolkit.exporter.model_exporter.tflite.tflite_export_facade import tflite_export_model, TFLiteExportMode
19
+
@@ -13,6 +13,3 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.exporter.model_exporter.keras.keras_export_facade import keras_export_model, KerasExportMode
17
- from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import PyTorchExportMode, pytorch_export_model
18
- from model_compression_toolkit.exporter.model_exporter.tflite.tflite_export_facade import tflite_export_model, TFLiteExportMode
@@ -17,7 +17,7 @@
17
17
  from abc import abstractmethod
18
18
  from typing import Any, Callable
19
19
 
20
- from model_compression_toolkit.core.common import Logger
20
+ from model_compression_toolkit.logger import Logger
21
21
 
22
22
 
23
23
  class Exporter:
@@ -19,7 +19,7 @@ import keras.models
19
19
  import tensorflow as tf
20
20
  from keras.engine.base_layer import Layer
21
21
 
22
- from model_compression_toolkit.core.common import Logger
22
+ from model_compression_toolkit.logger import Logger
23
23
  from model_compression_toolkit.exporter.model_exporter.keras.base_keras_exporter import \
24
24
  BaseKerasExporter
25
25
  from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
@@ -15,8 +15,8 @@
15
15
  from enum import Enum
16
16
  from typing import Callable, Dict
17
17
 
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import FOUND_TF
20
20
 
21
21
 
22
22
  class KerasExportMode(Enum):
@@ -16,17 +16,21 @@ from typing import Callable
16
16
 
17
17
  import torch.nn
18
18
 
19
- from model_compression_toolkit.core.common import Logger
19
+ from model_compression_toolkit.logger import Logger
20
20
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
21
21
  from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
22
22
  from packaging import version
23
23
 
24
+ from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
25
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import LAYER
26
+
24
27
  # ONNX opset version 16 is supported from PyTorch 1.12
25
28
  if version.parse(torch.__version__) < version.parse("1.12"):
26
29
  OPSET_VERSION = 15
27
30
  else:
28
31
  OPSET_VERSION = 16
29
32
 
33
+
30
34
  class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
31
35
  """
32
36
  Exporter for fakely-quant PyTorch models.
@@ -70,6 +74,16 @@ class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
70
74
 
71
75
  Logger.info(f"Exporting PyTorch fake quant onnx model: {self.save_model_path}")
72
76
 
77
+ # Replace float weight with wrapped quantized weights
78
+ for layer in self.model.modules():
79
+ if isinstance(layer, PytorchQuantizationWrapper):
80
+ for name in layer.weights_quantizers.keys():
81
+ quantized_weight = torch.nn.Parameter(layer.get_quantized_weights()[name]).detach()
82
+ linear_layer = getattr(layer, LAYER)
83
+ delattr(linear_layer, name)
84
+ setattr(linear_layer, name, torch.nn.Parameter(quantized_weight))
85
+ layer.weights_quantizers = {}
86
+
73
87
  torch.onnx.export(self.model,
74
88
  model_input,
75
89
  self.save_model_path,
@@ -16,7 +16,7 @@ from typing import Callable
16
16
 
17
17
  import torch.nn
18
18
 
19
- from model_compression_toolkit.core.common import Logger
19
+ from model_compression_toolkit.logger import Logger
20
20
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
21
21
  from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
22
22
 
@@ -15,8 +15,8 @@
15
15
  from enum import Enum
16
16
  from typing import Callable
17
17
 
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import FOUND_TORCH
20
20
 
21
21
 
22
22
  class PyTorchExportMode(Enum):
@@ -19,8 +19,8 @@ from typing import Callable
19
19
  import keras.models
20
20
  import tensorflow as tf
21
21
 
22
- from model_compression_toolkit import keras_load_quantized_model
23
- from model_compression_toolkit.core.common import Logger
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.load_model import keras_load_quantized_model
23
+ from model_compression_toolkit.logger import Logger
24
24
  from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
25
25
 
26
26
 
@@ -23,7 +23,7 @@ from keras.layers import Dense, Conv2D, Reshape
23
23
  from keras.models import clone_model
24
24
 
25
25
  from model_compression_toolkit import quantizers_infrastructure as qi
26
- from model_compression_toolkit.core.common import Logger
26
+ from model_compression_toolkit.logger import Logger
27
27
  from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
28
28
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import \
29
29
  constants as keras_inferable_constants
@@ -15,8 +15,8 @@
15
15
  from enum import Enum
16
16
  from typing import Callable
17
17
 
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import FOUND_TF
20
20
 
21
21
 
22
22
  class TFLiteExportMode(Enum):
@@ -13,12 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.constants import FOUND_TF, FOUND_TORCH
16
+ from model_compression_toolkit.exporter.model_wrapper.keras.validate_layer import is_keras_layer_exportable
17
+ from model_compression_toolkit.exporter.model_wrapper.keras.builder.fully_quantized_model_builder import get_exportable_keras_model
17
18
 
18
- if FOUND_TF:
19
- from model_compression_toolkit.exporter.model_wrapper.keras.validate_layer import is_keras_layer_exportable
20
- from model_compression_toolkit.exporter.model_wrapper.keras.builder.fully_quantized_model_builder import get_exportable_keras_model
21
-
22
- if FOUND_TORCH:
23
- from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
24
- from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
19
+ from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
20
+ from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
@@ -14,46 +14,53 @@
14
14
  # ==============================================================================
15
15
  from typing import Tuple
16
16
 
17
- import tensorflow as tf
18
- from tensorflow.keras.layers import Layer
19
17
 
20
18
  from model_compression_toolkit import quantizers_infrastructure as qi
21
19
  from model_compression_toolkit.core import common
22
20
  from model_compression_toolkit.core.common import Graph
21
+ from model_compression_toolkit.constants import FOUND_TF
23
22
  from model_compression_toolkit.core.common.user_info import UserInformation
24
- from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
25
- from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizers import \
26
- get_quantization_quantizers
27
-
28
-
29
- def _get_wrapper(node: common.BaseNode,
30
- layer: Layer) -> qi.KerasQuantizationWrapper:
31
- """
32
- A function which takes a computational graph node and a keras layer and perform the quantization wrapping
33
- Args:
34
- n: A node of mct graph.
35
- layer: A keras layer
36
-
37
- Returns: Wrapped layer with weights quantizers and activation quantizers
38
-
39
- """
40
- weights_quantizers, activation_quantizers = get_quantization_quantizers(node)
41
- return qi.KerasQuantizationWrapper(layer, weights_quantizers, activation_quantizers)
42
-
43
-
44
- def get_exportable_keras_model(graph: Graph) -> Tuple[tf.keras.models.Model,UserInformation]:
45
- """
46
- Convert graph to an exportable Keras model (model with all quantization parameters).
47
- An exportable model can then be exported using model_exporter, to retrieve the
48
- final exported model.
49
-
50
- Args:
51
- graph: Graph to convert to an exportable Keras model.
52
-
53
- Returns:
54
- Exportable Keras model and user information.
55
- """
56
- exportable_model, user_info = KerasModelBuilder(graph=graph,
57
- wrapper=_get_wrapper).build_model()
58
- exportable_model.trainable = False
59
- return exportable_model, user_info
23
+ from model_compression_toolkit.logger import Logger
24
+
25
+ if FOUND_TF:
26
+ import tensorflow as tf
27
+ from tensorflow.keras.layers import Layer
28
+ from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
29
+ from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizers import get_quantization_quantizers
30
+
31
+ def _get_wrapper(node: common.BaseNode,
32
+ layer: Layer) -> qi.KerasQuantizationWrapper:
33
+ """
34
+ A function which takes a computational graph node and a keras layer and perform the quantization wrapping
35
+ Args:
36
+ n: A node of mct graph.
37
+ layer: A keras layer
38
+
39
+ Returns: Wrapped layer with weights quantizers and activation quantizers
40
+
41
+ """
42
+ weights_quantizers, activation_quantizers = get_quantization_quantizers(node)
43
+ return qi.KerasQuantizationWrapper(layer, weights_quantizers, activation_quantizers)
44
+
45
+
46
+ def get_exportable_keras_model(graph: Graph) -> Tuple[tf.keras.models.Model, UserInformation]:
47
+ """
48
+ Convert graph to an exportable Keras model (model with all quantization parameters).
49
+ An exportable model can then be exported using model_exporter, to retrieve the
50
+ final exported model.
51
+
52
+ Args:
53
+ graph: Graph to convert to an exportable Keras model.
54
+
55
+ Returns:
56
+ Exportable Keras model and user information.
57
+ """
58
+ exportable_model, user_info = KerasModelBuilder(graph=graph,
59
+ wrapper=_get_wrapper).build_model()
60
+ exportable_model.trainable = False
61
+ return exportable_model, user_info
62
+ else:
63
+ def get_exportable_keras_model(*args, **kwargs): # pragma: no cover
64
+ Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
65
+ 'when using get_exportable_keras_model. '
66
+ 'Could not find Tensorflow package.')
@@ -14,16 +14,15 @@
14
14
  # ==============================================================================
15
15
  from typing import Dict, Any
16
16
 
17
- from model_compression_toolkit.core.common import BaseNode, Logger
18
- from model_compression_toolkit.core.common.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget
21
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
22
- get_inferable_quantizer_class
23
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer \
24
- import \
25
- BaseKerasInferableQuantizer
17
+ from model_compression_toolkit.core.common import BaseNode
18
+ from model_compression_toolkit.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED, CLUSTER_CENTERS, SCALE_PER_CHANNEL
26
19
 
20
+ from model_compression_toolkit.logger import Logger
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget
23
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import get_inferable_quantizer_class
24
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import BaseKerasInferableQuantizer
25
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import constants as qi_keras_consts
27
26
 
28
27
  def get_inferable_quantizer_kwargs(node: BaseNode,
29
28
  quantization_target: QuantizationTarget) -> Dict[str, Any]:
@@ -44,19 +43,29 @@ def get_inferable_quantizer_kwargs(node: BaseNode,
44
43
  # Return the appropriate quantization parameters based on the quantization method
45
44
  if quantization_method in [QuantizationMethod.POWER_OF_TWO,
46
45
  QuantizationMethod.SYMMETRIC]:
47
- return {'num_bits': node_w_qc.weights_n_bits,
48
- 'threshold': list(node_w_qc.weights_quantization_params[THRESHOLD].flatten()),
49
- 'per_channel': node_w_qc.weights_per_channel_threshold,
50
- 'channel_axis': node_w_qc.weights_channels_axis,
51
- 'input_rank': len(node_w_qc.weights_quantization_params[THRESHOLD].shape)}
46
+ return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
47
+ qi_keras_consts.THRESHOLD: list(node_w_qc.weights_quantization_params[THRESHOLD].flatten()),
48
+ qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
49
+ qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
50
+ qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[THRESHOLD].shape)}
52
51
 
53
52
  elif quantization_method in [QuantizationMethod.UNIFORM]:
54
- return {'num_bits': node_w_qc.weights_n_bits,
55
- 'per_channel': node_w_qc.weights_per_channel_threshold,
56
- 'min_range': list(node_w_qc.weights_quantization_params[RANGE_MIN].flatten()),
57
- 'max_range': list(node_w_qc.weights_quantization_params[RANGE_MAX].flatten()),
58
- 'channel_axis': node_w_qc.weights_channels_axis,
59
- 'input_rank': len(node_w_qc.weights_quantization_params[RANGE_MIN].shape)}
53
+ return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
54
+ qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
55
+ qi_keras_consts.MIN_RANGE: list(node_w_qc.weights_quantization_params[RANGE_MIN].flatten()),
56
+ qi_keras_consts.MAX_RANGE: list(node_w_qc.weights_quantization_params[RANGE_MAX].flatten()),
57
+ qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
58
+ qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[RANGE_MIN].shape)}
59
+
60
+ elif quantization_method in [QuantizationMethod.LUT_SYM_QUANTIZER, QuantizationMethod.LUT_POT_QUANTIZER]:
61
+ return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
62
+ qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
63
+ qi_keras_consts.CLUSTER_CENTERS: node_w_qc.weights_quantization_params[CLUSTER_CENTERS],
64
+ qi_keras_consts.THRESHOLD: list(node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten()),
65
+ qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
66
+ # TODO: how to pass multiplier nbits and eps for a specific node?
67
+ qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].shape)}
68
+
60
69
  else:
61
70
  Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
62
71
 
@@ -68,16 +77,24 @@ def get_inferable_quantizer_kwargs(node: BaseNode,
68
77
  # Return the appropriate quantization parameters based on the quantization method
69
78
  if quantization_method in [QuantizationMethod.POWER_OF_TWO,
70
79
  QuantizationMethod.SYMMETRIC]:
71
- return {'num_bits': node_qc.activation_n_bits,
80
+ return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
72
81
  # In activation quantization is per-tensor only - thus we hold the threshold as a list with a len of 1
73
- 'threshold': [node_qc.activation_quantization_params[THRESHOLD]],
74
- 'signed': node_qc.activation_quantization_params[SIGNED]}
82
+ qi_keras_consts.THRESHOLD: [node_qc.activation_quantization_params[THRESHOLD]],
83
+ qi_keras_consts.SIGNED: node_qc.activation_quantization_params[SIGNED]}
75
84
 
76
85
  elif quantization_method in [QuantizationMethod.UNIFORM]:
77
- return {'num_bits': node_qc.activation_n_bits,
86
+ return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
78
87
  # In activation quantization is per-tensor only - thus we hold the min/max as a list with a len of 1
79
- 'min_range': [node_qc.activation_quantization_params[RANGE_MIN]],
80
- 'max_range': [node_qc.activation_quantization_params[RANGE_MAX]]}
88
+ qi_keras_consts.MIN_RANGE: [node_qc.activation_quantization_params[RANGE_MIN]],
89
+ qi_keras_consts.MAX_RANGE: [node_qc.activation_quantization_params[RANGE_MAX]]}
90
+
91
+ elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER]:
92
+ return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
93
+ qi_keras_consts.SIGNED: node_qc.activation_quantization_params[SIGNED],
94
+ qi_keras_consts.CLUSTER_CENTERS: node_qc.activation_quantization_params[CLUSTER_CENTERS],
95
+ qi_keras_consts.THRESHOLD: [node_qc.activation_quantization_params[THRESHOLD]]
96
+ # TODO: how to pass multiplier nbits and eps for a specific node?
97
+ }
81
98
  else:
82
99
  Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
83
100
  else: