mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
  2. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
  3. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +12 -41
  5. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  6. model_compression_toolkit/core/__init__.py +14 -0
  7. model_compression_toolkit/core/analyzer.py +3 -2
  8. model_compression_toolkit/core/common/__init__.py +0 -1
  9. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  11. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  12. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  13. model_compression_toolkit/core/common/framework_info.py +1 -1
  14. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  15. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  16. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  18. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  19. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  20. model_compression_toolkit/core/common/memory_computation.py +1 -1
  21. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  23. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  26. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  28. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  29. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  30. model_compression_toolkit/core/common/model_collector.py +2 -2
  31. model_compression_toolkit/core/common/model_validation.py +1 -1
  32. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  33. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  34. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  35. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  36. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  37. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  38. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  39. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  50. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  51. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  52. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  54. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  55. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  56. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  57. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  58. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  59. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  60. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  61. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  62. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  63. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  65. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  66. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  67. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  68. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  69. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  72. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
  73. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  74. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  75. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  76. model_compression_toolkit/core/keras/constants.py +0 -7
  77. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  85. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  86. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  87. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  88. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  89. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  90. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  91. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  92. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  93. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  94. model_compression_toolkit/core/keras/reader/common.py +1 -1
  95. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  99. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  100. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  102. model_compression_toolkit/core/pytorch/constants.py +4 -6
  103. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  109. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  110. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  111. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  112. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  113. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  114. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  115. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  116. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  117. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  118. model_compression_toolkit/core/runner.py +7 -7
  119. model_compression_toolkit/exporter/__init__.py +5 -0
  120. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
  121. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
  124. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  125. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  126. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
  127. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
  128. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
  129. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
  130. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
  131. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
  132. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
  133. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
  135. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
  136. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
  137. model_compression_toolkit/gptq/__init__.py +6 -0
  138. model_compression_toolkit/gptq/common/gptq_config.py +57 -104
  139. model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
  140. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  141. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  142. model_compression_toolkit/gptq/common/gptq_training.py +30 -39
  143. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  144. model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
  145. model_compression_toolkit/gptq/keras/graph_info.py +8 -33
  146. model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
  147. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  148. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
  149. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  150. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
  151. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  152. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
  153. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
  154. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  155. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
  156. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  157. model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
  158. model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
  159. model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
  160. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  161. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
  162. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  163. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
  164. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  165. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
  166. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
  167. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  168. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
  169. model_compression_toolkit/gptq/runner.py +3 -2
  170. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
  171. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  172. model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
  173. model_compression_toolkit/ptq/__init__.py +3 -0
  174. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  175. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  176. model_compression_toolkit/qat/__init__.py +4 -0
  177. model_compression_toolkit/qat/common/__init__.py +1 -2
  178. model_compression_toolkit/qat/common/qat_config.py +3 -1
  179. model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
  180. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  181. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
  182. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
  183. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  184. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  185. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
  186. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
  187. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  211. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  212. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  213. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  214. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  215. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
  217. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  218. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  219. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  220. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  221. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  222. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
  223. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  224. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
  225. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  226. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
  227. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  228. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
  229. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  232. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  233. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  234. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  235. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  236. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  237. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  238. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  239. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  240. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  241. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  242. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  243. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  244. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  248. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  250. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  254. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  255. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  257. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  259. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  261. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  263. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  265. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  273. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  274. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  275. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  276. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  277. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  278. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  279. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  280. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  281. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  282. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  283. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  284. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  285. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  286. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  287. model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
  288. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
  289. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
  290. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  291. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  292. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  293. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  294. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  300. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  301. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  302. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  303. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  304. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  305. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  306. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  307. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -14,61 +14,69 @@
14
14
  # ==============================================================================
15
15
  from typing import Any
16
16
 
17
- from keras.engine.input_layer import InputLayer
18
-
19
- from model_compression_toolkit.core.common import Logger
20
- from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
21
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
22
17
 
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import FOUND_TF
23
20
 
21
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
24
22
 
25
- def is_keras_layer_exportable(layer: Any) -> bool:
26
- """
27
- Check whether a Keras layer is a valid exportable layer or not.
28
23
 
29
- Args:
30
- layer: Keras layer to check if considered to be valid for exporting.
24
+ if FOUND_TF:
25
+ from keras.engine.input_layer import InputLayer
26
+ from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
31
27
 
32
- Returns:
28
+ def is_keras_layer_exportable(layer: Any) -> bool:
29
+ """
33
30
  Check whether a Keras layer is a valid exportable layer or not.
34
- """
35
- # Keras Input layers are not wrapped
36
- if isinstance(layer, InputLayer):
37
- return True
38
31
 
39
- valid_layer = isinstance(layer, KerasQuantizationWrapper)
40
- if not valid_layer:
41
- Logger.error(
42
- f'Exportable layer must be wrapped using KerasQuantizationWrapper, but layer {layer.name} is of type '
43
- f'{type(layer)}') # pragma: no cover
32
+ Args:
33
+ layer: Keras layer to check if considered to be valid for exporting.
44
34
 
45
- valid_weights_quantizers = isinstance(layer.weights_quantizers, dict)
46
- if not valid_weights_quantizers:
47
- Logger.error(
48
- f'KerasQuantizationWrapper must have a weights_quantizers but has a '
49
- f'{type(layer.weights_quantizers)} object') # pragma: no cover
35
+ Returns:
36
+ Check whether a Keras layer is a valid exportable layer or not.
37
+ """
38
+ # Keras Input layers are not wrapped
39
+ if isinstance(layer, InputLayer):
40
+ return True
50
41
 
51
- for _, weights_quantizer in layer.weights_quantizers.items():
52
- if not isinstance(weights_quantizer, BaseInferableQuantizer):
42
+ valid_layer = isinstance(layer, KerasQuantizationWrapper)
43
+ if not valid_layer:
53
44
  Logger.error(
54
- f'weights_quantizer must be a BaseInferableQuantizer object but has a '
55
- f'{type(weights_quantizer)} object') # pragma: no cover
45
+ f'Exportable layer must be wrapped using KerasQuantizationWrapper, but layer {layer.name} is of type '
46
+ f'{type(layer)}') # pragma: no cover
56
47
 
57
- valid_activation_quantizers = isinstance(layer.activation_quantizers, list)
58
- if not valid_activation_quantizers:
59
- Logger.error(
60
- f'KerasQuantizationWrapper must have a activation_quantizers list but has a '
61
- f'{type(layer.activation_quantizers)} object') # pragma: no cover
48
+ valid_weights_quantizers = isinstance(layer.weights_quantizers, dict)
49
+ if not valid_weights_quantizers:
50
+ Logger.error(
51
+ f'KerasQuantizationWrapper must have a weights_quantizers but has a '
52
+ f'{type(layer.weights_quantizers)} object') # pragma: no cover
53
+
54
+ for _, weights_quantizer in layer.weights_quantizers.items():
55
+ if not isinstance(weights_quantizer, BaseInferableQuantizer):
56
+ Logger.error(
57
+ f'weights_quantizer must be a BaseInferableQuantizer object but has a '
58
+ f'{type(weights_quantizer)} object') # pragma: no cover
62
59
 
63
- for activation_quantizers in layer.activation_quantizers:
64
- if not isinstance(activation_quantizers, BaseInferableQuantizer):
60
+ valid_activation_quantizers = isinstance(layer.activation_quantizers, list)
61
+ if not valid_activation_quantizers:
65
62
  Logger.error(
66
- f'activation_quantizers must be a BaseInferableQuantizer object but has a '
67
- f'{type(activation_quantizers)} object') # pragma: no cover
63
+ f'KerasQuantizationWrapper must have a activation_quantizers list but has a '
64
+ f'{type(layer.activation_quantizers)} object') # pragma: no cover
68
65
 
69
- quantizers = layer.activation_quantizers + list(layer.weights_quantizers.values())
70
- is_valid_quantizers = all([isinstance(x, BaseInferableQuantizer) for x in quantizers])
71
- if not is_valid_quantizers:
72
- Logger.error(f'Found a quantizer that is not of type BaseInferableQuantizer') # pragma: no cover
66
+ for activation_quantizers in layer.activation_quantizers:
67
+ if not isinstance(activation_quantizers, BaseInferableQuantizer):
68
+ Logger.error(
69
+ f'activation_quantizers must be a BaseInferableQuantizer object but has a '
70
+ f'{type(activation_quantizers)} object') # pragma: no cover
73
71
 
74
- return True
72
+ quantizers = layer.activation_quantizers + list(layer.weights_quantizers.values())
73
+ is_valid_quantizers = all([isinstance(x, BaseInferableQuantizer) for x in quantizers])
74
+ if not is_valid_quantizers:
75
+ Logger.error(f'Found a quantizer that is not of type BaseInferableQuantizer') # pragma: no cover
76
+
77
+ return True
78
+ else:
79
+ def is_keras_layer_exportable(*args, **kwargs): # pragma: no cover
80
+ Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
81
+ 'when using is_keras_layer_exportable. '
82
+ 'Could not find Tensorflow package.')
@@ -13,42 +13,50 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import torch
17
16
 
18
17
  from model_compression_toolkit import quantizers_infrastructure as qi
19
18
  from model_compression_toolkit.core import common
20
19
  from model_compression_toolkit.core.common import Graph
21
- from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
22
- from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizers import \
23
- get_quantization_quantizers
24
-
25
-
26
- def fully_quantized_wrapper(node: common.BaseNode, module: torch.nn.Module) -> qi.PytorchQuantizationWrapper:
27
- """
28
- A function which takes a computational graph node and a pytorch module and
29
- perform the quantization wrapping
30
-
31
- Args:
32
- node: A node of mct graph.
33
- module: A Pytorch module
34
-
35
- Returns: Wrapped layer
36
-
37
- """
38
- weight_quantizers, activation_quantizers = get_quantization_quantizers(node)
39
- wrapped_layer = qi.PytorchQuantizationWrapper(module, weight_quantizers, activation_quantizers)
40
- return wrapped_layer
41
-
42
-
43
- def get_exportable_pytorch_model(graph: Graph):
44
- """
45
- Convert graph to fully quantized PyTorch model.
46
-
47
- Args:
48
- graph: Graph to convert to a PyTorch model.
49
-
50
- Returns:
51
- Fully quantized PyTorch model.
52
- """
53
- return PyTorchModelBuilder(graph=graph,
54
- wrapper=fully_quantized_wrapper).build_model()
20
+ from model_compression_toolkit.constants import FOUND_TORCH
21
+ from model_compression_toolkit.logger import Logger
22
+
23
+ if FOUND_TORCH:
24
+ import torch
25
+ from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
26
+ from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizers import \
27
+ get_quantization_quantizers
28
+
29
+ def fully_quantized_wrapper(node: common.BaseNode, module: torch.nn.Module) -> qi.PytorchQuantizationWrapper:
30
+ """
31
+ A function which takes a computational graph node and a pytorch module and
32
+ perform the quantization wrapping
33
+
34
+ Args:
35
+ node: A node of mct graph.
36
+ module: A Pytorch module
37
+
38
+ Returns: Wrapped layer
39
+
40
+ """
41
+ weight_quantizers, activation_quantizers = get_quantization_quantizers(node)
42
+ wrapped_layer = qi.PytorchQuantizationWrapper(module, weight_quantizers, activation_quantizers)
43
+ return wrapped_layer
44
+
45
+
46
+ def get_exportable_pytorch_model(graph: Graph):
47
+ """
48
+ Convert graph to fully quantized PyTorch model.
49
+
50
+ Args:
51
+ graph: Graph to convert to a PyTorch model.
52
+
53
+ Returns:
54
+ Fully quantized PyTorch model.
55
+ """
56
+ return PyTorchModelBuilder(graph=graph,
57
+ wrapper=fully_quantized_wrapper).build_model()
58
+ else:
59
+ def get_exportable_pytorch_model(*args, **kwargs): # pragma: no cover
60
+ Logger.error('Installing torch is mandatory '
61
+ 'when using get_exportable_pytorch_model. '
62
+ 'Could not find PyTorch package.')
@@ -15,9 +15,11 @@
15
15
 
16
16
  from typing import Dict, Any
17
17
 
18
- from model_compression_toolkit.core.common import BaseNode, Logger
19
- from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.core.common import BaseNode
19
+ from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX, \
20
+ SCALE_PER_CHANNEL, CLUSTER_CENTERS
21
+ from model_compression_toolkit.logger import Logger
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
23
  from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
22
24
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
23
25
  get_inferable_quantizer_class
@@ -45,6 +47,15 @@ def get_weights_inferable_quantizer_kwargs(node: BaseNode) -> Dict[str, Any]:
45
47
  qi_inferable_quantizers_constants.MIN_RANGE: node_w_qc.weights_quantization_params[RANGE_MIN].flatten(),
46
48
  qi_inferable_quantizers_constants.MAX_RANGE: node_w_qc.weights_quantization_params[RANGE_MAX].flatten(),
47
49
  qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
50
+
51
+ elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER, QuantizationMethod.LUT_SYM_QUANTIZER]:
52
+ return {qi_inferable_quantizers_constants.NUM_BITS: node_w_qc.weights_n_bits,
53
+ qi_inferable_quantizers_constants.CLUSTER_CENTERS: node_w_qc.weights_quantization_params[CLUSTER_CENTERS].flatten(),
54
+ qi_inferable_quantizers_constants.THRESHOLD: node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten(),
55
+ qi_inferable_quantizers_constants.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
56
+ qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
57
+ # TODO: Add MULTIPLIER_N_BITS & EPS to node quantization config
58
+
48
59
  else:
49
60
  Logger.critical(f'Not supported quantization method for weights inferable quantizers.') # pragma: no cover
50
61
 
@@ -65,6 +76,15 @@ def get_activation_inferable_quantizer_kwargs(node: BaseNode) -> Dict[str, Any]:
65
76
  return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.activation_n_bits,
66
77
  qi_inferable_quantizers_constants.MIN_RANGE: np.asarray([node_qc.activation_quantization_params[RANGE_MIN]]),
67
78
  qi_inferable_quantizers_constants.MAX_RANGE: np.asarray([node_qc.activation_quantization_params[RANGE_MAX]])}
79
+
80
+ elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER]:
81
+ return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.activation_n_bits,
82
+ qi_inferable_quantizers_constants.CLUSTER_CENTERS: np.asarray(
83
+ [node_qc.activation_quantization_params[CLUSTER_CENTERS]]),
84
+ qi_inferable_quantizers_constants.THRESHOLD: np.asarray(
85
+ [node_qc.activation_quantization_params[THRESHOLD]]),
86
+ qi_inferable_quantizers_constants.SIGNED: node_qc.activation_quantization_params.get(SIGNED)}
87
+ # TODO: Add MULTIPLIER_N_BITS & EPS to node quantization config
68
88
  else:
69
89
  Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
70
90
 
@@ -111,10 +131,10 @@ def get_activations_quantizer_for_node(node: BaseNode) -> BasePyTorchInferableQu
111
131
  node_act_qc = node.final_activation_quantization_cfg
112
132
  activation_quantization_method = node_act_qc.activation_quantization_method
113
133
 
114
- quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
115
- activation_quantization_method,
116
- BasePyTorchInferableQuantizer)
134
+ quantizer_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
135
+ activation_quantization_method,
136
+ BasePyTorchInferableQuantizer)
117
137
  kwargs = get_activation_inferable_quantizer_kwargs(node)
118
138
 
119
- return quantier_for_node(**kwargs)
139
+ return quantizer_for_node(**kwargs)
120
140
 
@@ -14,24 +14,31 @@
14
14
  # ==============================================================================
15
15
  from typing import Any
16
16
 
17
- from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
18
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
19
- BasePyTorchInferableQuantizer
17
+ from model_compression_toolkit.logger import Logger
18
+ from model_compression_toolkit.constants import FOUND_TORCH
20
19
 
20
+ if FOUND_TORCH:
21
+ from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
23
+ BasePyTorchInferableQuantizer
24
+ def is_pytorch_layer_exportable(layer: Any) -> bool:
25
+ """
26
+ Check whether a torch Module is a valid exportable module or not.
21
27
 
22
- def is_pytorch_layer_exportable(layer: Any) -> bool:
23
- """
24
- Check whether a torch Module is a valid exportable module or not.
28
+ Args:
29
+ layer: PyTorch module to check if considered to be valid for exporting.
25
30
 
26
- Args:
27
- layer: PyTorch module to check if considered to be valid for exporting.
28
-
29
- Returns:
30
- Check whether a PyTorch layer is a valid exportable layer or not.
31
- """
32
- if isinstance(layer, PytorchQuantizationWrapper):
33
- quantizers = list(layer.weights_quantizers.values())
34
- quantizers.extend(layer.activation_quantizers)
35
- if all([isinstance(q, BasePyTorchInferableQuantizer) for q in quantizers]):
36
- return True
37
- return False
31
+ Returns:
32
+ Check whether a PyTorch layer is a valid exportable layer or not.
33
+ """
34
+ if isinstance(layer, PytorchQuantizationWrapper):
35
+ quantizers = list(layer.weights_quantizers.values())
36
+ quantizers.extend(layer.activation_quantizers)
37
+ if all([isinstance(q, BasePyTorchInferableQuantizer) for q in quantizers]):
38
+ return True
39
+ return False
40
+ else:
41
+ def is_pytorch_layer_exportable(*args, **kwargs): # pragma: no cover
42
+ Logger.error('Installing torch is mandatory '
43
+ 'when using is_pytorch_layer_exportable. '
44
+ 'Could not find PyTorch package.')
@@ -12,3 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
16
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GradientPTQConfigV2
17
+ from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization_experimental
18
+ from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
19
+ from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization_experimental
20
+ from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config
@@ -16,9 +16,7 @@ from enum import Enum
16
16
  from typing import Callable, Any, Dict
17
17
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
18
18
  from model_compression_toolkit.core import common
19
- from model_compression_toolkit.gptq.common.gptq_constants import N_BATCHES_STR, QUANT_PARAM_LEARNING_STR, N_EPOCHS_STR, \
20
- MAX_LSB_STR
21
- from model_compression_toolkit.gptq.common.gptq_quantizer_config import GPTQQuantizerConfig, SoftQuantizerConfig
19
+ from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR, MAX_LSB_STR, REG_DEFAULT
22
20
 
23
21
 
24
22
  class RoundingType(Enum):
@@ -31,30 +29,53 @@ class RoundingType(Enum):
31
29
  SoftQuantizer = 1
32
30
 
33
31
 
32
+ class GPTQHessianWeightsConfig:
33
+ """
34
+ Configuration to use for computing the Hessian-based weights for GPTQ loss metric.
35
+ """
36
+
37
+ def __init__(self,
38
+ hessians_num_samples: int = 16,
39
+ norm_weights: bool = True,
40
+ log_norm: bool = True,
41
+ scale_log_norm: bool = False,
42
+ hessians_n_iter: int = 50):
43
+
44
+ """
45
+ Initialize a GPTQHessianWeightsConfig.
46
+ Args:
47
+ hessians_num_samples (int): Number of samples to use for computing the Hessian-based weights.
48
+ norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
49
+ log_norm (bool): Whether to use log normalization to the GPTQ Hessian-based weights.
50
+ scale_log_norm (bool): Whether to scale the final vector of the Hessian weights.
51
+ hessians_n_iter (int): Number of random iterations to run Hessian approximation for GPTQ weights.
52
+ """
53
+
54
+ self.hessians_num_samples = hessians_num_samples
55
+ self.norm_weights = norm_weights
56
+ self.log_norm = log_norm
57
+ self.scale_log_norm = scale_log_norm
58
+ self.hessians_n_iter = hessians_n_iter
59
+
60
+
34
61
  class GradientPTQConfig:
35
62
  """
36
63
  Configuration to use for quantization with GradientPTQ (experimental).
37
64
  """
38
65
 
39
- def __init__(self,
40
- n_iter: int,
66
+ def __init__(self, n_iter: int,
41
67
  optimizer: Any,
42
68
  optimizer_rest: Any = None,
43
69
  loss: Callable = None,
44
70
  log_function: Callable = None,
45
71
  train_bias: bool = True,
46
- quantization_parameters_learning: bool = False,
47
72
  rounding_type: RoundingType = RoundingType.SoftQuantizer,
48
- lsb_change_per_bit_width: dict = DefaultDict({}, lambda: 1),
49
- eps: float = 1e-6,
50
- use_jac_based_weights: bool = True,
51
- num_samples_for_loss: int = 16,
52
- norm_weights: bool = False,
73
+ use_hessian_based_weights: bool = True,
53
74
  optimizer_quantization_parameter: Any = None,
54
75
  optimizer_bias: Any = None,
55
- log_norm: bool = True,
56
- weights_n_iter: int = 50,
57
- quantizer_config: GPTQQuantizerConfig = SoftQuantizerConfig()):
76
+ regularization_factor: float = REG_DEFAULT,
77
+ hessian_weights_config: GPTQHessianWeightsConfig = GPTQHessianWeightsConfig(),
78
+ gptq_quantizer_params_override: Dict[str, Any] = None):
58
79
  """
59
80
  Initialize a GradientPTQConfig.
60
81
 
@@ -67,18 +88,13 @@ class GradientPTQConfig:
67
88
  accordingly. see example in multiple_tensors_mse_loss
68
89
  log_function (Callable): Function to log information about the GPTQ process.
69
90
  train_bias (bool): Whether to update the bias during the training or not.
70
- quantization_parameters_learning (bool): Whether to update the quantization param during the training or not.
71
91
  rounding_type (RoundingType): An enum that defines the rounding type.
72
- lsb_change_per_bit_width (dict): Whether to update the bias during the training or not.
73
- eps (float): A floating point value for numeric stability.
74
- use_jac_based_weights (bool): Whether to use jacobian-based weights for weighted average loss.
75
- num_samples_for_loss (int): Number of samples to use for computing the jacobian-based weights.
76
- norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
92
+ use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
77
93
  optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters.
78
94
  optimizer_bias (Any): Optimizer to override the rest optimizer for bias.
79
- log_norm (bool): Whether to use log normalization to the GPTQ Jacobian-based weights.
80
- weights_n_iter (int): Number of random iterations to run Jacobian approximation for GPTQ weights.
81
- quantizer_config (GPTQQuantizerConfig): A class that contains the quantizer specific config.
95
+ regularization_factor (float): A floating point number that defines the regularization factor.
96
+ hessian_weights_config (GPTQHessianWeightsConfig): A configuration that include all necessary arguments to run a computation of Hessian weights for the GPTQ loss.
97
+ gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
82
98
 
83
99
  """
84
100
  self.n_iter = n_iter
@@ -88,68 +104,34 @@ class GradientPTQConfig:
88
104
  self.log_function = log_function
89
105
  self.train_bias = train_bias
90
106
 
91
- if quantization_parameters_learning and rounding_type == RoundingType.STE:
92
- common.Logger.error("Quantization parameters learning is not supported with STE rounding.")
93
-
94
- self.quantization_parameters_learning = quantization_parameters_learning
95
107
  self.rounding_type = rounding_type
96
- self.lsb_change_per_bit_width = lsb_change_per_bit_width
97
- self.eps = eps
98
- self.use_jac_based_weights = use_jac_based_weights
99
- self.num_samples_for_loss = num_samples_for_loss
100
- self.norm_weights = norm_weights
108
+ self.use_hessian_based_weights = use_hessian_based_weights
101
109
  self.optimizer_quantization_parameter = optimizer_quantization_parameter
102
110
  self.optimizer_bias = optimizer_bias
103
- self.log_norm = log_norm
104
- self.weights_n_iter = weights_n_iter
111
+ self.regularization_factor = regularization_factor
112
+ self.hessian_weights_config = hessian_weights_config
105
113
 
106
- if self._verify_quantizer_config(quantizer_config, rounding_type):
107
- self.quantizer_config = quantizer_config
108
- else:
109
- common.Logger.error(f"Quantizer config of type {type(quantizer_config)} "
110
- f"is not suitable for rounding type {rounding_type}")
111
-
112
- def _verify_quantizer_config(self, quantizer_config, rounding_type) -> bool:
113
- """
114
- Verifies that the given quantizer config matches the given rounding type.
115
-
116
- Args:
117
- quantizer_config: A quantizer config.
118
- rounding_type: A RoundingType.
119
-
120
- Returns: True if the quantizer config matches the rounding type, False otherwise.
121
-
122
- """
123
- if rounding_type == RoundingType.SoftQuantizer:
124
- return type(quantizer_config) == SoftQuantizerConfig
125
-
126
- # Here, we compare type() and not isinstance to exclude instance equality because of inheritance
127
- return type(quantizer_config) == GPTQQuantizerConfig
114
+ self.gptq_quantizer_params_override = {} if gptq_quantizer_params_override is None \
115
+ else gptq_quantizer_params_override
128
116
 
129
117
 
130
118
  class GradientPTQConfigV2(GradientPTQConfig):
131
119
  """
132
120
  Configuration to use for quantization with GradientPTQV2 (experimental).
133
121
  """
134
- def __init__(self,
135
- n_epochs: int,
122
+ def __init__(self, n_epochs: int,
136
123
  optimizer: Any,
137
124
  optimizer_rest: Any = None,
138
125
  loss: Callable = None,
139
126
  log_function: Callable = None,
140
127
  train_bias: bool = True,
141
- quantization_parameters_learning: bool = False,
142
128
  rounding_type: RoundingType = RoundingType.SoftQuantizer,
143
- lsb_change_per_bit_width: dict = DefaultDict({}, lambda: 1),
144
- eps: float = 1e-6,
145
- use_jac_based_weights: bool = True,
146
- num_samples_for_loss: int = 16,
147
- norm_weights: bool = False,
129
+ use_hessian_based_weights: bool = True,
148
130
  optimizer_quantization_parameter: Any = None,
149
131
  optimizer_bias: Any = None,
150
- log_norm: bool = True,
151
- weights_n_iter: int = 50,
152
- quantizer_config: GPTQQuantizerConfig = SoftQuantizerConfig()):
132
+ regularization_factor: float = REG_DEFAULT,
133
+ hessian_weights_config: GPTQHessianWeightsConfig = GPTQHessianWeightsConfig(),
134
+ gptq_quantizer_params_override: Dict[str, Any] = None):
153
135
  """
154
136
  Initialize a GradientPTQConfigV2.
155
137
 
@@ -162,18 +144,13 @@ class GradientPTQConfigV2(GradientPTQConfig):
162
144
  accordingly. see example in multiple_tensors_mse_loss
163
145
  log_function (Callable): Function to log information about the GPTQ process.
164
146
  train_bias (bool): Whether to update the bias during the training or not.
165
- quantization_parameters_learning (bool): Whether to update the quantization param during the training or not.
166
147
  rounding_type (RoundingType): An enum that defines the rounding type.
167
- lsb_change_per_bit_width (dict): Whether to update the bias during the training or not.
168
- eps (float): A floating point value for numeric stability.
169
- use_jac_based_weights (bool): Whether to use jacobian-based weights for weighted average loss.
170
- num_samples_for_loss (int): Number of samples to use for computing the jacobian-based weights.
171
- norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
148
+ use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
172
149
  optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters.
173
150
  optimizer_bias (Any): Optimizer to override the rest optimizerfor bias.
174
- log_norm (bool): Whether to use log normalization to the GPTQ Jacobian-based weights.
175
- weights_n_iter (int): Number of random iterations to run Jacobian approximation for GPTQ weights.
176
- quantizer_config (Any): A class that contains the quantizer specific config.
151
+ regularization_factor (float): A floating point number that defines the regularization factor.
152
+ hessian_weights_config (GPTQHessianWeightsConfig): A configuration that include all necessary arguments to run a computation of Hessian weights for the GPTQ loss.
153
+ gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
177
154
 
178
155
  """
179
156
 
@@ -183,18 +160,13 @@ class GradientPTQConfigV2(GradientPTQConfig):
183
160
  loss=loss,
184
161
  log_function=log_function,
185
162
  train_bias=train_bias,
186
- quantization_parameters_learning=quantization_parameters_learning,
187
163
  rounding_type=rounding_type,
188
- lsb_change_per_bit_width=lsb_change_per_bit_width,
189
- eps=eps,
190
- use_jac_based_weights=use_jac_based_weights,
191
- num_samples_for_loss=num_samples_for_loss,
192
- norm_weights=norm_weights,
164
+ use_hessian_based_weights=use_hessian_based_weights,
193
165
  optimizer_quantization_parameter=optimizer_quantization_parameter,
194
166
  optimizer_bias=optimizer_bias,
195
- log_norm=log_norm,
196
- weights_n_iter=weights_n_iter,
197
- quantizer_config=quantizer_config)
167
+ regularization_factor=regularization_factor,
168
+ hessian_weights_config=hessian_weights_config,
169
+ gptq_quantizer_params_override=gptq_quantizer_params_override)
198
170
  self.n_epochs = n_epochs
199
171
 
200
172
  @classmethod
@@ -211,22 +183,3 @@ class GradientPTQConfigV2(GradientPTQConfig):
211
183
  v1_params = config_v1.__dict__
212
184
  v1_params = {k: v for k, v in v1_params.items() if k != 'n_iter'}
213
185
  return cls(n_epochs, **v1_params)
214
-
215
- def get_extended_quantizer_parametes(self) -> Dict[str, Any]:
216
- """
217
- Return a dictionary with a mapping to necessary additional parameters for initializing the GPTQ quantizer.
218
-
219
- Returns: A dictionary with parameters for initializing a quantizer.
220
-
221
- """
222
-
223
- if self.rounding_type == RoundingType.SoftQuantizer:
224
- return {N_BATCHES_STR: self.quantizer_config.n_batches,
225
- QUANT_PARAM_LEARNING_STR: self.quantization_parameters_learning,
226
- N_EPOCHS_STR: self.n_epochs}
227
- elif self.rounding_type == RoundingType.STE:
228
- return {MAX_LSB_STR: self.lsb_change_per_bit_width}
229
-
230
- return {}
231
-
232
-
@@ -2,7 +2,6 @@
2
2
  AUXVAR = 'auxvar_tensor'
3
3
  ITERVAR = 'iteration_variable'
4
4
  SCALE_TENSOR = "scale_ptq_tensor"
5
- GPTQ_ITER = "gptq_iter"
6
5
  AUXSHIFT = 'shift'
7
6
  WEIGHTS_QUANTIZATION_PARAMS = 'weights_quantization_params'
8
7
  PTQ_MIN_RANGE = "min_range"
@@ -11,22 +10,16 @@ PTQ_THRESHOLD = "ptq_threshold"
11
10
  SCALE_PTQ = "scale"
12
11
 
13
12
  # Default quantizer values
14
- N_EPOCHS = 10000
15
13
  N_CYCLES = 4
16
14
  MIM_TEMP = 0.5
17
15
  MAX_TEMP = 1.0
18
16
  REG_DEFAULT = 0.01
19
- MAX_ITERATIONS_DEFAULT = 10000
20
17
  MAX_LSB_CHANGE = 1
21
18
 
22
19
  # Soft rounding arguments values
23
20
  SOFT_ROUNDING_GAMMA = -0.1
24
21
  SOFT_ROUNDING_ZETA = 1.1
25
- SOFT_ROUNDING_BETA = 2 / 3
26
22
 
27
23
  # GPTQ config constant
28
- REGULARIZATION_VALUES = "regularization_values"
29
- N_BATCHES_STR = 'n_batches'
30
24
  QUANT_PARAM_LEARNING_STR = 'quantization_parameter_learning'
31
- N_EPOCHS_STR = 'n_epochs'
32
25
  MAX_LSB_STR = 'max_lsbs_change_map'
@@ -0,0 +1,32 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from abc import abstractmethod
17
+
18
+ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
19
+
20
+
21
+ class GPTQFrameworkImplemantation(FrameworkImplementation):
22
+ """
23
+ Class to implement framework related methods that are used in GPTQ
24
+ """
25
+
26
+ @abstractmethod
27
+ def get_gptq_trainer_obj(self):
28
+ """
29
+ Returns: GPTQTrainer object
30
+ """
31
+ raise NotImplemented(f'{self.__class__.__name__} have to implement the '
32
+ f'framework\'s get_gptq_trainer method.') # pragma: no cover