mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
  2. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
  3. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +12 -41
  5. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  6. model_compression_toolkit/core/__init__.py +14 -0
  7. model_compression_toolkit/core/analyzer.py +3 -2
  8. model_compression_toolkit/core/common/__init__.py +0 -1
  9. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  11. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  12. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  13. model_compression_toolkit/core/common/framework_info.py +1 -1
  14. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  15. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  16. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  18. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  19. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  20. model_compression_toolkit/core/common/memory_computation.py +1 -1
  21. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  23. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  26. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  28. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  29. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  30. model_compression_toolkit/core/common/model_collector.py +2 -2
  31. model_compression_toolkit/core/common/model_validation.py +1 -1
  32. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  33. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  34. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  35. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  36. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  37. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  38. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  39. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  50. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  51. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  52. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  54. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  55. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  56. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  57. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  58. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  59. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  60. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  61. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  62. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  63. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  65. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  66. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  67. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  68. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  69. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  72. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
  73. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  74. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  75. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  76. model_compression_toolkit/core/keras/constants.py +0 -7
  77. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  85. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  86. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  87. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  88. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  89. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  90. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  91. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  92. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  93. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  94. model_compression_toolkit/core/keras/reader/common.py +1 -1
  95. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  99. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  100. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  102. model_compression_toolkit/core/pytorch/constants.py +4 -6
  103. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  109. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  110. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  111. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  112. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  113. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  114. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  115. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  116. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  117. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  118. model_compression_toolkit/core/runner.py +7 -7
  119. model_compression_toolkit/exporter/__init__.py +5 -0
  120. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
  121. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
  124. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  125. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  126. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
  127. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
  128. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
  129. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
  130. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
  131. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
  132. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
  133. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
  135. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
  136. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
  137. model_compression_toolkit/gptq/__init__.py +6 -0
  138. model_compression_toolkit/gptq/common/gptq_config.py +57 -104
  139. model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
  140. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  141. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  142. model_compression_toolkit/gptq/common/gptq_training.py +30 -39
  143. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  144. model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
  145. model_compression_toolkit/gptq/keras/graph_info.py +8 -33
  146. model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
  147. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  148. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
  149. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  150. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
  151. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  152. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
  153. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
  154. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  155. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
  156. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  157. model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
  158. model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
  159. model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
  160. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  161. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
  162. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  163. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
  164. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  165. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
  166. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
  167. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  168. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
  169. model_compression_toolkit/gptq/runner.py +3 -2
  170. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
  171. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  172. model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
  173. model_compression_toolkit/ptq/__init__.py +3 -0
  174. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  175. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  176. model_compression_toolkit/qat/__init__.py +4 -0
  177. model_compression_toolkit/qat/common/__init__.py +1 -2
  178. model_compression_toolkit/qat/common/qat_config.py +3 -1
  179. model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
  180. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  181. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
  182. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
  183. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  184. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  185. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
  186. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
  187. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  211. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  212. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  213. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  214. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  215. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
  217. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  218. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  219. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  220. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  221. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  222. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
  223. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  224. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
  225. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  226. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
  227. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  228. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
  229. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  232. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  233. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  234. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  235. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  236. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  237. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  238. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  239. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  240. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  241. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  242. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  243. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  244. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  248. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  250. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  254. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  255. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  257. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  259. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  261. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  263. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  265. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  273. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  274. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  275. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  276. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  277. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  278. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  279. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  280. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  281. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  282. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  283. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  284. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  285. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  286. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  287. model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
  288. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
  289. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
  290. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  291. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  292. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  293. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  294. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  300. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  301. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  302. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  303. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  304. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  305. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  306. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  307. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -0,0 +1,45 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Callable
16
+
17
+ from model_compression_toolkit.gptq import RoundingType, GradientPTQConfigV2, GradientPTQConfig
18
+ from model_compression_toolkit.gptq.keras.quantizer.soft_rounding.soft_quantizer_reg import \
19
+ SoftQuantizerRegularization
20
+
21
+
22
+ def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen: Callable) -> Callable:
23
+ """
24
+ Returns a function that computes the regularization term for GPTQ training based on the given
25
+ rounding type in the GPTQ configuration.
26
+
27
+ Args:
28
+ gptq_config: A GPTQ configuration.
29
+ representative_data_gen: Dataset used for the GPTQ training.
30
+
31
+ Returns: A function for computing the regularization. If there is no regularization function defined for the given
32
+ rounding type, then it returns a function that just returns 0.
33
+
34
+ """
35
+ if gptq_config.rounding_type == RoundingType.SoftQuantizer:
36
+ # dry run on the representative dataset to count number of batches
37
+ num_batches = 0
38
+ for _ in representative_data_gen():
39
+ num_batches += 1
40
+
41
+ n_epochs = GradientPTQConfigV2.from_v1(n_ptq_iter=num_batches, config_v1=gptq_config).n_epochs if \
42
+ not type(gptq_config) == GradientPTQConfigV2 else gptq_config.n_epochs
43
+ return SoftQuantizerRegularization(total_gradient_steps=num_batches * n_epochs)
44
+ else:
45
+ return lambda m, e_reg: 0
@@ -0,0 +1,110 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List
16
+
17
+ import tensorflow as tf
18
+ from keras import Model
19
+
20
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
21
+ from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
22
+ from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
23
+
24
+
25
+ class LinearTempDecay:
26
+ """
27
+ Annealing process for the soft quantizer regularization temperature term.
28
+ """
29
+
30
+ def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 20, end_b: int = 2):
31
+ """
32
+ Initializes a LinearTempDecay object.
33
+
34
+ Args:
35
+ t_max: maximal time step.
36
+ rel_start_decay: Decay step size at the beginning of the process.
37
+ start_b: Starting value of the regularization term.
38
+ end_b: Target value of the regularization term.
39
+ """
40
+
41
+ self.t_max = t_max
42
+ self.start_decay = rel_start_decay * t_max
43
+ self.start_b = start_b
44
+ self.end_b = end_b
45
+
46
+ def __call__(self, t: int) -> float:
47
+ """
48
+ Cosine annealing scheduler for soft quantizer regularization temperature term.
49
+
50
+ Args:
51
+ t: The current time step.
52
+
53
+ Returns: Scheduled temperature.
54
+ """
55
+
56
+ is_before_start_decay = tf.cast(t < self.start_decay, tf.float32)
57
+
58
+ rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
59
+
60
+ return self.start_b * is_before_start_decay + \
61
+ (1 - is_before_start_decay) * \
62
+ (self.end_b + (self.start_b - self.end_b) * tf.math.maximum(0.0, (1 - rel_t)))
63
+
64
+
65
+ class SoftQuantizerRegularization:
66
+ """
67
+ A class to handle the computation of soft quantizer regularization for GPTQ training.
68
+ """
69
+
70
+ def __init__(self, total_gradient_steps: int):
71
+ """
72
+ Initializes the regularization computation object with a LinearDecay object.
73
+
74
+ Args:
75
+ total_gradient_steps: The number of gradient steps during optimization.
76
+ """
77
+ # Initializing the temperature decay according to the number of expected gradient steps
78
+ self.linear_decay = LinearTempDecay(total_gradient_steps)
79
+
80
+ self.count_iter = tf.Variable(0.)
81
+
82
+
83
+ def __call__(self, model: Model, entropy_reg: float):
84
+ """
85
+ Returns the soft quantizer regularization value for SoftRounding.
86
+
87
+ Args:
88
+ model: A model to be quantized with SoftRounding.
89
+ entropy_reg: Entropy value to scale the quantizer regularization.
90
+
91
+ Returns: Regularization value.
92
+ """
93
+ soft_reg_aux: List[tf.Tensor] = []
94
+ b = self.linear_decay(self.count_iter.value())
95
+ for layer in model.layers:
96
+ if isinstance(layer, KerasQuantizationWrapper):
97
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
98
+ fw_info=DEFAULT_KERAS_INFO)
99
+
100
+ st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
101
+ soft_reg_aux.append(tf.reduce_sum(1 - tf.pow(tf.math.abs(st - .5) * 2, b)))
102
+
103
+ reg = 0
104
+
105
+ for sq in soft_reg_aux:
106
+ reg += sq
107
+
108
+ self.count_iter.assign_add(1.0)
109
+
110
+ return entropy_reg * reg
@@ -16,22 +16,22 @@
16
16
  import tensorflow as tf
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit import RoundingType
19
+ from model_compression_toolkit.gptq import RoundingType
20
20
  from model_compression_toolkit import quantizers_infrastructure as qi
21
21
  from model_compression_toolkit.core.common import max_power_of_two
22
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
23
- from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, N_EPOCHS, \
24
- MAX_ITERATIONS_DEFAULT, SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, SOFT_ROUNDING_BETA, GPTQ_ITER, AUXVAR
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
+ from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
24
+ SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
25
25
  from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
26
- from typing import Dict, Any, List
27
- from model_compression_toolkit.core.common.constants import THRESHOLD, MIN_THRESHOLD
28
- from model_compression_toolkit.core.common.logger import Logger
26
+ from typing import Dict, Any
27
+ from model_compression_toolkit.constants import THRESHOLD, MIN_THRESHOLD
29
28
  from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
30
29
  from model_compression_toolkit.gptq.keras.quantizer.quant_utils import power_of_two_max, clip, calculate_delta
31
30
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
32
31
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
33
32
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
34
33
  get_threshold_reshape_shape
34
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
35
35
 
36
36
 
37
37
  def soft_rounding_symmetric_quantizer(input_tensor: tf.Tensor,
@@ -66,46 +66,6 @@ def soft_rounding_symmetric_quantizer(input_tensor: tf.Tensor,
66
66
  return delta * clip(tensor_q, max_val=max_int, min_val=min_int)
67
67
 
68
68
 
69
- class LinearTempDecay:
70
- """
71
- Annealing process for the soft quantizer regularization temperature term.
72
- """
73
-
74
- def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 20, end_b: int = 2):
75
- """
76
- Initializes a LinearTempDecay object.
77
-
78
- Args:
79
- t_max: maximal time step.
80
- rel_start_decay: Decay step size at the beginning of the process.
81
- start_b: Starting value of the regularization term.
82
- end_b: Target value of the regularization term.
83
- """
84
-
85
- self.t_max = t_max
86
- self.start_decay = rel_start_decay * t_max
87
- self.start_b = start_b
88
- self.end_b = end_b
89
-
90
- def __call__(self, t: int) -> float:
91
- """
92
- Cosine annealing scheduler for soft quantizer regularization temperature term.
93
-
94
- Args:
95
- t: The current time step.
96
-
97
- Returns: Scheduled temperature.
98
- """
99
-
100
- is_before_start_decay = tf.cast(t < self.start_decay, tf.float32)
101
-
102
- rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
103
-
104
- return self.start_b * is_before_start_decay + \
105
- (1 - is_before_start_decay) * \
106
- (self.end_b + (self.start_b - self.end_b) * tf.math.maximum(0.0, (1 - rel_t)))
107
-
108
-
109
69
  @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
110
70
  quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
111
71
  quantizer_type=RoundingType.SoftQuantizer)
@@ -116,23 +76,15 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
116
76
 
117
77
  def __init__(self,
118
78
  quantization_config: TrainableQuantizerWeightsConfig,
119
- n_batches: int = None,
120
- quantization_parameter_learning: bool = False,
121
- n_epochs: int = N_EPOCHS):
79
+ quantization_parameter_learning: bool = False):
122
80
  """
123
81
  Initialize a SymmetricSoftRoundingGPTQ object with parameters to use
124
82
  for the quantization.
125
83
 
126
84
  Args:
127
85
  quantization_config: Trainable weights quantizer config.
128
- n_batches: The expected number of batches for each training epoch.
129
86
  quantization_parameter_learning: Whether to train the quantization threshold.
130
- n_epochs: Number of epochs to run training for.
131
87
  """
132
-
133
- if n_batches is None:
134
- Logger.error("SymmetricSoftRoundingGPTQ got an uninitialized n_batches argument.")
135
-
136
88
  super().__init__(quantization_config)
137
89
  self.num_bits = quantization_config.weights_n_bits
138
90
  self.per_channel = quantization_config.weights_per_channel_threshold
@@ -148,32 +100,23 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
148
100
  self.num_channels = len(self.threshold_values) if self.per_channel else 1
149
101
 
150
102
  # gamma and zeta are stretch parameters for computing the rectified sigmoind function.
151
- # beta is used to set the regularization term.
152
103
  # See: https://arxiv.org/pdf/2004.10568.pdf
153
104
  self.gamma = SOFT_ROUNDING_GAMMA
154
105
  self.zeta = SOFT_ROUNDING_ZETA
155
- self.beta = SOFT_ROUNDING_BETA
156
106
 
157
107
  self.quantizer_parameters = {}
158
108
 
159
- # Initializing the temperature decay according to the number of expected gradient steps
160
- init_decay = MAX_ITERATIONS_DEFAULT if n_batches is None else n_epochs * n_batches
161
- self.linear_decay = LinearTempDecay(init_decay)
162
-
163
109
  def initialize_quantization(self,
164
110
  tensor_shape: Any,
165
111
  name: str,
166
- layer: Any) -> Dict[Any, Any]:
112
+ layer: Any):
167
113
  """
168
- Return a dictionary of quantizer parameters and their names.
114
+ Add quantizer parameters to the quantizer parameters dictionary
169
115
 
170
116
  Args:
171
117
  tensor_shape: tensor shape of the quantized tensor.
172
118
  name: Tensor name.
173
119
  layer: Layer to quantize.
174
-
175
- Returns:
176
- Dictionary of parameters names to the variables.
177
120
  """
178
121
 
179
122
  if self.per_channel:
@@ -183,12 +126,6 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
183
126
  else:
184
127
  reshape_shape = [self.num_channels]
185
128
 
186
- ar_iter = layer.add_weight(
187
- f"{name}_{GPTQ_ITER}",
188
- shape=(),
189
- initializer=tf.keras.initializers.Constant(0.0),
190
- trainable=False)
191
-
192
129
  ptq_threshold_tensor = layer.add_weight(
193
130
  f"{name}_{PTQ_THRESHOLD}",
194
131
  shape=reshape_shape,
@@ -212,44 +149,17 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
212
149
 
213
150
  auxvar_tensor.assign(alpha)
214
151
 
215
- self.quantizer_parameters.update({AUXVAR: auxvar_tensor,
216
- PTQ_THRESHOLD: ptq_threshold_tensor,
217
- GPTQ_ITER: ar_iter})
152
+ # Add quantization variables
153
+ self.add_quantizer_variable(AUXVAR, auxvar_tensor, VariableGroup.WEIGHTS)
154
+ self.add_quantizer_variable(PTQ_THRESHOLD, ptq_threshold_tensor, VariableGroup.QPARAMS)
218
155
 
219
- if self.quantization_parameter_learning:
156
+ if self.quantization_parameter_learning and not self.power_of_two:
220
157
  scale = layer.add_weight(
221
158
  f"{name}_{SCALE_PTQ}",
222
159
  shape=self.num_channels,
223
160
  initializer=tf.keras.initializers.Constant(1.0),
224
161
  trainable=True)
225
- self.quantizer_parameters.update({SCALE_PTQ: scale})
226
-
227
- return self.quantizer_parameters
228
-
229
- def get_quantization_variable(self) -> List[tf.Tensor]:
230
- """
231
- This function return a list with the quantizer's quantization parameters variables.
232
-
233
- Returns: A list with the quantization parameters if there are defined parameters.
234
-
235
- """
236
-
237
- if self.quantization_parameter_learning and not self.power_of_two:
238
- return [self.quantizer_parameters[SCALE_PTQ]]
239
- else:
240
- return []
241
-
242
- def get_regularization(self) -> tf.Tensor:
243
- """
244
- Computes the regularization term for the soft rounding loss.
245
-
246
- Returns:
247
- regularization term.
248
- """
249
-
250
- st = self.get_soft_targets()
251
- b = self.linear_decay(self.ar_iter.value())
252
- return tf.reduce_sum(1 - tf.pow(tf.math.abs(st - .5) * 2, b))
162
+ self.add_quantizer_variable(SCALE_PTQ, scale, VariableGroup.QPARAMS)
253
163
 
254
164
  def get_soft_targets(self) -> tf.Tensor:
255
165
  """
@@ -260,16 +170,7 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
260
170
 
261
171
  """
262
172
  return qutils.clip(
263
- tf.sigmoid(self.quantizer_parameters[AUXVAR]) * (self.zeta - self.gamma) + self.gamma, 1, 0)
264
-
265
- def get_aux_variable(self) -> List[tf.Tensor]:
266
- """
267
- This function return a list with the quantizer's quantization auxiliary variables.
268
-
269
- Returns: A list with the quantization auxiliary variables.
270
-
271
- """
272
- return [self.quantizer_parameters[AUXVAR]]
173
+ tf.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma, 1, 0)
273
174
 
274
175
  def __call__(self,
275
176
  inputs: tf.Tensor,
@@ -285,8 +186,14 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
285
186
  The quantized tensor.
286
187
  """
287
188
 
288
- self.ar_iter = self.quantizer_parameters[GPTQ_ITER]
289
- ptq_threshold_tensor = self.quantizer_parameters[PTQ_THRESHOLD]
189
+ ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
190
+
191
+ #####################################################
192
+ # Soft Rounding
193
+ #####################################################
194
+ aux_var = self.get_soft_targets()
195
+ if not training:
196
+ aux_var = tf.cast(tf.math.greater_equal(aux_var, 0.5), tf.float32)
290
197
 
291
198
  if self.per_channel:
292
199
  reshape_shape = get_threshold_reshape_shape(inputs.shape,
@@ -297,15 +204,6 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
297
204
  # Calculate soft rounding targets and optimized threshold
298
205
  ##########################################################
299
206
  ptq_threshold_tensor_hat = tf.reshape(ptq_threshold_tensor, reshape_shape)
300
- aux_var = self.get_soft_targets()
301
-
302
- #####################################################
303
- # Soft Rounding
304
- #####################################################
305
- if training:
306
- self.ar_iter.assign_add(1.0)
307
- else:
308
- aux_var = tf.cast(tf.math.greater_equal(aux_var, 0.5), tf.float32)
309
207
 
310
208
  #####################################################
311
209
  # Quantized Input
@@ -318,17 +216,22 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
318
216
  power_of_two=self.power_of_two)
319
217
 
320
218
  if self.quantization_parameter_learning and not self.power_of_two:
321
- scale = tf.reshape(self.quantizer_parameters[SCALE_PTQ], reshape_shape)
219
+ scale = tf.reshape(self.get_quantizer_variable(SCALE_PTQ), reshape_shape)
322
220
  q_tensor *= scale
323
221
 
324
- return q_tensor
325
222
  else:
326
- return soft_rounding_symmetric_quantizer(input_tensor=inputs,
327
- auxvar_tensor=self.quantizer_parameters[AUXVAR],
328
- threshold_tensor=ptq_threshold_tensor.value(),
329
- num_bits=self.num_bits,
330
- signed=True,
331
- power_of_two=self.power_of_two)
223
+ q_tensor = soft_rounding_symmetric_quantizer(input_tensor=inputs,
224
+ auxvar_tensor=aux_var,
225
+ threshold_tensor=ptq_threshold_tensor.value(),
226
+ num_bits=self.num_bits,
227
+ signed=True,
228
+ power_of_two=self.power_of_two)
229
+
230
+ if self.quantization_parameter_learning and not self.power_of_two:
231
+ scale = self.get_quantizer_variable(SCALE_PTQ)
232
+ q_tensor *= scale
233
+
234
+ return q_tensor
332
235
 
333
236
  def get_quant_config(self) -> Dict[str, np.ndarray]:
334
237
  """
@@ -340,13 +243,13 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
340
243
  """
341
244
 
342
245
  if self.power_of_two:
343
- old_threshold = self.quantizer_parameters[PTQ_THRESHOLD]
246
+ old_threshold = self.get_quantizer_variable(PTQ_THRESHOLD)
344
247
  old_threshold = max_power_of_two(old_threshold, MIN_THRESHOLD)
345
248
 
346
249
  else:
347
- old_threshold = self.quantizer_parameters[PTQ_THRESHOLD]
250
+ old_threshold = self.get_quantizer_variable(PTQ_THRESHOLD)
348
251
  if self.quantization_parameter_learning:
349
- scale = tf.reshape(self.quantizer_parameters[SCALE_PTQ], self.threshold_shape)
252
+ scale = tf.reshape(self.get_quantizer_variable(SCALE_PTQ), self.threshold_shape)
350
253
  old_threshold = old_threshold * scale
351
254
  old_threshold = old_threshold.numpy()
352
255
  old_threshold = old_threshold.reshape(self.threshold_shape)
@@ -0,0 +1,224 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import tensorflow as tf
17
+ import numpy as np
18
+
19
+ from model_compression_toolkit.gptq import RoundingType
20
+ from model_compression_toolkit import quantizers_infrastructure as qi
21
+ from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
+ from model_compression_toolkit.gptq.common.gptq_constants import \
24
+ SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
25
+ from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
26
+ from typing import Dict, Any
27
+ from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX
28
+ from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
29
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
30
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
31
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
32
+ get_threshold_reshape_shape
33
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
34
+
35
+
36
+ def soft_rounding_uniform_quantizer(input_tensor: tf.Tensor,
37
+ auxvar_tensor: tf.Variable,
38
+ min_tensor: tf.Tensor,
39
+ max_tensor: tf.Tensor,
40
+ num_bits: int) -> tf.Tensor:
41
+ """
42
+ Quantize a tensor uniformly for GPTQ quantizers.
43
+
44
+ Args:
45
+ input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
46
+ auxvar_tensor: Tensor that manifests the bit shift of the quantized weights due to gptq training.
47
+ min_tensor: Tensor with values to compute the min threshold.
48
+ max_tensor: Tensor with values to compute the max threshold.
49
+ num_bits: Num of bits to use.
50
+
51
+ Returns:
52
+ A quantized tensor.
53
+ """
54
+ # adjusts the quantization range so the quantization grid includes zero.
55
+ min_range, max_range = qutils.fix_range_to_include_zero(min_tensor, max_tensor, num_bits)
56
+ delta = qutils.calculate_delta_uniform(min_range, max_range, num_bits)
57
+ input_tensor_int = qutils.ste_floor((input_tensor - min_range) / delta)
58
+ tensor_q = input_tensor_int + auxvar_tensor
59
+ return delta * qutils.ste_clip(tensor_q,
60
+ min_val=0,
61
+ max_val=2 ** num_bits - 1) + min_range
62
+
63
+
64
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
65
+ quantization_method=[QuantizationMethod.UNIFORM],
66
+ quantizer_type=RoundingType.SoftQuantizer)
67
+ class UniformSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
68
+ """
69
+ Trainable uniform quantizer to optimize the rounding of the quantized values using a soft quantization method.
70
+ """
71
+
72
+ def __init__(self,
73
+ quantization_config: TrainableQuantizerWeightsConfig,
74
+ quantization_parameter_learning: bool = False):
75
+ """
76
+ Initialize a UniformSoftRoundingGPTQ object with parameters to use
77
+ for the quantization.
78
+
79
+ Args:
80
+ quantization_config: Trainable weight quantizer config.
81
+ quantization_parameter_learning: Whether to train the quantization threshold.
82
+ """
83
+ super().__init__(quantization_config)
84
+ self.num_bits = quantization_config.weights_n_bits
85
+ self.per_channel = quantization_config.weights_per_channel_threshold
86
+
87
+ self.min_values = quantization_config.weights_quantization_params[RANGE_MIN]
88
+ self.max_values = quantization_config.weights_quantization_params[RANGE_MAX]
89
+
90
+ self.quantization_axis = quantization_config.weights_channels_axis
91
+ assert quantization_parameter_learning is False, \
92
+ "Quantization parameters learning in UniformSoftRoundingGPTQ not implemented yet"
93
+ self.quantization_parameter_learning = quantization_parameter_learning
94
+ self.num_channels = self.min_values.shape[self.quantization_axis] if self.per_channel else 1
95
+
96
+ # gamma and zeta are stretch parameters for computing the rectified sigmoid function.
97
+ # See: https://arxiv.org/pdf/2004.10568.pdf
98
+ self.gamma = SOFT_ROUNDING_GAMMA
99
+ self.zeta = SOFT_ROUNDING_ZETA
100
+
101
+ def initialize_quantization(self,
102
+ tensor_shape: Any,
103
+ name: str,
104
+ layer: Any):
105
+ """
106
+ Add quantizer parameters to the quantizer parameters dictionary
107
+
108
+ Args:
109
+ tensor_shape: tensor shape of the quantized tensor.
110
+ name: Tensor name.
111
+ layer: Layer to quantize.
112
+ """
113
+
114
+ if self.per_channel:
115
+ reshape_shape = get_threshold_reshape_shape(tensor_shape,
116
+ quant_axis=self.quantization_axis,
117
+ quant_axis_dim=self.num_channels)
118
+ else:
119
+ reshape_shape = [self.num_channels]
120
+
121
+ min_tensor = layer.add_weight(
122
+ f"{name}_{FQ_MIN}",
123
+ shape=reshape_shape,
124
+ initializer=tf.keras.initializers.Constant(1.0),
125
+ trainable=False)
126
+ min_tensor.assign(self.min_values.reshape(reshape_shape))
127
+
128
+ max_tensor = layer.add_weight(
129
+ f"{name}_{FQ_MAX}",
130
+ shape=reshape_shape,
131
+ initializer=tf.keras.initializers.Constant(1.0),
132
+ trainable=False)
133
+ max_tensor.assign(self.max_values.reshape(reshape_shape))
134
+
135
+ w = getattr(layer.layer, name)
136
+ auxvar_tensor = layer.add_weight(
137
+ f"{name}_{AUXVAR}",
138
+ shape=list(w.shape),
139
+ initializer=tf.keras.initializers.Constant(0.0),
140
+ trainable=True)
141
+
142
+ w = layer.layer.depthwise_kernel if isinstance(layer.layer, (tf.keras.layers.DepthwiseConv2D,
143
+ tf.keras.layers.DepthwiseConv1D)) \
144
+ else layer.layer.kernel
145
+ delta = qutils.calculate_delta_uniform(min_tensor, max_tensor, self.num_bits)
146
+ w_clipped_normed = qutils.clip((w - min_tensor)/ delta, 0, 2 ** self.num_bits - 1)
147
+ rest = w_clipped_normed - tf.floor(w_clipped_normed) # rest of rounding [0, 1)
148
+ alpha = -qutils.safe_log((self.zeta - self.gamma) / (rest - self.gamma) - 1, 1e-16) # => sigmoid(alpha) = rest
149
+ auxvar_tensor.assign(alpha)
150
+
151
+ # Add quantization variables
152
+ self.add_quantizer_variable(AUXVAR, auxvar_tensor, VariableGroup.WEIGHTS)
153
+ self.add_quantizer_variable(RANGE_MIN, min_tensor, VariableGroup.QPARAMS)
154
+ self.add_quantizer_variable(RANGE_MAX, max_tensor, VariableGroup.QPARAMS)
155
+
156
+ def get_soft_targets(self) -> tf.Tensor:
157
+ """
158
+ Computes the rectified sigmoid function for the quantization target parameters.
159
+
160
+ Returns:
161
+ A tensor with the soft rounding targets values.
162
+
163
+ """
164
+ return qutils.clip(
165
+ tf.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma, 1, 0)
166
+
167
+ def __call__(self,
168
+ inputs: tf.Tensor,
169
+ training: bool):
170
+ """
171
+ Quantize a tensor.
172
+
173
+ Args:
174
+ inputs: Input tensor to quantize.
175
+ training: Whether the graph is in training mode.
176
+
177
+ Returns:
178
+ The quantized tensor.
179
+ """
180
+
181
+ min_tensor = self.get_quantizer_variable(RANGE_MIN)
182
+ max_tensor = self.get_quantizer_variable(RANGE_MAX)
183
+
184
+ #####################################################
185
+ # Soft Rounding
186
+ #####################################################
187
+ aux_var = self.get_soft_targets()
188
+ if not training:
189
+ aux_var = tf.cast(tf.math.greater_equal(aux_var, 0.5), tf.float32)
190
+
191
+ if self.per_channel:
192
+ reshape_shape = get_threshold_reshape_shape(inputs.shape,
193
+ quant_axis=self.quantization_axis,
194
+ quant_axis_dim=-1)
195
+
196
+ #####################################################
197
+ # Quantized Input
198
+ #####################################################
199
+ q_tensor = soft_rounding_uniform_quantizer(input_tensor=inputs,
200
+ auxvar_tensor=aux_var,
201
+ min_tensor=tf.reshape(min_tensor, reshape_shape),
202
+ max_tensor=tf.reshape(max_tensor, reshape_shape),
203
+ num_bits=self.num_bits)
204
+
205
+ else:
206
+ q_tensor = soft_rounding_uniform_quantizer(input_tensor=inputs,
207
+ auxvar_tensor=aux_var,
208
+ min_tensor=min_tensor,
209
+ max_tensor=max_tensor,
210
+ num_bits=self.num_bits)
211
+
212
+ return q_tensor
213
+
214
+ def get_quant_config(self) -> Dict[str, np.ndarray]:
215
+ """
216
+ Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
217
+
218
+ Returns:
219
+ A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
220
+ Keys must match NodeQuantizationConfig attributes
221
+ """
222
+
223
+ return {RANGE_MIN: self.get_quantizer_variable(RANGE_MIN).numpy(),
224
+ RANGE_MAX: self.get_quantizer_variable(RANGE_MAX).numpy()}