mct-nightly 1.8.0.22042023.post414__py3-none-any.whl → 1.8.0.22052023.post408__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. {mct_nightly-1.8.0.22042023.post414.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/METADATA +1 -1
  2. {mct_nightly-1.8.0.22042023.post414.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/RECORD +237 -230
  3. model_compression_toolkit/__init__.py +8 -31
  4. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  5. model_compression_toolkit/core/__init__.py +14 -0
  6. model_compression_toolkit/core/analyzer.py +3 -2
  7. model_compression_toolkit/core/common/__init__.py +0 -1
  8. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  9. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  11. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  12. model_compression_toolkit/core/common/fusion/layer_fusing.py +2 -2
  13. model_compression_toolkit/core/common/graph/base_graph.py +1 -1
  14. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  15. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  16. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  17. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  18. model_compression_toolkit/core/common/memory_computation.py +1 -1
  19. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +1 -1
  20. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +2 -3
  21. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  22. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  23. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  25. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  26. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  27. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  28. model_compression_toolkit/core/common/model_collector.py +2 -2
  29. model_compression_toolkit/core/common/model_validation.py +1 -1
  30. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  31. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  32. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
  33. model_compression_toolkit/core/common/quantization/node_quantization_config.py +1 -1
  34. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  35. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
  36. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
  37. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  38. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  39. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +2 -2
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +2 -1
  47. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  48. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  49. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  50. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  51. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  52. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -2
  53. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  54. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +4 -3
  55. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -2
  56. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
  57. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +2 -2
  58. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  59. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +4 -4
  60. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  61. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  62. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  63. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  64. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  65. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +66 -21
  66. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  67. model_compression_toolkit/core/keras/back2framework/model_gradients.py +2 -2
  68. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  69. model_compression_toolkit/core/keras/constants.py +0 -7
  70. model_compression_toolkit/core/keras/default_framework_info.py +2 -2
  71. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  72. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  73. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  74. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  79. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  80. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  81. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  82. model_compression_toolkit/core/keras/kpi_data_facade.py +7 -7
  83. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  84. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  85. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  86. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  87. model_compression_toolkit/core/keras/reader/common.py +1 -1
  88. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  89. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  90. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  91. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  92. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +2 -2
  93. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  94. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  95. model_compression_toolkit/core/pytorch/constants.py +0 -6
  96. model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  98. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  99. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  100. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
  101. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  102. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  103. model_compression_toolkit/core/pytorch/kpi_data_facade.py +6 -6
  104. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  105. model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -9
  106. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  107. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  108. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  109. model_compression_toolkit/core/pytorch/reader/graph_builders.py +3 -2
  110. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  111. model_compression_toolkit/core/runner.py +6 -6
  112. model_compression_toolkit/exporter/__init__.py +6 -3
  113. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  114. model_compression_toolkit/exporter/model_exporter/keras/export_serialization_format.py +20 -0
  115. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  116. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/fakely_quant_tflite_exporter.py +1 -1
  117. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/int8_tflite_exporter.py +1 -1
  118. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +60 -22
  119. model_compression_toolkit/exporter/model_exporter/pytorch/export_serialization_format.py +20 -0
  120. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  121. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +54 -31
  123. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +5 -3
  124. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +4 -2
  125. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +2 -2
  126. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +3 -2
  127. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +3 -2
  128. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +2 -2
  129. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  130. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  131. model_compression_toolkit/gptq/common/gptq_training.py +5 -4
  132. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  133. model_compression_toolkit/gptq/keras/gptq_training.py +41 -14
  134. model_compression_toolkit/gptq/keras/graph_info.py +4 -0
  135. model_compression_toolkit/gptq/keras/quantization_facade.py +26 -19
  136. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -2
  137. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +1 -1
  138. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  139. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +2 -2
  140. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  141. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  142. model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
  143. model_compression_toolkit/gptq/pytorch/quantization_facade.py +11 -11
  144. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -3
  145. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +1 -3
  146. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  147. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +2 -2
  148. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  149. model_compression_toolkit/gptq/runner.py +3 -2
  150. model_compression_toolkit/{exporter/model_exporter/tflite → legacy}/__init__.py +1 -1
  151. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +8 -9
  152. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +8 -9
  153. model_compression_toolkit/ptq/__init__.py +3 -0
  154. model_compression_toolkit/ptq/keras/quantization_facade.py +10 -11
  155. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -7
  156. model_compression_toolkit/qat/__init__.py +4 -0
  157. model_compression_toolkit/qat/common/__init__.py +1 -2
  158. model_compression_toolkit/qat/common/qat_config.py +5 -1
  159. model_compression_toolkit/qat/keras/quantization_facade.py +33 -27
  160. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  161. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +31 -4
  162. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +12 -10
  163. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +8 -8
  164. model_compression_toolkit/qat/pytorch/quantization_facade.py +8 -8
  165. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  166. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +3 -2
  167. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +6 -4
  168. model_compression_toolkit/quantizers_infrastructure/__init__.py +2 -2
  169. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +5 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +1 -1
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +147 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +5 -5
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +2 -2
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +1 -1
  178. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  179. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  180. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +1 -1
  181. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +1 -1
  182. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +1 -1
  183. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +1 -1
  184. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +1 -1
  185. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  186. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +2 -2
  187. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +1 -2
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +1 -1
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +1 -1
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +1 -1
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +1 -1
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +1 -1
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +1 -1
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +1 -1
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +1 -1
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +1 -1
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +1 -1
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  200. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +9 -9
  201. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  202. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +3 -5
  203. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +2 -2
  204. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +2 -2
  205. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  206. model_compression_toolkit/target_platform_capabilities/target_platform/current_tp_model.py +1 -1
  207. model_compression_toolkit/target_platform_capabilities/target_platform/operators.py +1 -1
  208. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  209. model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +11 -2
  210. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  211. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/layer_filter_params.py +32 -34
  212. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +2 -2
  213. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -24
  214. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +1 -1
  215. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/target_platform_capabilities.py +3 -1
  216. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v1/tp_model.py +7 -1
  217. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v2/tp_model.py +7 -1
  218. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v3/tp_model.py +7 -1
  219. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v3_lut/tp_model.py +7 -2
  220. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v4/tp_model.py +7 -1
  221. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v4_lut/tp_model.py +7 -2
  222. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v5/tp_model.py +7 -1
  223. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +1 -3
  224. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +1 -1
  225. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +2 -1
  226. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  227. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +1 -1
  228. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +2 -1
  229. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  230. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +1 -1
  231. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +2 -1
  232. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  233. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +0 -73
  234. {mct_nightly-1.8.0.22042023.post414.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/LICENSE.md +0 -0
  235. {mct_nightly-1.8.0.22042023.post414.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/WHEEL +0 -0
  236. {mct_nightly-1.8.0.22042023.post414.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/top_level.txt +0 -0
  237. /model_compression_toolkit/{core/common/logger.py → logger.py} +0 -0
  238. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
@@ -20,7 +20,7 @@ import keras.models
20
20
  import tensorflow as tf
21
21
 
22
22
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.load_model import keras_load_quantized_model
23
- from model_compression_toolkit.core.common import Logger
23
+ from model_compression_toolkit.logger import Logger
24
24
  from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
25
25
 
26
26
 
@@ -23,7 +23,7 @@ from keras.layers import Dense, Conv2D, Reshape
23
23
  from keras.models import clone_model
24
24
 
25
25
  from model_compression_toolkit import quantizers_infrastructure as qi
26
- from model_compression_toolkit.core.common import Logger
26
+ from model_compression_toolkit.logger import Logger
27
27
  from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
28
28
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import \
29
29
  constants as keras_inferable_constants
@@ -12,53 +12,91 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from enum import Enum
16
15
  from typing import Callable, Dict
17
16
 
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
20
-
21
-
22
- class KerasExportMode(Enum):
23
- FAKELY_QUANT = 0
24
-
17
+ from model_compression_toolkit.constants import FOUND_TF
18
+ from model_compression_toolkit.exporter.model_exporter.keras.export_serialization_format import \
19
+ KerasExportSerializationFormat
20
+ from model_compression_toolkit.logger import Logger
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform.quantization_format import \
23
+ QuantizationFormat
25
24
 
26
25
  if FOUND_TF:
27
26
  import keras
28
27
  from model_compression_toolkit.exporter.model_wrapper.keras.validate_layer import is_keras_layer_exportable
29
- from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
28
+ from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import \
29
+ FakelyQuantKerasExporter
30
+ from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_tflite_exporter import \
31
+ FakelyQuantTFLiteExporter
32
+ from model_compression_toolkit.exporter.model_exporter.keras.int8_tflite_exporter import INT8TFLiteExporter
33
+
34
+ supported_serialization_quantization_export_dict = {
35
+ KerasExportSerializationFormat.KERAS_H5: [QuantizationFormat.FAKELY_QUANT],
36
+ KerasExportSerializationFormat.TFLITE: [QuantizationFormat.FAKELY_QUANT, QuantizationFormat.INT8]
37
+ }
30
38
 
31
39
  def keras_export_model(model: keras.models.Model,
32
40
  save_model_path: str,
41
+ target_platform_capabilities: TargetPlatformCapabilities,
33
42
  is_layer_exportable_fn: Callable = is_keras_layer_exportable,
34
- mode: KerasExportMode = KerasExportMode.FAKELY_QUANT) -> Dict[str, type]:
43
+ serialization_format: KerasExportSerializationFormat =
44
+ KerasExportSerializationFormat.KERAS_H5) -> \
45
+ Dict[str, type]:
35
46
  """
36
- Export a Keras quantized model to h5 model.
47
+ Export a Keras quantized model to a h5 or tflite model.
37
48
  The model will be saved to the path in save_model_path.
38
- Mode can be used for different exported files. Currently, keras_export_model
39
- supports KerasExportMode.FAKELY_QUANT (where weights and activations are
40
- float fakely-quantized values).
49
+ keras_export_model supports the combination of QuantizationFormat.FAKELY_QUANT (where weights
50
+ and activations are float fakely-quantized values) and KerasExportSerializationFormat.KERAS_H5 (where the model
51
+ will be saved to h5 model) or the combination of KerasExportSerializationFormat.TFLITE (where the model will be
52
+ saved to tflite model) with QuantizationFormat.FAKELY_QUANT or QuantizationFormat.INT8 (where weights and
53
+ activations are represented using 8bits integers).
41
54
 
42
55
  Args:
43
56
  model: Model to export.
44
- is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
45
- mode: Mode to export the model according to.
46
57
  save_model_path: Path to save the model.
58
+ target_platform_capabilities: TargetPlatformCapabilities object that describes the desired inference
59
+ target platform (includes quantization format).
60
+ is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
61
+ serialization_format: Format to export the model according to (by default
62
+ KerasExportSerializationFormat.KERAS_H5).
47
63
 
48
64
  Returns:
49
65
  Custom objects dictionary needed to load the model.
50
66
 
51
67
  """
52
68
 
53
- if mode == KerasExportMode.FAKELY_QUANT:
54
- exporter = FakelyQuantKerasExporter(model,
55
- is_layer_exportable_fn,
56
- save_model_path)
69
+ if serialization_format == KerasExportSerializationFormat.KERAS_H5:
70
+ if target_platform_capabilities.tp_model.quantization_format == QuantizationFormat.FAKELY_QUANT:
71
+ exporter = FakelyQuantKerasExporter(model,
72
+ is_layer_exportable_fn,
73
+ save_model_path)
74
+ else:
75
+ Logger.critical(
76
+ f'Unsupported quantization {target_platform_capabilities.tp_model.quantization_format} for '
77
+ f'serialization {serialization_format} was used to export Keras model. Please see API for '
78
+ f'supported formats.') # pragma: no cover
79
+
80
+ elif serialization_format == KerasExportSerializationFormat.TFLITE:
81
+ if target_platform_capabilities.tp_model.quantization_format == QuantizationFormat.FAKELY_QUANT:
82
+ exporter = FakelyQuantTFLiteExporter(model,
83
+ is_layer_exportable_fn,
84
+ save_model_path)
85
+
86
+ elif target_platform_capabilities.tp_model.quantization_format == QuantizationFormat.INT8:
87
+ exporter = INT8TFLiteExporter(model,
88
+ is_layer_exportable_fn,
89
+ save_model_path)
90
+ else:
91
+ Logger.critical(
92
+ f'Unsupported quantization {target_platform_capabilities.tp_model.quantization_format} for '
93
+ f'serialization {serialization_format} was used to export Keras model. Please see API for '
94
+ f'supported formats.') # pragma: no cover
57
95
 
58
96
  else:
59
97
  Logger.critical(
60
- f'Unsupported mode was used {mode.name} to '
61
- f'export Keras model. Please see API for supported modes.') # pragma: no cover
98
+ f'Unsupported serialization {serialization_format} was used to export Keras model. Please see API '
99
+ f'for supported formats.') # pragma: no cover
62
100
 
63
101
  exporter.export()
64
102
 
@@ -0,0 +1,20 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from enum import Enum
16
+
17
+
18
+ class PytorchExportSerializationFormat(Enum):
19
+ TORCHSCRIPT = 0
20
+ ONNX = 1
@@ -16,17 +16,21 @@ from typing import Callable
16
16
 
17
17
  import torch.nn
18
18
 
19
- from model_compression_toolkit.core.common import Logger
19
+ from model_compression_toolkit.logger import Logger
20
20
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
21
21
  from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
22
22
  from packaging import version
23
23
 
24
+ from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
25
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import LAYER
26
+
24
27
  # ONNX opset version 16 is supported from PyTorch 1.12
25
28
  if version.parse(torch.__version__) < version.parse("1.12"):
26
29
  OPSET_VERSION = 15
27
30
  else:
28
31
  OPSET_VERSION = 16
29
32
 
33
+
30
34
  class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
31
35
  """
32
36
  Exporter for fakely-quant PyTorch models.
@@ -70,6 +74,16 @@ class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
70
74
 
71
75
  Logger.info(f"Exporting PyTorch fake quant onnx model: {self.save_model_path}")
72
76
 
77
+ # Replace float weight with wrapped quantized weights
78
+ for layer in self.model.modules():
79
+ if isinstance(layer, PytorchQuantizationWrapper):
80
+ for name in layer.weights_quantizers.keys():
81
+ quantized_weight = torch.nn.Parameter(layer.get_quantized_weights()[name]).detach()
82
+ linear_layer = getattr(layer, LAYER)
83
+ delattr(linear_layer, name)
84
+ setattr(linear_layer, name, torch.nn.Parameter(quantized_weight))
85
+ layer.weights_quantizers = {}
86
+
73
87
  torch.onnx.export(self.model,
74
88
  model_input,
75
89
  self.save_model_path,
@@ -16,7 +16,7 @@ from typing import Callable
16
16
 
17
17
  import torch.nn
18
18
 
19
- from model_compression_toolkit.core.common import Logger
19
+ from model_compression_toolkit.logger import Logger
20
20
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
21
21
  from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
22
22
 
@@ -12,63 +12,86 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from enum import Enum
16
15
  from typing import Callable
17
16
 
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
20
-
21
-
22
- class PyTorchExportMode(Enum):
23
- FAKELY_QUANT_TORCHSCRIPT = 0
24
- FAKELY_QUANT_ONNX = 1
25
-
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
+ from model_compression_toolkit.exporter.model_exporter.pytorch.export_serialization_format import \
19
+ PytorchExportSerializationFormat
20
+ from model_compression_toolkit.logger import Logger
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform.quantization_format import \
23
+ QuantizationFormat
26
24
 
27
25
  if FOUND_TORCH:
28
26
  import torch.nn
29
- from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import FakelyQuantONNXPyTorchExporter
30
- from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import FakelyQuantTorchScriptPyTorchExporter
27
+ from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import \
28
+ FakelyQuantONNXPyTorchExporter
29
+ from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import \
30
+ FakelyQuantTorchScriptPyTorchExporter
31
31
  from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
32
32
 
33
+ supported_serialization_quantization_export_dict = {
34
+ PytorchExportSerializationFormat.TORCHSCRIPT: [QuantizationFormat.FAKELY_QUANT],
35
+ PytorchExportSerializationFormat.ONNX: [QuantizationFormat.FAKELY_QUANT]
36
+ }
37
+
33
38
  def pytorch_export_model(model: torch.nn.Module,
34
39
  save_model_path: str,
35
40
  repr_dataset: Callable,
41
+ target_platform_capabilities: TargetPlatformCapabilities,
36
42
  is_layer_exportable_fn: Callable = is_pytorch_layer_exportable,
37
- mode: PyTorchExportMode = PyTorchExportMode.FAKELY_QUANT_TORCHSCRIPT) -> None:
43
+ serialization_format: PytorchExportSerializationFormat =
44
+ PytorchExportSerializationFormat.TORCHSCRIPT) -> None:
38
45
  """
39
46
  Export a PyTorch quantized model to a torchscript or onnx model.
40
47
  The model will be saved to the path in save_model_path.
41
- Mode can be used for different exported files. Currently, pytorch_export_model
42
- supports PyTorchExportMode.FAKELY_QUANT_TORCHSCRIPT (where the exported model
43
- is in a TorchScript format and its weights and activations are float fakely-quantized values),
44
- and PyTorchExportMode.FakelyQuantONNX (where the exported model
45
- is in an ONNX format and its weights and activations are float fakely-quantized values)
48
+ Currently, pytorch_export_model supports only QuantizationFormat.FAKELY_QUANT (where weights
49
+ and activations are float fakely-quantized values) and PytorchExportSerializationFormat.TORCHSCRIPT
50
+ (where the model will be saved to TorchScript model) or PytorchExportSerializationFormat.ONNX
51
+ (where the model will be saved to ONNX model).
46
52
 
47
53
  Args:
48
54
  model: Model to export.
49
- is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
50
- mode: Mode to export the model according to.
51
55
  save_model_path: Path to save the model.
52
56
  repr_dataset: Representative dataset for tracing the pytorch model (mandatory for exporting it).
57
+ target_platform_capabilities: TargetPlatformCapabilities object that describes the desired inference
58
+ target platform (includes quantization format).
59
+ is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
60
+ serialization_format: Format to export the model according to (by default
61
+ PytorchExportSerializationFormat.TORCHSCRIPT).
53
62
 
54
63
  """
55
64
 
56
- if mode == PyTorchExportMode.FAKELY_QUANT_TORCHSCRIPT:
57
- exporter = FakelyQuantTorchScriptPyTorchExporter(model,
58
- is_layer_exportable_fn,
59
- save_model_path,
60
- repr_dataset)
65
+ if serialization_format == PytorchExportSerializationFormat.TORCHSCRIPT:
66
+ if target_platform_capabilities.tp_model.quantization_format in \
67
+ supported_serialization_quantization_export_dict[serialization_format]:
68
+ exporter = FakelyQuantTorchScriptPyTorchExporter(model,
69
+ is_layer_exportable_fn,
70
+ save_model_path,
71
+ repr_dataset)
72
+ else:
73
+ Logger.critical(
74
+ f'Unsupported quantization {target_platform_capabilities.tp_model.quantization_format} for '
75
+ f'serialization {serialization_format} was used to export Pytorch model. Please see API for '
76
+ f'supported formats.') # pragma: no cover
61
77
 
62
- elif mode == PyTorchExportMode.FAKELY_QUANT_ONNX:
63
- exporter = FakelyQuantONNXPyTorchExporter(model,
64
- is_layer_exportable_fn,
65
- save_model_path,
66
- repr_dataset)
78
+ elif serialization_format == PytorchExportSerializationFormat.ONNX:
79
+ if target_platform_capabilities.tp_model.quantization_format in \
80
+ supported_serialization_quantization_export_dict[serialization_format]:
81
+ exporter = FakelyQuantONNXPyTorchExporter(model,
82
+ is_layer_exportable_fn,
83
+ save_model_path,
84
+ repr_dataset)
85
+ else:
86
+ Logger.critical(
87
+ f'Unsupported quantization {target_platform_capabilities.tp_model.quantization_format} for '
88
+ f'serialization {serialization_format} was used to export Pytorch model. Please see API for '
89
+ f'supported formats.') # pragma: no cover
67
90
 
68
91
  else:
69
92
  Logger.critical(
70
- f'Unsupported mode was used {mode.name} to export PyTorch model. '
71
- f'Please see API for supported modes.') # pragma: no cover
93
+ f'Unsupported serialization {serialization_format} was used to export Pytorch model. Please see API '
94
+ f'for supported formats.') # pragma: no cover
72
95
 
73
96
  exporter.export()
74
97
 
@@ -17,9 +17,10 @@ from typing import Tuple
17
17
 
18
18
  from model_compression_toolkit import quantizers_infrastructure as qi
19
19
  from model_compression_toolkit.core import common
20
- from model_compression_toolkit.core.common import Graph, Logger
21
- from model_compression_toolkit.core.common.constants import FOUND_TF
20
+ from model_compression_toolkit.core.common import Graph
21
+ from model_compression_toolkit.constants import FOUND_TF
22
22
  from model_compression_toolkit.core.common.user_info import UserInformation
23
+ from model_compression_toolkit.logger import Logger
23
24
 
24
25
  if FOUND_TF:
25
26
  import tensorflow as tf
@@ -34,6 +35,7 @@ if FOUND_TF:
34
35
  Args:
35
36
  n: A node of mct graph.
36
37
  layer: A keras layer
38
+ include_activation_quantizers: Whether to use the wrapper for the activation quantizer or not
37
39
 
38
40
  Returns: Wrapped layer with weights quantizers and activation quantizers
39
41
 
@@ -55,7 +57,7 @@ if FOUND_TF:
55
57
  Exportable Keras model and user information.
56
58
  """
57
59
  exportable_model, user_info = KerasModelBuilder(graph=graph,
58
- wrapper=_get_wrapper).build_model()
60
+ wrapper=_get_wrapper).build_model()
59
61
  exportable_model.trainable = False
60
62
  return exportable_model, user_info
61
63
  else:
@@ -14,8 +14,10 @@
14
14
  # ==============================================================================
15
15
  from typing import Dict, Any
16
16
 
17
- from model_compression_toolkit.core.common import BaseNode, Logger
18
- from model_compression_toolkit.core.common.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED, CLUSTER_CENTERS, SCALE_PER_CHANNEL
17
+ from model_compression_toolkit.core.common import BaseNode
18
+ from model_compression_toolkit.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED, CLUSTER_CENTERS, SCALE_PER_CHANNEL
19
+
20
+ from model_compression_toolkit.logger import Logger
19
21
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
22
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget
21
23
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import get_inferable_quantizer_class
@@ -15,8 +15,8 @@
15
15
  from typing import Any
16
16
 
17
17
 
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import FOUND_TF
20
20
 
21
21
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
22
22
 
@@ -16,8 +16,9 @@
16
16
 
17
17
  from model_compression_toolkit import quantizers_infrastructure as qi
18
18
  from model_compression_toolkit.core import common
19
- from model_compression_toolkit.core.common import Graph, Logger
20
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
+ from model_compression_toolkit.core.common import Graph
20
+ from model_compression_toolkit.constants import FOUND_TORCH
21
+ from model_compression_toolkit.logger import Logger
21
22
 
22
23
  if FOUND_TORCH:
23
24
  import torch
@@ -15,9 +15,10 @@
15
15
 
16
16
  from typing import Dict, Any
17
17
 
18
- from model_compression_toolkit.core.common import BaseNode, Logger
19
- from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX, \
18
+ from model_compression_toolkit.core.common import BaseNode
19
+ from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX, \
20
20
  SCALE_PER_CHANNEL, CLUSTER_CENTERS
21
+ from model_compression_toolkit.logger import Logger
21
22
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
23
  from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
23
24
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  from typing import Any
16
16
 
17
- from model_compression_toolkit.core.common import Logger
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
17
+ from model_compression_toolkit.logger import Logger
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
19
 
20
20
  if FOUND_TORCH:
21
21
  from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
@@ -0,0 +1,32 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from abc import abstractmethod
17
+
18
+ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
19
+
20
+
21
+ class GPTQFrameworkImplemantation(FrameworkImplementation):
22
+ """
23
+ Class to implement framework related methods that are used in GPTQ
24
+ """
25
+
26
+ @abstractmethod
27
+ def get_gptq_trainer_obj(self):
28
+ """
29
+ Returns: GPTQTrainer object
30
+ """
31
+ raise NotImplemented(f'{self.__class__.__name__} have to implement the '
32
+ f'framework\'s get_gptq_trainer method.') # pragma: no cover
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  from typing import Tuple, List
16
16
 
17
- from model_compression_toolkit import FrameworkInfo
18
- from model_compression_toolkit.core.common import Logger
17
+ from model_compression_toolkit.core import FrameworkInfo
18
+ from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.core.common.graph.base_graph import Graph
20
20
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
21
21
 
@@ -17,12 +17,13 @@ from abc import ABC, abstractmethod
17
17
  import numpy as np
18
18
  from typing import Callable, List, Any
19
19
  from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
20
- from model_compression_toolkit.core.common import Graph, Logger, BaseNode
20
+ from model_compression_toolkit.core.common import Graph, BaseNode
21
21
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
- from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
23
22
  from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
23
+ from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
24
24
  from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
25
25
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
26
+ from model_compression_toolkit.logger import Logger
26
27
 
27
28
 
28
29
  class GPTQTrainer(ABC):
@@ -34,7 +35,7 @@ class GPTQTrainer(ABC):
34
35
  graph_float: Graph,
35
36
  graph_quant: Graph,
36
37
  gptq_config: GradientPTQConfig,
37
- fw_impl: FrameworkImplementation,
38
+ fw_impl: GPTQFrameworkImplemantation,
38
39
  fw_info: FrameworkInfo):
39
40
  """
40
41
  Build two models from a graph: A teacher network (float model) and a student network (quantized model).
@@ -259,7 +260,7 @@ def gptq_training(graph_float: Graph,
259
260
  graph_quant: Graph,
260
261
  gptq_config: GradientPTQConfig,
261
262
  representative_data_gen: Callable,
262
- fw_impl: FrameworkImplementation,
263
+ fw_impl: GPTQFrameworkImplemantation,
263
264
  fw_info: FrameworkInfo) -> Graph:
264
265
  """
265
266
  GPTQ training process using knowledge distillation with a teacher network (float model) and a student network (quantized model).
@@ -0,0 +1,29 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Type
17
+
18
+ from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
19
+ from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
20
+ from model_compression_toolkit.gptq.keras.gptq_training import KerasGPTQTrainer
21
+
22
+
23
+ class GPTQKerasImplemantation(GPTQFrameworkImplemantation, KerasImplementation):
24
+
25
+ def get_gptq_trainer_obj(self) -> Type[KerasGPTQTrainer]:
26
+ """
27
+ Returns: Keras object of GPTQTrainer
28
+ """
29
+ return KerasGPTQTrainer