mct-nightly 1.8.0.22032023.post333__py3-none-any.whl → 1.8.0.22052023.post408__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/METADATA +4 -3
  2. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/RECORD +294 -284
  3. model_compression_toolkit/__init__.py +9 -32
  4. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  5. model_compression_toolkit/core/__init__.py +14 -0
  6. model_compression_toolkit/core/analyzer.py +3 -2
  7. model_compression_toolkit/core/common/__init__.py +0 -1
  8. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  9. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  11. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  12. model_compression_toolkit/core/common/framework_info.py +1 -1
  13. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  14. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  15. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  16. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  18. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  19. model_compression_toolkit/core/common/memory_computation.py +1 -1
  20. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  21. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  23. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  26. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  28. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  29. model_compression_toolkit/core/common/model_collector.py +2 -2
  30. model_compression_toolkit/core/common/model_validation.py +1 -1
  31. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  32. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  33. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  34. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  35. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  36. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  37. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  38. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  39. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  50. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  51. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  52. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  54. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  55. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  56. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  57. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  58. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  60. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  63. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  65. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  66. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  67. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  69. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +66 -21
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  73. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  74. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  75. model_compression_toolkit/core/keras/constants.py +0 -7
  76. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  85. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  86. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  87. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  88. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  89. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  90. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  91. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  92. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  93. model_compression_toolkit/core/keras/reader/common.py +1 -1
  94. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  95. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  99. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  100. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/constants.py +0 -6
  102. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  103. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  109. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  110. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  111. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  112. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  113. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  114. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  115. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  116. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  117. model_compression_toolkit/core/runner.py +7 -7
  118. model_compression_toolkit/exporter/__init__.py +6 -3
  119. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  120. model_compression_toolkit/exporter/model_exporter/keras/export_serialization_format.py +20 -0
  121. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/fakely_quant_tflite_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/int8_tflite_exporter.py +1 -1
  124. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +60 -22
  125. model_compression_toolkit/exporter/model_exporter/pytorch/export_serialization_format.py +20 -0
  126. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  127. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  128. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +54 -31
  129. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +5 -3
  130. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +5 -3
  131. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +2 -2
  132. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +3 -2
  133. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +4 -3
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +2 -2
  135. model_compression_toolkit/gptq/common/gptq_config.py +2 -4
  136. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  137. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  138. model_compression_toolkit/gptq/common/gptq_training.py +5 -4
  139. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  140. model_compression_toolkit/gptq/keras/gptq_training.py +41 -14
  141. model_compression_toolkit/gptq/keras/graph_info.py +4 -0
  142. model_compression_toolkit/gptq/keras/quantization_facade.py +27 -20
  143. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  144. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -2
  145. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  146. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -5
  147. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +21 -16
  148. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  149. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +2 -2
  150. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  151. model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
  152. model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -13
  153. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  154. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -3
  155. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  156. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -2
  157. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +13 -5
  158. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  159. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +2 -2
  160. model_compression_toolkit/gptq/runner.py +3 -2
  161. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +11 -12
  162. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  163. model_compression_toolkit/ptq/__init__.py +3 -0
  164. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  165. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  166. model_compression_toolkit/qat/__init__.py +4 -0
  167. model_compression_toolkit/qat/common/__init__.py +1 -2
  168. model_compression_toolkit/qat/common/qat_config.py +5 -1
  169. model_compression_toolkit/qat/keras/quantization_facade.py +34 -28
  170. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  171. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +31 -4
  172. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +13 -11
  173. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +9 -9
  174. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  175. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  176. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +4 -3
  177. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +7 -5
  178. model_compression_toolkit/quantizers_infrastructure/__init__.py +2 -2
  179. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  180. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  181. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +5 -0
  182. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +2 -2
  183. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +147 -0
  184. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +5 -5
  185. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  186. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  187. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +3 -5
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +2 -3
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  211. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +9 -9
  212. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  213. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -6
  214. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  215. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  217. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +2 -2
  218. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  219. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  220. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  221. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  222. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  223. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  224. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  225. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  226. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  227. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  228. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  229. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  232. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  233. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  234. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  235. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  236. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  237. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  238. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  239. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  240. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  241. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  242. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  243. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  244. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  248. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  250. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  254. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  255. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  257. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  259. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  261. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  263. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  265. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  273. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  274. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  275. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  276. model_compression_toolkit/exporter/model_exporter/tflite/__init__.py +0 -14
  277. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +0 -73
  278. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/LICENSE.md +0 -0
  279. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/WHEEL +0 -0
  280. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/top_level.txt +0 -0
  281. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  282. /model_compression_toolkit/{core/common/logger.py → logger.py} +0 -0
  283. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  284. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  285. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  286. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  287. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  288. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  289. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  290. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  291. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  292. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  293. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  294. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -12,63 +12,86 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from enum import Enum
16
15
  from typing import Callable
17
16
 
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
20
-
21
-
22
- class PyTorchExportMode(Enum):
23
- FAKELY_QUANT_TORCHSCRIPT = 0
24
- FAKELY_QUANT_ONNX = 1
25
-
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
+ from model_compression_toolkit.exporter.model_exporter.pytorch.export_serialization_format import \
19
+ PytorchExportSerializationFormat
20
+ from model_compression_toolkit.logger import Logger
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform.quantization_format import \
23
+ QuantizationFormat
26
24
 
27
25
  if FOUND_TORCH:
28
26
  import torch.nn
29
- from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import FakelyQuantONNXPyTorchExporter
30
- from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import FakelyQuantTorchScriptPyTorchExporter
27
+ from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import \
28
+ FakelyQuantONNXPyTorchExporter
29
+ from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import \
30
+ FakelyQuantTorchScriptPyTorchExporter
31
31
  from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
32
32
 
33
+ supported_serialization_quantization_export_dict = {
34
+ PytorchExportSerializationFormat.TORCHSCRIPT: [QuantizationFormat.FAKELY_QUANT],
35
+ PytorchExportSerializationFormat.ONNX: [QuantizationFormat.FAKELY_QUANT]
36
+ }
37
+
33
38
  def pytorch_export_model(model: torch.nn.Module,
34
39
  save_model_path: str,
35
40
  repr_dataset: Callable,
41
+ target_platform_capabilities: TargetPlatformCapabilities,
36
42
  is_layer_exportable_fn: Callable = is_pytorch_layer_exportable,
37
- mode: PyTorchExportMode = PyTorchExportMode.FAKELY_QUANT_TORCHSCRIPT) -> None:
43
+ serialization_format: PytorchExportSerializationFormat =
44
+ PytorchExportSerializationFormat.TORCHSCRIPT) -> None:
38
45
  """
39
46
  Export a PyTorch quantized model to a torchscript or onnx model.
40
47
  The model will be saved to the path in save_model_path.
41
- Mode can be used for different exported files. Currently, pytorch_export_model
42
- supports PyTorchExportMode.FAKELY_QUANT_TORCHSCRIPT (where the exported model
43
- is in a TorchScript format and its weights and activations are float fakely-quantized values),
44
- and PyTorchExportMode.FakelyQuantONNX (where the exported model
45
- is in an ONNX format and its weights and activations are float fakely-quantized values)
48
+ Currently, pytorch_export_model supports only QuantizationFormat.FAKELY_QUANT (where weights
49
+ and activations are float fakely-quantized values) and PytorchExportSerializationFormat.TORCHSCRIPT
50
+ (where the model will be saved to TorchScript model) or PytorchExportSerializationFormat.ONNX
51
+ (where the model will be saved to ONNX model).
46
52
 
47
53
  Args:
48
54
  model: Model to export.
49
- is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
50
- mode: Mode to export the model according to.
51
55
  save_model_path: Path to save the model.
52
56
  repr_dataset: Representative dataset for tracing the pytorch model (mandatory for exporting it).
57
+ target_platform_capabilities: TargetPlatformCapabilities object that describes the desired inference
58
+ target platform (includes quantization format).
59
+ is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
60
+ serialization_format: Format to export the model according to (by default
61
+ PytorchExportSerializationFormat.TORCHSCRIPT).
53
62
 
54
63
  """
55
64
 
56
- if mode == PyTorchExportMode.FAKELY_QUANT_TORCHSCRIPT:
57
- exporter = FakelyQuantTorchScriptPyTorchExporter(model,
58
- is_layer_exportable_fn,
59
- save_model_path,
60
- repr_dataset)
65
+ if serialization_format == PytorchExportSerializationFormat.TORCHSCRIPT:
66
+ if target_platform_capabilities.tp_model.quantization_format in \
67
+ supported_serialization_quantization_export_dict[serialization_format]:
68
+ exporter = FakelyQuantTorchScriptPyTorchExporter(model,
69
+ is_layer_exportable_fn,
70
+ save_model_path,
71
+ repr_dataset)
72
+ else:
73
+ Logger.critical(
74
+ f'Unsupported quantization {target_platform_capabilities.tp_model.quantization_format} for '
75
+ f'serialization {serialization_format} was used to export Pytorch model. Please see API for '
76
+ f'supported formats.') # pragma: no cover
61
77
 
62
- elif mode == PyTorchExportMode.FAKELY_QUANT_ONNX:
63
- exporter = FakelyQuantONNXPyTorchExporter(model,
64
- is_layer_exportable_fn,
65
- save_model_path,
66
- repr_dataset)
78
+ elif serialization_format == PytorchExportSerializationFormat.ONNX:
79
+ if target_platform_capabilities.tp_model.quantization_format in \
80
+ supported_serialization_quantization_export_dict[serialization_format]:
81
+ exporter = FakelyQuantONNXPyTorchExporter(model,
82
+ is_layer_exportable_fn,
83
+ save_model_path,
84
+ repr_dataset)
85
+ else:
86
+ Logger.critical(
87
+ f'Unsupported quantization {target_platform_capabilities.tp_model.quantization_format} for '
88
+ f'serialization {serialization_format} was used to export Pytorch model. Please see API for '
89
+ f'supported formats.') # pragma: no cover
67
90
 
68
91
  else:
69
92
  Logger.critical(
70
- f'Unsupported mode was used {mode.name} to export PyTorch model. '
71
- f'Please see API for supported modes.') # pragma: no cover
93
+ f'Unsupported serialization {serialization_format} was used to export Pytorch model. Please see API '
94
+ f'for supported formats.') # pragma: no cover
72
95
 
73
96
  exporter.export()
74
97
 
@@ -17,9 +17,10 @@ from typing import Tuple
17
17
 
18
18
  from model_compression_toolkit import quantizers_infrastructure as qi
19
19
  from model_compression_toolkit.core import common
20
- from model_compression_toolkit.core.common import Graph, Logger
21
- from model_compression_toolkit.core.common.constants import FOUND_TF
20
+ from model_compression_toolkit.core.common import Graph
21
+ from model_compression_toolkit.constants import FOUND_TF
22
22
  from model_compression_toolkit.core.common.user_info import UserInformation
23
+ from model_compression_toolkit.logger import Logger
23
24
 
24
25
  if FOUND_TF:
25
26
  import tensorflow as tf
@@ -34,6 +35,7 @@ if FOUND_TF:
34
35
  Args:
35
36
  n: A node of mct graph.
36
37
  layer: A keras layer
38
+ include_activation_quantizers: Whether to use the wrapper for the activation quantizer or not
37
39
 
38
40
  Returns: Wrapped layer with weights quantizers and activation quantizers
39
41
 
@@ -55,7 +57,7 @@ if FOUND_TF:
55
57
  Exportable Keras model and user information.
56
58
  """
57
59
  exportable_model, user_info = KerasModelBuilder(graph=graph,
58
- wrapper=_get_wrapper).build_model()
60
+ wrapper=_get_wrapper).build_model()
59
61
  exportable_model.trainable = False
60
62
  return exportable_model, user_info
61
63
  else:
@@ -14,9 +14,11 @@
14
14
  # ==============================================================================
15
15
  from typing import Dict, Any
16
16
 
17
- from model_compression_toolkit.core.common import BaseNode, Logger
18
- from model_compression_toolkit.core.common.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED, CLUSTER_CENTERS, SCALE_PER_CHANNEL
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.core.common import BaseNode
18
+ from model_compression_toolkit.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED, CLUSTER_CENTERS, SCALE_PER_CHANNEL
19
+
20
+ from model_compression_toolkit.logger import Logger
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
22
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget
21
23
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import get_inferable_quantizer_class
22
24
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import BaseKerasInferableQuantizer
@@ -15,8 +15,8 @@
15
15
  from typing import Any
16
16
 
17
17
 
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import FOUND_TF
20
20
 
21
21
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
22
22
 
@@ -16,8 +16,9 @@
16
16
 
17
17
  from model_compression_toolkit import quantizers_infrastructure as qi
18
18
  from model_compression_toolkit.core import common
19
- from model_compression_toolkit.core.common import Graph, Logger
20
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
+ from model_compression_toolkit.core.common import Graph
20
+ from model_compression_toolkit.constants import FOUND_TORCH
21
+ from model_compression_toolkit.logger import Logger
21
22
 
22
23
  if FOUND_TORCH:
23
24
  import torch
@@ -15,10 +15,11 @@
15
15
 
16
16
  from typing import Dict, Any
17
17
 
18
- from model_compression_toolkit.core.common import BaseNode, Logger
19
- from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX, \
18
+ from model_compression_toolkit.core.common import BaseNode
19
+ from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX, \
20
20
  SCALE_PER_CHANNEL, CLUSTER_CENTERS
21
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
21
+ from model_compression_toolkit.logger import Logger
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
23
  from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
23
24
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
24
25
  get_inferable_quantizer_class
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  from typing import Any
16
16
 
17
- from model_compression_toolkit.core.common import Logger
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
17
+ from model_compression_toolkit.logger import Logger
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
19
 
20
20
  if FOUND_TORCH:
21
21
  from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
@@ -111,10 +111,8 @@ class GradientPTQConfig:
111
111
  self.regularization_factor = regularization_factor
112
112
  self.hessian_weights_config = hessian_weights_config
113
113
 
114
- # Since the default quantizer is soft quantizer, we initialize the gptq_quantizer_params_override dictionary
115
- # with its extended params
116
- self.gptq_quantizer_params_override = {QUANT_PARAM_LEARNING_STR: False} \
117
- if gptq_quantizer_params_override is None else gptq_quantizer_params_override
114
+ self.gptq_quantizer_params_override = {} if gptq_quantizer_params_override is None \
115
+ else gptq_quantizer_params_override
118
116
 
119
117
 
120
118
  class GradientPTQConfigV2(GradientPTQConfig):
@@ -0,0 +1,32 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from abc import abstractmethod
17
+
18
+ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
19
+
20
+
21
+ class GPTQFrameworkImplemantation(FrameworkImplementation):
22
+ """
23
+ Class to implement framework related methods that are used in GPTQ
24
+ """
25
+
26
+ @abstractmethod
27
+ def get_gptq_trainer_obj(self):
28
+ """
29
+ Returns: GPTQTrainer object
30
+ """
31
+ raise NotImplemented(f'{self.__class__.__name__} have to implement the '
32
+ f'framework\'s get_gptq_trainer method.') # pragma: no cover
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  from typing import Tuple, List
16
16
 
17
- from model_compression_toolkit import FrameworkInfo
18
- from model_compression_toolkit.core.common import Logger
17
+ from model_compression_toolkit.core import FrameworkInfo
18
+ from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.core.common.graph.base_graph import Graph
20
20
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
21
21
 
@@ -17,12 +17,13 @@ from abc import ABC, abstractmethod
17
17
  import numpy as np
18
18
  from typing import Callable, List, Any
19
19
  from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
20
- from model_compression_toolkit.core.common import Graph, Logger, BaseNode
20
+ from model_compression_toolkit.core.common import Graph, BaseNode
21
21
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
- from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
23
22
  from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
23
+ from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
24
24
  from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
25
25
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
26
+ from model_compression_toolkit.logger import Logger
26
27
 
27
28
 
28
29
  class GPTQTrainer(ABC):
@@ -34,7 +35,7 @@ class GPTQTrainer(ABC):
34
35
  graph_float: Graph,
35
36
  graph_quant: Graph,
36
37
  gptq_config: GradientPTQConfig,
37
- fw_impl: FrameworkImplementation,
38
+ fw_impl: GPTQFrameworkImplemantation,
38
39
  fw_info: FrameworkInfo):
39
40
  """
40
41
  Build two models from a graph: A teacher network (float model) and a student network (quantized model).
@@ -259,7 +260,7 @@ def gptq_training(graph_float: Graph,
259
260
  graph_quant: Graph,
260
261
  gptq_config: GradientPTQConfig,
261
262
  representative_data_gen: Callable,
262
- fw_impl: FrameworkImplementation,
263
+ fw_impl: GPTQFrameworkImplemantation,
263
264
  fw_info: FrameworkInfo) -> Graph:
264
265
  """
265
266
  GPTQ training process using knowledge distillation with a teacher network (float model) and a student network (quantized model).
@@ -0,0 +1,29 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Type
17
+
18
+ from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
19
+ from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
20
+ from model_compression_toolkit.gptq.keras.gptq_training import KerasGPTQTrainer
21
+
22
+
23
+ class GPTQKerasImplemantation(GPTQFrameworkImplemantation, KerasImplementation):
24
+
25
+ def get_gptq_trainer_obj(self) -> Type[KerasGPTQTrainer]:
26
+ """
27
+ Returns: Keras object of GPTQTrainer
28
+ """
29
+ return KerasGPTQTrainer
@@ -16,17 +16,18 @@ from typing import Callable, List, Tuple, Union
16
16
 
17
17
  import tensorflow as tf
18
18
  from keras import Model
19
+ from packaging import version
19
20
  from tensorflow.keras.layers import Layer
20
21
  from tqdm import tqdm
21
22
 
22
23
  # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
23
24
  from model_compression_toolkit.core.common.user_info import UserInformation
24
25
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
25
- from packaging import version
26
-
27
26
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
28
27
  from model_compression_toolkit.gptq.keras.quantizer.quantization_builder import quantization_builder
28
+ from model_compression_toolkit.logger import Logger
29
29
  from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
30
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.activation_quantization_holder import ActivationQuantizationHolder
30
31
 
31
32
  if version.parse(tf.__version__) < version.parse("2.6"):
32
33
  from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
@@ -105,7 +106,7 @@ class KerasGPTQTrainer(GPTQTrainer):
105
106
  [len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
106
107
 
107
108
  if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
108
- common.Logger.error("Input scale mismatch between float and GPTQ networks") # pragma: no cover
109
+ Logger.error("Input scale mismatch between float and GPTQ networks") # pragma: no cover
109
110
  else:
110
111
  self.input_scale = self.gptq_user_info.input_scale
111
112
 
@@ -113,8 +114,8 @@ class KerasGPTQTrainer(GPTQTrainer):
113
114
 
114
115
  self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
115
116
 
116
- def _is_gptq_applicable(self,
117
- node: common.BaseNode) -> bool:
117
+ def _is_gptq_weights_trainable(self,
118
+ node: common.BaseNode) -> bool:
118
119
  """
119
120
  A function for deciding if a layer should be fine-tuned during GPTQ.
120
121
 
@@ -126,11 +127,13 @@ class KerasGPTQTrainer(GPTQTrainer):
126
127
  """
127
128
 
128
129
  if node.is_weights_quantization_enabled() and not self.fw_info.is_kernel_op(node.type):
129
- common.Logger.error(f"GPTQ Error: Quantizing node {node.name} of type {node.type} "
130
+ Logger.error(f"GPTQ Error: Quantizing node {node.name} of type {node.type} "
130
131
  f"without a kernel isn't supported")
131
132
  return node.is_weights_quantization_enabled()
132
133
 
133
- def gptq_wrapper(self, n: common.BaseNode, layer: Layer) -> Union[qi.KerasQuantizationWrapper, Layer]:
134
+ def gptq_wrapper(self,
135
+ n: common.BaseNode,
136
+ layer: Layer) -> Union[qi.KerasQuantizationWrapper, Layer]:
134
137
  """
135
138
  A function which takes a computational graph node and a keras layer and perform the quantization wrapping.
136
139
 
@@ -141,14 +144,37 @@ class KerasGPTQTrainer(GPTQTrainer):
141
144
  Returns: Wrapped layer if the layer should be wrap, otherwise returns the layer as is.
142
145
 
143
146
  """
144
- if self._is_gptq_applicable(n):
145
- weights_quantizers, activation_quantizers = quantization_builder(n, self.gptq_config)
147
+ if self._is_gptq_weights_trainable(n):
148
+ weights_quantizers, _ = quantization_builder(n, self.gptq_config) # TODO: split quantizers building into two functions: for weights and activations
146
149
  return qi.KerasQuantizationWrapper(layer,
147
- weights_quantizers=weights_quantizers,
148
- activation_quantizers=activation_quantizers)
150
+ weights_quantizers=weights_quantizers)
149
151
  else:
150
152
  return layer
151
153
 
154
+ def get_activation_quantizer_holder(self, n: common.BaseNode) -> Union[None, Callable]:
155
+ """
156
+ Retrieve a ActivationQuantizationHolder layer to use for activation quantization for a node.
157
+ If the layer is not supposed to be wrapped with activation quantizers - return None.
158
+
159
+ Args:
160
+ n: Node to get ActivationQuantizationHolder to attach in its output.
161
+
162
+ Returns:
163
+ A ActivationQuantizationHolder layer for the node activation quantization.
164
+ """
165
+ _, activation_quantizers = quantization_builder(n, self.gptq_config) # TODO: split quantizers building into two functions: for weights and activations
166
+
167
+ # Holder by definition uses a single quantizer for the activation quantization
168
+ # thus we make sure this is the only possible case (unless it's a node with no activation
169
+ # quantization, which in this case has an empty list).
170
+ if len(activation_quantizers) == 1:
171
+ return ActivationQuantizationHolder(activation_quantizers[0])
172
+
173
+ Logger.error(
174
+ f'ActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
175
+ f'were found for node {n}')
176
+
177
+
152
178
  def build_gptq_model(self) -> Tuple[Model, UserInformation]:
153
179
  """
154
180
  Build the GPTQ model with QuantizationWrappers
@@ -161,7 +187,8 @@ class KerasGPTQTrainer(GPTQTrainer):
161
187
  append2output=self.compare_points,
162
188
  fw_info=self.fw_info,
163
189
  return_float_outputs=True,
164
- wrapper=self.gptq_wrapper).build_model()
190
+ wrapper=self.gptq_wrapper,
191
+ get_activation_quantizer_holder_fn=self.get_activation_quantizer_holder).build_model()
165
192
 
166
193
  return gptq_model, gptq_user_info
167
194
 
@@ -280,7 +307,7 @@ class KerasGPTQTrainer(GPTQTrainer):
280
307
  self.gptq_config.log_function(loss_value_step, grads[0], in_optimizer_with_param[0][-1],
281
308
  self.compare_points)
282
309
  self.loss_list.append(loss_value_step.numpy())
283
- common.Logger.debug(f'last loss value: {self.loss_list[-1]}')
310
+ Logger.debug(f'last loss value: {self.loss_list[-1]}')
284
311
 
285
312
  def update_graph(self):
286
313
  """
@@ -297,7 +324,7 @@ class KerasGPTQTrainer(GPTQTrainer):
297
324
  if len(node) == 0 and isinstance(layer.layer, TensorFlowOpLayer):
298
325
  node = graph.find_node_by_name('_'.join(layer.layer.name.split('_')[3:]))
299
326
  if len(node) != 1:
300
- common.Logger.error(f"Can't update GPTQ graph due to missing layer named: {layer.layer.name}")
327
+ Logger.error(f"Can't update GPTQ graph due to missing layer named: {layer.layer.name}")
301
328
  node = node[0]
302
329
  kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
303
330
  fw_info=self.fw_info)
@@ -20,6 +20,7 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
20
20
  from tensorflow.keras.models import Model
21
21
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
22
22
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
23
+ from model_compression_toolkit.logger import Logger
23
24
  from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
24
25
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
25
26
 
@@ -50,6 +51,9 @@ def get_gptq_trainable_parameters(fxp_model: Model,
50
51
  fw_info=DEFAULT_KERAS_INFO)
51
52
 
52
53
  # collect trainable weights per quantizer
54
+ if kernel_attribute not in layer.weights_quantizers:
55
+ Logger.error(f'{kernel_attribute} was not found in weight quantizers of layer {layer.layer}')
56
+
53
57
  quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
54
58
  quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
55
59
  trainable_weights.append(quantizer_trainable_weights)
@@ -16,21 +16,19 @@
16
16
  from typing import Callable, Tuple
17
17
  from packaging import version
18
18
 
19
- from model_compression_toolkit.core import common
20
- from model_compression_toolkit.core.common import Logger
21
- from model_compression_toolkit.core.common.constants import TENSORFLOW
19
+ from model_compression_toolkit.logger import Logger
20
+ from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
22
21
  from model_compression_toolkit.core.common.user_info import UserInformation
23
22
  from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
24
23
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
25
24
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
26
- from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
27
- MixedPrecisionQuantizationConfigV2
28
- from model_compression_toolkit import CoreConfig
25
+ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfigV2
26
+ from model_compression_toolkit.core import CoreConfig
29
27
  from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
30
28
  from model_compression_toolkit.gptq.runner import gptq_runner
31
29
  from model_compression_toolkit.core.exporter import export_model
32
30
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
33
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework import TargetPlatformCapabilities
31
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
34
32
 
35
33
  LR_DEFAULT = 0.15
36
34
  LR_REST_DEFAULT = 1e-4
@@ -38,14 +36,14 @@ LR_BIAS_DEFAULT = 1e-4
38
36
  LR_QUANTIZATION_PARAM_DEFAULT = 1e-3
39
37
  GPTQ_MOMENTUM = 0.9
40
38
 
41
- if common.constants.FOUND_TF:
39
+ if FOUND_TF:
42
40
  import tensorflow as tf
43
41
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
44
- from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
42
+ from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
45
43
  from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
46
44
  from tensorflow.keras.models import Model
47
45
  from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss
48
- from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL
46
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
49
47
  from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
50
48
  from model_compression_toolkit import get_target_platform_capabilities
51
49
 
@@ -62,7 +60,8 @@ if common.constants.FOUND_TF:
62
60
  optimizer: OptimizerV2 = tf.keras.optimizers.Adam(learning_rate=LR_DEFAULT),
63
61
  optimizer_rest: OptimizerV2 = tf.keras.optimizers.Adam(learning_rate=LR_REST_DEFAULT),
64
62
  loss: Callable = GPTQMultipleTensorsLoss(),
65
- log_function: Callable = None) -> GradientPTQConfigV2:
63
+ log_function: Callable = None,
64
+ use_hessian_based_weights: bool = True) -> GradientPTQConfigV2:
66
65
  """
67
66
  Create a GradientPTQConfigV2 instance for Keras models.
68
67
 
@@ -72,6 +71,7 @@ if common.constants.FOUND_TF:
72
71
  optimizer_rest (OptimizerV2): Keras optimizer to use for fine-tuning of the bias variable.
73
72
  loss (Callable): loss to use during fine-tuning. should accept 4 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors, the 3rd is a list of quantized weights and the 4th is a list of float weights.
74
73
  log_function (Callable): Function to log information about the gptq process.
74
+ use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
75
75
 
76
76
  returns:
77
77
  a GradientPTQConfigV2 object to use when fine-tuning the quantized model using gptq.
@@ -94,9 +94,16 @@ if common.constants.FOUND_TF:
94
94
  The configuration can be passed to :func:`~model_compression_toolkit.keras_post_training_quantization` in order to quantize a keras model using gptq.
95
95
 
96
96
  """
97
- bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
98
- return GradientPTQConfigV2(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
99
- log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer)
97
+ bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT,
98
+ momentum=GPTQ_MOMENTUM)
99
+ return GradientPTQConfigV2(n_epochs,
100
+ optimizer,
101
+ optimizer_rest=optimizer_rest,
102
+ loss=loss,
103
+ log_function=log_function,
104
+ train_bias=True,
105
+ optimizer_bias=bias_optimizer,
106
+ use_hessian_based_weights=use_hessian_based_weights)
100
107
 
101
108
 
102
109
  def keras_gradient_post_training_quantization_experimental(in_model: Model,
@@ -158,20 +165,20 @@ if common.constants.FOUND_TF:
158
165
 
159
166
  Create an MCT core config, containing the quantization configuration:
160
167
 
161
- >>> config = mct.CoreConfig()
168
+ >>> config = mct.core.CoreConfig()
162
169
 
163
170
  If mixed precision is desired, create an MCT core config with a mixed-precision configuration, to quantize a model
164
171
  with different bitwidths for different layers.
165
172
  The candidates bitwidth for quantization should be defined in the target platform model:
166
173
 
167
- >>> config = mct.CoreConfig(mixed_precision_config=mct.MixedPrecisionQuantizationConfigV2(num_of_images=1))
174
+ >>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfigV2(num_of_images=1))
168
175
 
169
176
  For mixed-precision set a target KPI object:
170
177
  Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
171
178
  that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
172
179
  while the bias will not):
173
180
 
174
- >>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
181
+ >>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
175
182
 
176
183
  Create GPTQ config:
177
184
 
@@ -187,15 +194,15 @@ if common.constants.FOUND_TF:
187
194
 
188
195
  if core_config.mixed_precision_enable:
189
196
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
190
- common.Logger.error("Given quantization config to mixed-precision facade is not of type "
197
+ Logger.error("Given quantization config to mixed-precision facade is not of type "
191
198
  "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
192
199
  "API, or pass a valid mixed precision configuration.") # pragma: no cover
193
200
 
194
- common.Logger.info("Using experimental mixed-precision quantization. "
201
+ Logger.info("Using experimental mixed-precision quantization. "
195
202
  "If you encounter an issue please file a bug.")
196
203
  tb_w = _init_tensorboard_writer(fw_info)
197
204
 
198
- fw_impl = KerasImplementation()
205
+ fw_impl = GPTQKerasImplemantation()
199
206
 
200
207
  tg, bit_widths_config = core_runner(in_model=in_model,
201
208
  representative_data_gen=representative_data_gen,
@@ -15,3 +15,4 @@
15
15
 
16
16
  import model_compression_toolkit.gptq.keras.quantizer.ste_rounding.symmetric_ste
17
17
  import model_compression_toolkit.gptq.keras.quantizer.soft_rounding.symmetric_soft_quantizer
18
+ import model_compression_toolkit.gptq.keras.quantizer.soft_rounding.uniform_soft_quantizer
@@ -15,8 +15,8 @@
15
15
  from abc import abstractmethod
16
16
  from typing import Union, Dict, List
17
17
 
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import FOUND_TF
20
20
  from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
21
21
 
22
22
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \