mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
  2. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
  3. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +12 -41
  5. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  6. model_compression_toolkit/core/__init__.py +14 -0
  7. model_compression_toolkit/core/analyzer.py +3 -2
  8. model_compression_toolkit/core/common/__init__.py +0 -1
  9. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  11. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  12. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  13. model_compression_toolkit/core/common/framework_info.py +1 -1
  14. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  15. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  16. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  18. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  19. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  20. model_compression_toolkit/core/common/memory_computation.py +1 -1
  21. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  23. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  26. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  28. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  29. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  30. model_compression_toolkit/core/common/model_collector.py +2 -2
  31. model_compression_toolkit/core/common/model_validation.py +1 -1
  32. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  33. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  34. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  35. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  36. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  37. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  38. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  39. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  50. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  51. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  52. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  54. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  55. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  56. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  57. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  58. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  59. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  60. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  61. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  62. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  63. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  65. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  66. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  67. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  68. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  69. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  72. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
  73. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  74. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  75. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  76. model_compression_toolkit/core/keras/constants.py +0 -7
  77. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  85. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  86. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  87. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  88. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  89. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  90. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  91. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  92. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  93. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  94. model_compression_toolkit/core/keras/reader/common.py +1 -1
  95. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  99. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  100. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  102. model_compression_toolkit/core/pytorch/constants.py +4 -6
  103. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  109. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  110. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  111. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  112. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  113. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  114. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  115. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  116. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  117. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  118. model_compression_toolkit/core/runner.py +7 -7
  119. model_compression_toolkit/exporter/__init__.py +5 -0
  120. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
  121. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
  124. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  125. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  126. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
  127. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
  128. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
  129. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
  130. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
  131. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
  132. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
  133. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
  135. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
  136. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
  137. model_compression_toolkit/gptq/__init__.py +6 -0
  138. model_compression_toolkit/gptq/common/gptq_config.py +57 -104
  139. model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
  140. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  141. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  142. model_compression_toolkit/gptq/common/gptq_training.py +30 -39
  143. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  144. model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
  145. model_compression_toolkit/gptq/keras/graph_info.py +8 -33
  146. model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
  147. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  148. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
  149. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  150. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
  151. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  152. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
  153. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
  154. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  155. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
  156. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  157. model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
  158. model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
  159. model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
  160. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  161. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
  162. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  163. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
  164. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  165. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
  166. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
  167. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  168. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
  169. model_compression_toolkit/gptq/runner.py +3 -2
  170. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
  171. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  172. model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
  173. model_compression_toolkit/ptq/__init__.py +3 -0
  174. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  175. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  176. model_compression_toolkit/qat/__init__.py +4 -0
  177. model_compression_toolkit/qat/common/__init__.py +1 -2
  178. model_compression_toolkit/qat/common/qat_config.py +3 -1
  179. model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
  180. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  181. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
  182. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
  183. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  184. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  185. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
  186. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
  187. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  211. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  212. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  213. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  214. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  215. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
  217. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  218. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  219. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  220. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  221. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  222. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
  223. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  224. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
  225. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  226. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
  227. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  228. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
  229. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  232. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  233. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  234. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  235. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  236. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  237. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  238. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  239. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  240. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  241. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  242. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  243. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  244. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  248. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  250. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  254. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  255. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  257. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  259. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  261. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  263. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  265. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  273. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  274. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  275. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  276. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  277. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  278. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  279. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  280. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  281. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  282. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  283. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  284. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  285. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  286. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  287. model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
  288. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
  289. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
  290. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  291. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  292. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  293. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  294. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  300. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  301. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  302. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  303. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  304. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  305. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  306. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  307. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  from typing import Tuple, List
16
16
 
17
- from model_compression_toolkit import FrameworkInfo
18
- from model_compression_toolkit.core.common import Logger
17
+ from model_compression_toolkit.core import FrameworkInfo
18
+ from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.core.common.graph.base_graph import Graph
20
20
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
21
21
 
@@ -16,12 +16,14 @@ import copy
16
16
  from abc import ABC, abstractmethod
17
17
  import numpy as np
18
18
  from typing import Callable, List, Any
19
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType
20
- from model_compression_toolkit.core.common import Graph, Logger, BaseNode
19
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
20
+ from model_compression_toolkit.core.common import Graph, BaseNode
21
21
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
22
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
23
+ from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
23
24
  from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
24
25
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
26
+ from model_compression_toolkit.logger import Logger
25
27
 
26
28
 
27
29
  class GPTQTrainer(ABC):
@@ -34,8 +36,7 @@ class GPTQTrainer(ABC):
34
36
  graph_quant: Graph,
35
37
  gptq_config: GradientPTQConfig,
36
38
  fw_impl: FrameworkImplementation,
37
- fw_info: FrameworkInfo,
38
- representative_data_gen: Callable):
39
+ fw_info: FrameworkInfo):
39
40
  """
40
41
  Build two models from a graph: A teacher network (float model) and a student network (quantized model).
41
42
  Use the dataset generator to pass images through the teacher and student networks to get intermediate
@@ -48,7 +49,6 @@ class GPTQTrainer(ABC):
48
49
  gptq_config: GradientPTQConfig with parameters about the tuning process.
49
50
  fw_impl: Framework implementation
50
51
  fw_info: Framework information
51
- representative_data_gen: Dataset to use for inputs of the models.
52
52
  """
53
53
  self.graph_float = copy.deepcopy(graph_float)
54
54
  self.graph_quant = copy.deepcopy(graph_quant)
@@ -66,10 +66,6 @@ class GPTQTrainer(ABC):
66
66
  append2output=self.compare_points,
67
67
  fw_info=self.fw_info)
68
68
 
69
- if self.gptq_config.rounding_type == RoundingType.SoftQuantizer:
70
- # dry run on the representative dataset to count number of batches
71
- self.count_num_batches_for_training(representative_data_gen)
72
-
73
69
  self.fxp_model, self.gptq_user_info = self.build_gptq_model()
74
70
 
75
71
  def get_optimizer_with_param(self,
@@ -88,8 +84,10 @@ class GPTQTrainer(ABC):
88
84
 
89
85
  w2train = [*flattened_trainable_weights]
90
86
 
87
+ quant_params_learning = self.gptq_config.gptq_quantizer_params_override.get(QUANT_PARAM_LEARNING_STR, False)
88
+
91
89
  optimizer_with_param = [(self.gptq_config.optimizer, w2train)]
92
- if self.gptq_config.train_bias or self.gptq_config.quantization_parameters_learning:
90
+ if self.gptq_config.train_bias or quant_params_learning:
93
91
  w2train_res = []
94
92
  if self.gptq_config.train_bias:
95
93
  if self.gptq_config.optimizer_bias is not None:
@@ -99,7 +97,7 @@ class GPTQTrainer(ABC):
99
97
  if self.gptq_config.optimizer_rest is None:
100
98
  Logger.error( # pragma: no cover
101
99
  "To enable bias micro training an additional optimizer is required, please define the optimizer_rest")
102
- if self.gptq_config.quantization_parameters_learning:
100
+ if quant_params_learning:
103
101
  if self.gptq_config.optimizer_quantization_parameter is not None: # Ability to override optimizer
104
102
  optimizer_with_param.append((self.gptq_config.optimizer_quantization_parameter,
105
103
  trainable_quantization_parameters))
@@ -107,25 +105,32 @@ class GPTQTrainer(ABC):
107
105
  w2train_res.extend(trainable_quantization_parameters)
108
106
  if self.gptq_config.optimizer_rest is None:
109
107
  Logger.error( # pragma: no cover
110
- "To enable bias micro training an additional optimizer is required, please define the optimizer_rest")
111
- optimizer_with_param.append((self.gptq_config.optimizer_rest, w2train_res))
108
+ "To enable quantization parameters micro training an additional optimizer is required, please define the optimizer_rest")
109
+ if len(w2train_res) > 0:
110
+ # Either bias or quantization parameters are trainable but did not provide a specific optimizer,
111
+ # so we should use optimizer_rest to train them
112
+ if self.gptq_config.optimizer_rest is None:
113
+ Logger.error( # pragma: no cover
114
+ "To enable bias or quantization parameters micro training an additional optimizer is required, please define the optimizer_rest")
115
+ optimizer_with_param.append((self.gptq_config.optimizer_rest, w2train_res))
112
116
 
113
117
  return optimizer_with_param
114
118
 
115
119
 
116
- def compute_jacobian_based_weights(self,
117
- representative_data_gen: Callable) -> np.ndarray:
120
+ def compute_hessian_based_weights(self,
121
+ representative_data_gen: Callable) -> np.ndarray:
118
122
  """
119
- Computes the jacobian-based weights using the framework's model_grad method per batch of images.
123
+ Computes the Hessian-based weights using the framework's model_grad method per batch of images.
120
124
 
121
125
  Args:
122
- representative_data_gen: Dataset used for inference to compute the jacobian-based weights.
126
+ representative_data_gen: Dataset used for inference to compute the Hessian-based weights.
123
127
 
124
128
  Returns: A vector of weights, one for each compare point,
125
129
  to be used for the loss metric weighted average computation when running GPTQ training.
126
130
  """
127
- if self.gptq_config.use_jac_based_weights:
128
- images = self._generate_images_batch(representative_data_gen, self.gptq_config.num_samples_for_loss)
131
+ if self.gptq_config.use_hessian_based_weights:
132
+ images = self._generate_images_batch(representative_data_gen,
133
+ self.gptq_config.hessian_weights_config.hessians_num_samples)
129
134
 
130
135
  model_output_replacement = self._get_model_output_replacement()
131
136
 
@@ -143,17 +148,18 @@ class GPTQTrainer(ABC):
143
148
  output_list=model_output_replacement,
144
149
  all_outputs_indices=[],
145
150
  alpha=0,
146
- norm_weights=self.gptq_config.norm_weights,
147
- n_iter=self.gptq_config.weights_n_iter)
151
+ norm_weights=self.gptq_config.hessian_weights_config.norm_weights,
152
+ n_iter=self.gptq_config.hessian_weights_config.hessians_n_iter)
148
153
  points_apprx_jacobians_weights.append(image_ip_gradients)
149
- if self.gptq_config.log_norm:
154
+ if self.gptq_config.hessian_weights_config.log_norm:
150
155
  mean_jacobian_weights = np.mean(points_apprx_jacobians_weights, axis=0)
151
156
  mean_jacobian_weights = np.where(mean_jacobian_weights != 0, mean_jacobian_weights,
152
157
  np.partition(mean_jacobian_weights, 1)[1])
153
158
  log_weights = np.log10(mean_jacobian_weights)
154
159
 
155
- # To add scaling to the normalized weights replace return statement with the following line:
156
- # return log_weights - np.min(log_weights) / (np.max(log_weights) - np.min(log_weights))
160
+ if self.gptq_config.hessian_weights_config.scale_log_norm:
161
+ return (log_weights - np.min(log_weights)) / (np.max(log_weights) - np.min(log_weights))
162
+
157
163
  return log_weights - np.min(log_weights)
158
164
  else:
159
165
  return np.mean(points_apprx_jacobians_weights, axis=0)
@@ -249,21 +255,6 @@ class GPTQTrainer(ABC):
249
255
  replacement_outputs.append(prev_node)
250
256
  return replacement_outputs
251
257
 
252
- def count_num_batches_for_training(self, representative_data_gen):
253
- """
254
- Runs a "dry-run" of the representative dataset to count the number of batches for each training epoch.
255
-
256
- Args:
257
- representative_data_gen: A callable method to retrieve images from Dataset.
258
-
259
- Returns: The number of batches for each training epoch.
260
-
261
- """
262
- num_batches = 0
263
- for _ in representative_data_gen():
264
- num_batches += 1
265
- self.gptq_config.quantizer_config.set_num_batches(num_batches)
266
-
267
258
 
268
259
  def gptq_training(graph_float: Graph,
269
260
  graph_quant: Graph,
@@ -0,0 +1,29 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Type
17
+
18
+ from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
19
+ from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
20
+ from model_compression_toolkit.gptq.keras.gptq_training import KerasGPTQTrainer
21
+
22
+
23
+ class GPTQKerasImplemantation(GPTQFrameworkImplemantation, KerasImplementation):
24
+
25
+ def get_gptq_trainer_obj(self) -> Type[KerasGPTQTrainer]:
26
+ """
27
+ Returns: Keras object of GPTQTrainer
28
+ """
29
+ return KerasGPTQTrainer
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from functools import partial
16
15
  from typing import Callable, List, Tuple, Union
17
16
 
18
17
  import tensorflow as tf
@@ -23,11 +22,11 @@ from tqdm import tqdm
23
22
  # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
24
23
  from model_compression_toolkit.core.common.user_info import UserInformation
25
24
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
26
- from model_compression_toolkit.gptq.common.gptq_constants import REGULARIZATION_VALUES
27
25
  from packaging import version
28
26
 
29
27
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
30
28
  from model_compression_toolkit.gptq.keras.quantizer.quantization_builder import quantization_builder
29
+ from model_compression_toolkit.logger import Logger
31
30
  from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
32
31
 
33
32
  if version.parse(tf.__version__) < version.parse("2.6"):
@@ -37,15 +36,15 @@ else:
37
36
 
38
37
  from model_compression_toolkit.core import common
39
38
  from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
40
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2, RoundingType
39
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
41
40
  from model_compression_toolkit.core.common import Graph
42
- from model_compression_toolkit.gptq.keras.graph_info import get_weights_for_loss, \
43
- get_soft_rounding_reg, get_gptq_trainable_parameters
41
+ from model_compression_toolkit.gptq.keras.graph_info import get_weights_for_loss, get_gptq_trainable_parameters
42
+ from model_compression_toolkit.gptq.keras.quantizer.regularization_factory import get_regularization
44
43
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
45
44
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
46
45
  import numpy as np
47
46
  import copy
48
- from model_compression_toolkit.core.keras.constants import BIAS, USE_BIAS, KERNEL
47
+ from model_compression_toolkit.core.keras.constants import BIAS, USE_BIAS
49
48
  from model_compression_toolkit import quantizers_infrastructure as qi
50
49
 
51
50
 
@@ -79,13 +78,12 @@ class KerasGPTQTrainer(GPTQTrainer):
79
78
  graph_quant,
80
79
  gptq_config,
81
80
  fw_impl,
82
- fw_info,
83
- representative_data_gen)
81
+ fw_info)
84
82
 
85
83
  self.loss_list = []
86
84
  self.input_scale = 1
87
85
 
88
- trainable_weights, bias_weights, trainable_threshold, temperature_weights = get_gptq_trainable_parameters(
86
+ trainable_weights, bias_weights, trainable_threshold = get_gptq_trainable_parameters(
89
87
  self.fxp_model,
90
88
  fw_info,
91
89
  add_bias=gptq_config.train_bias)
@@ -108,11 +106,13 @@ class KerasGPTQTrainer(GPTQTrainer):
108
106
  [len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
109
107
 
110
108
  if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
111
- common.Logger.error("Input scale mismatch between float and GPTQ networks") # pragma: no cover
109
+ Logger.error("Input scale mismatch between float and GPTQ networks") # pragma: no cover
112
110
  else:
113
111
  self.input_scale = self.gptq_user_info.input_scale
114
112
 
115
- self.weights_for_average_loss = self.compute_jacobian_based_weights(representative_data_gen)
113
+ self.weights_for_average_loss = self.compute_hessian_based_weights(representative_data_gen)
114
+
115
+ self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
116
116
 
117
117
  def _is_gptq_applicable(self,
118
118
  node: common.BaseNode) -> bool:
@@ -127,7 +127,7 @@ class KerasGPTQTrainer(GPTQTrainer):
127
127
  """
128
128
 
129
129
  if node.is_weights_quantization_enabled() and not self.fw_info.is_kernel_op(node.type):
130
- common.Logger.error(f"GPTQ Error: Quantizing node {node.name} of type {node.type} "
130
+ Logger.error(f"GPTQ Error: Quantizing node {node.name} of type {node.type} "
131
131
  f"without a kernel isn't supported")
132
132
  return node.is_weights_quantization_enabled()
133
133
 
@@ -195,9 +195,7 @@ class KerasGPTQTrainer(GPTQTrainer):
195
195
  self.compare_points_std,
196
196
  self.weights_for_average_loss)
197
197
 
198
- reg_value = self.gptq_config.quantizer_config.get_regularization_value(
199
- self.fxp_model,
200
- **{REGULARIZATION_VALUES: self._get_quantizer_regularization_values(self.gptq_config.rounding_type)})
198
+ reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor)
201
199
 
202
200
  loss_value += reg_value
203
201
 
@@ -283,7 +281,7 @@ class KerasGPTQTrainer(GPTQTrainer):
283
281
  self.gptq_config.log_function(loss_value_step, grads[0], in_optimizer_with_param[0][-1],
284
282
  self.compare_points)
285
283
  self.loss_list.append(loss_value_step.numpy())
286
- common.Logger.debug(f'last loss value: {self.loss_list[-1]}')
284
+ Logger.debug(f'last loss value: {self.loss_list[-1]}')
287
285
 
288
286
  def update_graph(self):
289
287
  """
@@ -300,7 +298,7 @@ class KerasGPTQTrainer(GPTQTrainer):
300
298
  if len(node) == 0 and isinstance(layer.layer, TensorFlowOpLayer):
301
299
  node = graph.find_node_by_name('_'.join(layer.layer.name.split('_')[3:]))
302
300
  if len(node) != 1:
303
- common.Logger.error(f"Can't update GPTQ graph due to missing layer named: {layer.layer.name}")
301
+ Logger.error(f"Can't update GPTQ graph due to missing layer named: {layer.layer.name}")
304
302
  node = node[0]
305
303
  kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
306
304
  fw_info=self.fw_info)
@@ -319,18 +317,3 @@ class KerasGPTQTrainer(GPTQTrainer):
319
317
  node.set_weights_by_keys(BIAS, new_bias)
320
318
 
321
319
  return graph
322
-
323
- def _get_quantizer_regularization_values(self, rounding_type: RoundingType) -> List[tf.Tensor]:
324
- """
325
- Mapping between a rounding type to its matching regularization method.
326
-
327
- Args:
328
- rounding_type: GPTQ rounding type.
329
-
330
- Returns: A regularization computation method.
331
-
332
- """
333
- if rounding_type == RoundingType.SoftQuantizer:
334
- return get_soft_rounding_reg(self.fxp_model)
335
- else:
336
- return []
@@ -13,23 +13,21 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
-
17
16
  import tensorflow as tf
18
17
  from typing import Tuple, List
19
-
20
18
  from model_compression_toolkit.core.keras.constants import USE_BIAS
21
19
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
20
  from tensorflow.keras.models import Model
23
-
24
21
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
25
22
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
26
23
  from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
24
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
27
25
 
28
26
 
29
27
  def get_gptq_trainable_parameters(fxp_model: Model,
30
28
  fw_info: FrameworkInfo,
31
29
  add_bias: bool = False) -> (
32
- List[tf.Variable], List[tf.Variable], List[tf.Variable], List[tf.Variable], List[tf.Variable]):
30
+ List[tf.Variable], List[tf.Variable], List[tf.Variable]):
33
31
  """
34
32
  Get trainable parameters from all layers in a model
35
33
 
@@ -45,16 +43,17 @@ def get_gptq_trainable_parameters(fxp_model: Model,
45
43
  trainable_weights: List[tf.Tensor] = []
46
44
  trainable_threshold: List[tf.Tensor] = []
47
45
  bias_weights: List[List[tf.Tensor]] = []
48
- temperature_weights: List[tf.Tensor] = []
49
46
 
50
47
  for layer in fxp_model.layers:
51
48
  if isinstance(layer, KerasQuantizationWrapper):
52
49
  kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
53
50
  fw_info=DEFAULT_KERAS_INFO)
54
51
 
55
- # collect trainable weights per layer
56
- layer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_aux_variable()
57
- layer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_quantization_variable()
52
+ # collect trainable weights per quantizer
53
+ quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
54
+ quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
55
+ trainable_weights.append(quantizer_trainable_weights)
56
+ trainable_threshold.extend(quantizer_trainable_threshold)
58
57
 
59
58
  if add_bias:
60
59
  kernel_ops_attrs = fw_info.kernel_ops_attributes_mapping.get(type(layer.layer))
@@ -62,10 +61,8 @@ def get_gptq_trainable_parameters(fxp_model: Model,
62
61
  and layer.layer.get_config().get(USE_BIAS)
63
62
  if use_bias is not None and use_bias:
64
63
  bias_weights.append([layer.layer.bias])
65
- trainable_weights.append(layer_trainable_weights)
66
- trainable_threshold.extend(layer_trainable_threshold)
67
64
 
68
- return trainable_weights, bias_weights, trainable_threshold, temperature_weights
65
+ return trainable_weights, bias_weights, trainable_threshold
69
66
 
70
67
 
71
68
  def get_weights_for_loss(fxp_model: Model) -> Tuple[List[list], List[list]]:
@@ -95,25 +92,3 @@ def get_weights_for_loss(fxp_model: Model) -> Tuple[List[list], List[list]]:
95
92
  fxp_weights_list.append(_layer_fxp_weights)
96
93
 
97
94
  return flp_weights_list, fxp_weights_list
98
-
99
-
100
- # TODO: this function need to move to location that is relevant only for soft quantizer -
101
- # once deciding how to handle GPTQ quantizers regularization.
102
- def get_soft_rounding_reg(fxp_model: Model) -> List[tf.Tensor]:
103
- """
104
- This function returns the soft quantizer regularization values for SoftRounding.
105
-
106
- Args:
107
- fxp_model: A model to be quantized with SoftRounding.
108
-
109
- Returns: A list of tensors.
110
- """
111
-
112
- soft_reg_aux: List[tf.Tensor] = []
113
- for layer in fxp_model.layers:
114
- if isinstance(layer, KerasQuantizationWrapper):
115
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
116
- fw_info=DEFAULT_KERAS_INFO)
117
-
118
- soft_reg_aux.append(layer.weights_quantizers[kernel_attribute].get_regularization())
119
- return soft_reg_aux
@@ -16,21 +16,19 @@
16
16
  from typing import Callable, Tuple
17
17
  from packaging import version
18
18
 
19
- from model_compression_toolkit.core import common
20
- from model_compression_toolkit.core.common import Logger
21
- from model_compression_toolkit.core.common.constants import TENSORFLOW
19
+ from model_compression_toolkit.logger import Logger
20
+ from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
22
21
  from model_compression_toolkit.core.common.user_info import UserInformation
23
22
  from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
24
23
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
25
24
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
26
- from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
27
- MixedPrecisionQuantizationConfigV2
28
- from model_compression_toolkit import CoreConfig
25
+ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfigV2
26
+ from model_compression_toolkit.core import CoreConfig
29
27
  from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
30
28
  from model_compression_toolkit.gptq.runner import gptq_runner
31
29
  from model_compression_toolkit.core.exporter import export_model
32
30
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
33
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework import TargetPlatformCapabilities
31
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
34
32
 
35
33
  LR_DEFAULT = 0.15
36
34
  LR_REST_DEFAULT = 1e-4
@@ -38,14 +36,14 @@ LR_BIAS_DEFAULT = 1e-4
38
36
  LR_QUANTIZATION_PARAM_DEFAULT = 1e-3
39
37
  GPTQ_MOMENTUM = 0.9
40
38
 
41
- if common.constants.FOUND_TF:
39
+ if FOUND_TF:
42
40
  import tensorflow as tf
43
41
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
44
- from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
42
+ from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
45
43
  from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
46
44
  from tensorflow.keras.models import Model
47
45
  from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss
48
- from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL
46
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
49
47
  from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
50
48
  from model_compression_toolkit import get_target_platform_capabilities
51
49
 
@@ -62,7 +60,8 @@ if common.constants.FOUND_TF:
62
60
  optimizer: OptimizerV2 = tf.keras.optimizers.Adam(learning_rate=LR_DEFAULT),
63
61
  optimizer_rest: OptimizerV2 = tf.keras.optimizers.Adam(learning_rate=LR_REST_DEFAULT),
64
62
  loss: Callable = GPTQMultipleTensorsLoss(),
65
- log_function: Callable = None) -> GradientPTQConfigV2:
63
+ log_function: Callable = None,
64
+ use_hessian_based_weights: bool = True) -> GradientPTQConfigV2:
66
65
  """
67
66
  Create a GradientPTQConfigV2 instance for Keras models.
68
67
 
@@ -72,6 +71,7 @@ if common.constants.FOUND_TF:
72
71
  optimizer_rest (OptimizerV2): Keras optimizer to use for fine-tuning of the bias variable.
73
72
  loss (Callable): loss to use during fine-tuning. should accept 4 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors, the 3rd is a list of quantized weights and the 4th is a list of float weights.
74
73
  log_function (Callable): Function to log information about the gptq process.
74
+ use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
75
75
 
76
76
  returns:
77
77
  a GradientPTQConfigV2 object to use when fine-tuning the quantized model using gptq.
@@ -85,24 +85,25 @@ if common.constants.FOUND_TF:
85
85
 
86
86
  Create a GradientPTQConfigV2 to run for 5 epochs:
87
87
 
88
- >>> gptq_conf = mct.get_keras_gptq_config(n_epochs=5)
88
+ >>> gptq_conf = mct.gptq.get_keras_gptq_config(n_epochs=5)
89
89
 
90
90
  Other Tensorflow optimizers can be passed:
91
91
 
92
- >>> gptq_conf = mct.get_keras_gptq_config(n_epochs=3, optimizer=tf.keras.optimizers.Nadam())
92
+ >>> gptq_conf = mct.gptq.get_keras_gptq_config(n_epochs=3, optimizer=tf.keras.optimizers.Nadam())
93
93
 
94
94
  The configuration can be passed to :func:`~model_compression_toolkit.keras_post_training_quantization` in order to quantize a keras model using gptq.
95
95
 
96
96
  """
97
- bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
97
+ bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT,
98
+ momentum=GPTQ_MOMENTUM)
98
99
  return GradientPTQConfigV2(n_epochs,
99
100
  optimizer,
100
101
  optimizer_rest=optimizer_rest,
101
102
  loss=loss,
102
103
  log_function=log_function,
103
104
  train_bias=True,
104
- quantization_parameters_learning=True,
105
- optimizer_bias=bias_optimizer)
105
+ optimizer_bias=bias_optimizer,
106
+ use_hessian_based_weights=use_hessian_based_weights)
106
107
 
107
108
 
108
109
  def keras_gradient_post_training_quantization_experimental(in_model: Model,
@@ -164,28 +165,28 @@ if common.constants.FOUND_TF:
164
165
 
165
166
  Create an MCT core config, containing the quantization configuration:
166
167
 
167
- >>> config = mct.CoreConfig()
168
+ >>> config = mct.core.CoreConfig()
168
169
 
169
170
  If mixed precision is desired, create an MCT core config with a mixed-precision configuration, to quantize a model
170
171
  with different bitwidths for different layers.
171
172
  The candidates bitwidth for quantization should be defined in the target platform model:
172
173
 
173
- >>> config = mct.CoreConfig(mixed_precision_config=mct.MixedPrecisionQuantizationConfigV2(num_of_images=1))
174
+ >>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfigV2(num_of_images=1))
174
175
 
175
176
  For mixed-precision set a target KPI object:
176
177
  Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
177
178
  that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
178
179
  while the bias will not):
179
180
 
180
- >>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
181
+ >>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
181
182
 
182
183
  Create GPTQ config:
183
184
 
184
- >>> gptq_config = mct.get_keras_gptq_config(n_epochs=1)
185
+ >>> gptq_config = mct.gptq.get_keras_gptq_config(n_epochs=1)
185
186
 
186
187
  Pass the model with the representative dataset generator to get a quantized model:
187
188
 
188
- >>> quantized_model, quantization_info = mct.keras_gradient_post_training_quantization_experimental(model, repr_datagen, gptq_config, target_kpi=kpi, core_config=config)
189
+ >>> quantized_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization_experimental(model, repr_datagen, gptq_config, target_kpi=kpi, core_config=config)
189
190
 
190
191
  """
191
192
  KerasModelValidation(model=in_model,
@@ -193,15 +194,15 @@ if common.constants.FOUND_TF:
193
194
 
194
195
  if core_config.mixed_precision_enable:
195
196
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
196
- common.Logger.error("Given quantization config to mixed-precision facade is not of type "
197
+ Logger.error("Given quantization config to mixed-precision facade is not of type "
197
198
  "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
198
199
  "API, or pass a valid mixed precision configuration.") # pragma: no cover
199
200
 
200
- common.Logger.info("Using experimental mixed-precision quantization. "
201
+ Logger.info("Using experimental mixed-precision quantization. "
201
202
  "If you encounter an issue please file a bug.")
202
203
  tb_w = _init_tensorboard_writer(fw_info)
203
204
 
204
- fw_impl = KerasImplementation()
205
+ fw_impl = GPTQKerasImplemantation()
205
206
 
206
207
  tg, bit_widths_config = core_runner(in_model=in_model,
207
208
  representative_data_gen=representative_data_gen,
@@ -15,3 +15,4 @@
15
15
 
16
16
  import model_compression_toolkit.gptq.keras.quantizer.ste_rounding.symmetric_ste
17
17
  import model_compression_toolkit.gptq.keras.quantizer.soft_rounding.symmetric_soft_quantizer
18
+ import model_compression_toolkit.gptq.keras.quantizer.soft_rounding.uniform_soft_quantizer
@@ -15,8 +15,8 @@
15
15
  from abc import abstractmethod
16
16
  from typing import Union, Dict, List
17
17
 
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import FOUND_TF
20
20
  from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
21
21
 
22
22
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
@@ -45,7 +45,6 @@ if FOUND_TF:
45
45
 
46
46
  super().__init__(quantization_config)
47
47
 
48
- self.quantizer_parameters = None
49
48
 
50
49
  def update_layer_quantization_params(self, layer: KerasQuantizationWrapper
51
50
  ) -> (Dict[str, tf.Tensor], Dict[str, Dict], Dict):
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import tensorflow as tf
17
- from model_compression_toolkit.core.common.constants import MIN_THRESHOLD
17
+ from model_compression_toolkit.constants import MIN_THRESHOLD
18
18
  from typing import Tuple
19
19
 
20
20
 
@@ -26,6 +26,14 @@ def ste_ceil(x: tf.Tensor) -> tf.Tensor:
26
26
  return error + x
27
27
 
28
28
 
29
+ def ste_floor(x: tf.Tensor) -> tf.Tensor:
30
+ """
31
+ Return the floor values of a tensor.
32
+ """
33
+ error = tf.stop_gradient(tf.math.floor(x) - x)
34
+ return error + x
35
+
36
+
29
37
  def safe_log(x: tf.Tensor, eps: float) -> tf.Tensor:
30
38
  """
31
39
  Computes log function of x unless x is smaller than some small value, so the log function would not fail.
@@ -72,6 +80,15 @@ def calculate_delta(max_tensor: tf.Tensor,
72
80
  return max_tensor / (2 ** (num_bits - int(signed)))
73
81
 
74
82
 
83
+ def calculate_delta_uniform(min_tensor: tf.Tensor,
84
+ max_tensor: tf.Tensor,
85
+ num_bits: int) -> tf.Tensor:
86
+ """
87
+ Compute the step size for the uniform quantization.
88
+ """
89
+ return (max_tensor-min_tensor) / (2 ** num_bits - 1)
90
+
91
+
75
92
  def ste_clip(x: [tf.Tensor, tf.Variable], max_val=1, min_val=None) -> tf.Tensor:
76
93
  """
77
94
  clip a variable between fixed values such that min_val<=output<=max_val
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import Dict, List, Tuple
16
16
 
17
- from model_compression_toolkit import GradientPTQConfigV2
17
+ from model_compression_toolkit.gptq import GradientPTQConfigV2
18
18
  from model_compression_toolkit.core import common
19
19
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
20
20
  from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \
@@ -61,7 +61,7 @@ def quantization_builder(n: common.BaseNode,
61
61
  fw_info=DEFAULT_KERAS_INFO)
62
62
 
63
63
  weights_quantizers.update({kernel_attribute: quantizer_class(get_trainable_quantizer_weights_config(n),
64
- **gptq_config.get_extended_quantizer_parametes())})
64
+ **gptq_config.gptq_quantizer_params_override)})
65
65
 
66
66
  activation_quantizers = []
67
67
  if n.is_activation_quantization_enabled():