mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
  2. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
  3. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +12 -41
  5. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  6. model_compression_toolkit/core/__init__.py +14 -0
  7. model_compression_toolkit/core/analyzer.py +3 -2
  8. model_compression_toolkit/core/common/__init__.py +0 -1
  9. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  11. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  12. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  13. model_compression_toolkit/core/common/framework_info.py +1 -1
  14. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  15. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  16. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  18. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  19. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  20. model_compression_toolkit/core/common/memory_computation.py +1 -1
  21. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  23. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  26. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  28. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  29. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  30. model_compression_toolkit/core/common/model_collector.py +2 -2
  31. model_compression_toolkit/core/common/model_validation.py +1 -1
  32. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  33. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  34. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  35. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  36. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  37. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  38. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  39. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  50. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  51. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  52. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  54. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  55. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  56. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  57. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  58. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  59. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  60. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  61. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  62. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  63. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  65. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  66. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  67. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  68. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  69. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  72. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
  73. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  74. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  75. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  76. model_compression_toolkit/core/keras/constants.py +0 -7
  77. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  85. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  86. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  87. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  88. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  89. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  90. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  91. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  92. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  93. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  94. model_compression_toolkit/core/keras/reader/common.py +1 -1
  95. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  99. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  100. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  102. model_compression_toolkit/core/pytorch/constants.py +4 -6
  103. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  109. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  110. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  111. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  112. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  113. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  114. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  115. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  116. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  117. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  118. model_compression_toolkit/core/runner.py +7 -7
  119. model_compression_toolkit/exporter/__init__.py +5 -0
  120. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
  121. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
  124. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  125. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  126. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
  127. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
  128. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
  129. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
  130. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
  131. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
  132. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
  133. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
  135. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
  136. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
  137. model_compression_toolkit/gptq/__init__.py +6 -0
  138. model_compression_toolkit/gptq/common/gptq_config.py +57 -104
  139. model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
  140. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  141. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  142. model_compression_toolkit/gptq/common/gptq_training.py +30 -39
  143. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  144. model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
  145. model_compression_toolkit/gptq/keras/graph_info.py +8 -33
  146. model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
  147. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  148. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
  149. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  150. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
  151. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  152. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
  153. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
  154. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  155. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
  156. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  157. model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
  158. model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
  159. model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
  160. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  161. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
  162. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  163. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
  164. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  165. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
  166. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
  167. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  168. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
  169. model_compression_toolkit/gptq/runner.py +3 -2
  170. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
  171. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  172. model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
  173. model_compression_toolkit/ptq/__init__.py +3 -0
  174. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  175. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  176. model_compression_toolkit/qat/__init__.py +4 -0
  177. model_compression_toolkit/qat/common/__init__.py +1 -2
  178. model_compression_toolkit/qat/common/qat_config.py +3 -1
  179. model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
  180. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  181. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
  182. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
  183. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  184. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  185. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
  186. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
  187. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  211. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  212. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  213. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  214. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  215. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
  217. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  218. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  219. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  220. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  221. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  222. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
  223. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  224. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
  225. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  226. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
  227. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  228. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
  229. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  232. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  233. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  234. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  235. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  236. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  237. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  238. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  239. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  240. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  241. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  242. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  243. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  244. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  248. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  250. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  254. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  255. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  257. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  259. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  261. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  263. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  265. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  273. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  274. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  275. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  276. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  277. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  278. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  279. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  280. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  281. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  282. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  283. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  284. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  285. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  286. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  287. model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
  288. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
  289. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
  290. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  291. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  292. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  293. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  294. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  300. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  301. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  302. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  303. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  304. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  305. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  306. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  307. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -0,0 +1,114 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List
16
+
17
+ import torch
18
+ import numpy as np
19
+ from torch import nn
20
+
21
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
22
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
23
+ from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
24
+ from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
25
+
26
+
27
+ class LinearTempDecay:
28
+ """
29
+ Annealing process for the soft quantizer regularization temperature term.
30
+ """
31
+
32
+ def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 20, end_b: int = 2):
33
+ """
34
+ Initializes a LinearTempDecay object.
35
+
36
+ Args:
37
+ t_max: maximal time step.
38
+ rel_start_decay: Decay step size at the beginning of the process.
39
+ start_b: Starting value of the regularization term.
40
+ end_b: Target value of the regularization term.
41
+ """
42
+
43
+ self.t_max = t_max
44
+ self.start_decay = rel_start_decay * t_max
45
+ self.start_b = start_b
46
+ self.end_b = end_b
47
+
48
+ def __call__(self, t: float) -> float:
49
+ """
50
+ Cosine annealing scheduler for soft quantizer regularization temperature term.
51
+
52
+ Args:
53
+ t: The current time step.
54
+
55
+ Returns: Scheduled temperature.
56
+ """
57
+
58
+ is_before_start_decay = (t < self.start_decay)
59
+
60
+ rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
61
+
62
+ return self.start_b * is_before_start_decay + \
63
+ (1 - is_before_start_decay) * \
64
+ (self.end_b + (self.start_b - self.end_b) * torch.maximum(to_torch_tensor(np.array([0.0])),
65
+ to_torch_tensor(np.array((1 - rel_t)))))
66
+
67
+
68
+ class SoftQuantizerRegularization:
69
+ """
70
+ A class to handle the computation of soft quantizer regularization for GPTQ training.
71
+ """
72
+
73
+ def __init__(self, total_gradient_steps: int):
74
+ """
75
+ Initializes the regularization computation object with a LinearDecay object.
76
+
77
+ Args:
78
+ total_gradient_steps: The number of gradient steps during optimization.
79
+ """
80
+
81
+ # Initializing the temperature decay according to the number of expected gradient steps
82
+ self.linear_decay = LinearTempDecay(total_gradient_steps)
83
+
84
+ self.count_iter = 0
85
+
86
+ def __call__(self, model: nn.Module, entropy_reg: float):
87
+ """
88
+ Returns the soft quantizer regularization value for SoftRounding.
89
+
90
+ Args:
91
+ model: A model to be quantized with SoftRounding.
92
+ entropy_reg: Entropy value to scale the quantizer regularization.
93
+
94
+ Returns: Regularization value.
95
+ """
96
+
97
+ soft_reg_aux: List[torch.Tensor] = []
98
+ b = self.linear_decay(self.count_iter)
99
+ for layer in model.modules():
100
+ if isinstance(layer, PytorchQuantizationWrapper):
101
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
102
+ fw_info=DEFAULT_PYTORCH_INFO)
103
+
104
+ st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
105
+ soft_reg_aux.append((1 - torch.pow(torch.abs(st - .5) * 2, b)).sum())
106
+
107
+ reg = 0
108
+
109
+ for sq in soft_reg_aux:
110
+ reg += sq
111
+
112
+ self.count_iter += 1
113
+
114
+ return entropy_reg * reg
@@ -14,24 +14,25 @@
14
14
  # ==============================================================================
15
15
  import torch
16
16
  import torch.nn as nn
17
- from typing import List, Dict
17
+ from typing import Dict
18
18
  import numpy as np
19
19
 
20
- from model_compression_toolkit.core.common import Logger, max_power_of_two
20
+ from model_compression_toolkit.core.common import max_power_of_two
21
21
  from model_compression_toolkit import quantizers_infrastructure as qi
22
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
23
  from model_compression_toolkit.gptq.common.gptq_config import RoundingType
24
24
  from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
25
25
  BasePytorchGPTQTrainableQuantizer
26
26
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
27
27
  from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
28
- from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, N_EPOCHS, \
29
- MAX_ITERATIONS_DEFAULT, SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, SOFT_ROUNDING_BETA, GPTQ_ITER, AUXVAR
30
- from model_compression_toolkit.core.common.constants import THRESHOLD, MIN_THRESHOLD
28
+ from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
29
+ SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
30
+ from model_compression_toolkit.constants import THRESHOLD, MIN_THRESHOLD
31
31
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
32
32
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
33
33
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
34
34
  get_threshold_reshape_shape
35
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
35
36
 
36
37
 
37
38
  def soft_rounding_symmetric_quantizer(input_tensor: torch.Tensor,
@@ -67,46 +68,6 @@ def soft_rounding_symmetric_quantizer(input_tensor: torch.Tensor,
67
68
  max_val=int_threshold - 1)
68
69
 
69
70
 
70
- class LinearTempDecay:
71
- """
72
- Annealing process for the soft quantizer regularization temperature term.
73
- """
74
-
75
- def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 20, end_b: int = 2):
76
- """
77
- Initializes a LinearTempDecay object.
78
-
79
- Args:
80
- t_max: maximal time step.
81
- rel_start_decay: Decay step size at the beginning of the process.
82
- start_b: Starting value of the regularization term.
83
- end_b: Target value of the regularization term.
84
- """
85
-
86
- self.t_max = t_max
87
- self.start_decay = rel_start_decay * t_max
88
- self.start_b = start_b
89
- self.end_b = end_b
90
-
91
- def __call__(self, t: nn.Parameter) -> float:
92
- """
93
- Cosine annealing scheduler for soft quantizer regularization temperature term.
94
-
95
- Args:
96
- t: The current time step.
97
-
98
- Returns: Scheduled temperature.
99
- """
100
-
101
- is_before_start_decay = (t < self.start_decay).to(torch.float32)
102
-
103
- rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
104
-
105
- return self.start_b * is_before_start_decay + \
106
- (1 - is_before_start_decay) * \
107
- (self.end_b + (self.start_b - self.end_b) * torch.maximum(to_torch_tensor(np.array([0.0])), (1 - rel_t)))
108
-
109
-
110
71
  @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
111
72
  quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
112
73
  quantizer_type=RoundingType.SoftQuantizer)
@@ -117,22 +78,15 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
117
78
 
118
79
  def __init__(self,
119
80
  quantization_config: TrainableQuantizerWeightsConfig,
120
- n_batches: int = None,
121
- quantization_parameter_learning: bool = False,
122
- n_epochs: int = N_EPOCHS):
81
+ quantization_parameter_learning: bool = False):
123
82
  """
124
83
  Construct a Pytorch model that utilize a fake weight quantizer of soft-quantizer for symmetric quantizer.
125
84
 
126
85
  Args:
127
86
  quantization_config: Trainable weights quantizer config.
128
- n_batches (int): number of batches in representative dataset
129
87
  quantization_parameter_learning (Bool): Whether to learn the threshold or not
130
- n_epochs (int): number of epochs the representative dataset is run during fine-tuning
131
88
  """
132
89
 
133
- if n_batches is None:
134
- Logger.error("SymmetricSoftRoundingGPTQ got an uninitialized n_batches argument.")
135
-
136
90
  super().__init__(quantization_config)
137
91
  self.num_bits = quantization_config.weights_n_bits
138
92
  self.per_channel = quantization_config.weights_per_channel_threshold
@@ -147,35 +101,24 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
147
101
  self.quantization_parameter_learning = quantization_parameter_learning
148
102
 
149
103
  # gamma and zeta are stretch parameters for computing the rectified sigmoind function.
150
- # beta is used to set the regularization term.
151
104
  # See: https://arxiv.org/pdf/2004.10568.pdf
152
105
  self.gamma = SOFT_ROUNDING_GAMMA
153
106
  self.zeta = SOFT_ROUNDING_ZETA
154
- self.beta = SOFT_ROUNDING_BETA
155
107
 
156
108
  self.quantizer_parameters = {}
157
109
 
158
- # Initializing the temperature decay according to the number of expected gradient steps
159
- num_iterations = MAX_ITERATIONS_DEFAULT if n_batches is None else n_epochs * n_batches
160
- self.linear_decay = LinearTempDecay(num_iterations)
161
-
162
110
  def initialize_quantization(self,
163
111
  tensor_shape: torch.Size,
164
112
  name: str,
165
- layer: qi.PytorchQuantizationWrapper) -> Dict[str, nn.Parameter]:
113
+ layer: qi.PytorchQuantizationWrapper):
166
114
  """
167
- Return a dictionary of quantizer parameters and their names.
115
+ Add quantizer parameters to the quantizer parameters dictionary
168
116
 
169
117
  Args:
170
118
  tensor_shape: tensor shape of the quantized tensor.
171
119
  name: Tensor name.
172
120
  layer: Layer to quantize.
173
-
174
- Returns:
175
- Dictionary of parameters names to the variables.
176
121
  """
177
- layer.register_parameter(f"{name}_{GPTQ_ITER}",
178
- nn.Parameter(to_torch_tensor(np.array([0])), requires_grad=False))
179
122
 
180
123
  if self.per_channel:
181
124
  threshold_tensor = to_torch_tensor(self.threshold_values)
@@ -195,31 +138,18 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
195
138
  layer.register_parameter(f"{name}_{AUXVAR}", nn.Parameter(alpha, requires_grad=True))
196
139
 
197
140
  # save the quantizer added parameters for later calculations
198
- self.quantizer_parameters = {PTQ_THRESHOLD: layer.get_parameter(f"{name}_{PTQ_THRESHOLD}"),
199
- AUXVAR: layer.get_parameter(f"{name}_{AUXVAR}"),
200
- GPTQ_ITER: layer.get_parameter(f"{name}_{GPTQ_ITER}")}
141
+ self.add_quantizer_variable(PTQ_THRESHOLD, layer.get_parameter(f"{name}_{PTQ_THRESHOLD}"), VariableGroup.QPARAMS)
142
+ self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
201
143
 
202
144
  if self.quantization_parameter_learning:
203
- layer.register_parameter(f"{name}_{SCALE_PTQ}",
204
- nn.Parameter(torch.ones_like(torch.Tensor(self.threshold_values)),
205
- requires_grad=True))
206
-
207
- self.quantizer_parameters.update({SCALE_PTQ: layer.get_parameter(f"{name}_{SCALE_PTQ}")})
208
-
209
- return self.quantizer_parameters
210
-
211
- def get_regularization(self) -> torch.Tensor:
212
- """
213
- Computes the regularization term for the soft rounding loss.
214
-
215
- Returns:
216
- regularization term.
217
- """
218
-
219
- st = self.get_soft_targets()
220
- ar_iter = self.quantizer_parameters[GPTQ_ITER]
221
- b = self.linear_decay(ar_iter)
222
- return (1 - torch.pow(torch.abs(st - .5) * 2, b)).sum()
145
+ if self.per_channel:
146
+ layer.register_parameter(f"{name}_{SCALE_PTQ}",
147
+ nn.Parameter(to_torch_tensor(torch.ones_like(torch.Tensor(self.threshold_values))),
148
+ requires_grad=True))
149
+ else:
150
+ layer.register_parameter(f"{name}_{SCALE_PTQ}",
151
+ nn.Parameter(to_torch_tensor((torch.tensor([1.0], requires_grad=True)))))
152
+ self.add_quantizer_variable(SCALE_PTQ, layer.get_parameter(f"{name}_{SCALE_PTQ}"), VariableGroup.QPARAMS)
223
153
 
224
154
  def get_soft_targets(self) -> torch.Tensor:
225
155
  """
@@ -229,28 +159,9 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
229
159
  A tensor with the soft rounding targets values.
230
160
 
231
161
  """
232
- scaled_sigmoid = torch.sigmoid(self.quantizer_parameters[AUXVAR]) * (self.zeta - self.gamma) + self.gamma
162
+ scaled_sigmoid = torch.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma
233
163
  return torch.clip(scaled_sigmoid, min=0, max=1)
234
164
 
235
- def get_aux_variable(self) -> List[torch.Tensor]:
236
- """
237
- This function return a list with the quantizer's quantization auxiliary variables.
238
-
239
- Returns: A list with the quantization auxiliary variables.
240
- """
241
- return [self.quantizer_parameters.get(AUXVAR)]
242
-
243
- def get_quantization_variable(self) -> List[torch.Tensor]:
244
- """
245
- This function return a list with the quantizer's quantization parameters variables.
246
-
247
- Returns: A list with the quantization parameters.
248
- """
249
- if self.quantization_parameter_learning and not self.power_of_two:
250
- return [self.quantizer_parameters[SCALE_PTQ]]
251
- else:
252
- return []
253
-
254
165
  def get_quant_config(self) -> Dict[str, np.ndarray]:
255
166
  """
256
167
  Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
@@ -260,12 +171,13 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
260
171
  Keys must match NodeQuantizationConfig attributes
261
172
 
262
173
  """
263
- old_threshold = torch_tensor_to_numpy(self.quantizer_parameters[PTQ_THRESHOLD])
174
+ old_threshold = torch_tensor_to_numpy(self.get_quantizer_variable(PTQ_THRESHOLD))
175
+ old_threshold = np.resize(old_threshold, self.threshold_shape)
264
176
  if self.power_of_two:
265
177
  old_threshold = max_power_of_two(old_threshold, MIN_THRESHOLD)
266
178
  else:
267
179
  if self.quantization_parameter_learning:
268
- scale = torch.reshape(self.quantizer_parameters[SCALE_PTQ], self.threshold_shape)
180
+ scale = torch.reshape(self.get_quantizer_variable(SCALE_PTQ), self.threshold_shape)
269
181
  old_threshold = old_threshold * torch_tensor_to_numpy(scale)
270
182
  old_threshold = old_threshold.reshape(self.threshold_shape)
271
183
  return {THRESHOLD: old_threshold}
@@ -283,17 +195,14 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
283
195
  Returns:
284
196
  quantized tensor
285
197
  """
286
- ar_iter = self.quantizer_parameters[GPTQ_ITER]
287
- auxvar = self.quantizer_parameters[AUXVAR]
288
- ptq_threshold_tensor = self.quantizer_parameters[PTQ_THRESHOLD]
198
+ auxvar = self.get_quantizer_variable(AUXVAR)
199
+ ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
289
200
 
290
201
  #####################################################
291
202
  # Soft Rounding
292
203
  #####################################################
293
204
  aux_var = self.get_soft_targets()
294
- if training:
295
- ar_iter.set_(ar_iter + 1)
296
- else:
205
+ if not training:
297
206
  aux_var = (aux_var >= 0.5).to(auxvar.dtype)
298
207
 
299
208
  if self.per_channel:
@@ -317,7 +226,7 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
317
226
  power_of_two=self.power_of_two)
318
227
 
319
228
  if self.quantization_parameter_learning and not self.power_of_two:
320
- scale = torch.reshape(self.quantizer_parameters[SCALE_PTQ], reshape_shape)
229
+ scale = torch.reshape(self.get_quantizer_variable(SCALE_PTQ), reshape_shape)
321
230
  q_tensor *= scale
322
231
 
323
232
  else:
@@ -328,4 +237,8 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
328
237
  signed=True,
329
238
  power_of_two=self.power_of_two)
330
239
 
240
+ if self.quantization_parameter_learning and not self.power_of_two:
241
+ scale = self.get_quantizer_variable(SCALE_PTQ)
242
+ q_tensor *= scale
243
+
331
244
  return q_tensor
@@ -0,0 +1,194 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import torch
16
+ import torch.nn as nn
17
+ from typing import Dict
18
+ import numpy as np
19
+
20
+ from model_compression_toolkit import quantizers_infrastructure as qi
21
+ from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
+ from model_compression_toolkit.gptq.common.gptq_config import RoundingType
24
+ from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
25
+ BasePytorchGPTQTrainableQuantizer
26
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
27
+ from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
28
+ from model_compression_toolkit.gptq.common.gptq_constants import SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
29
+ from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import fix_range_to_include_zero
30
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
31
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
32
+ mark_quantizer
33
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import \
34
+ VariableGroup
35
+ from model_compression_toolkit.constants import RANGE_MAX, RANGE_MIN
36
+
37
+
38
+ def soft_rounding_unifrom_quantizer(input_tensor: torch.Tensor,
39
+ auxvar_tensor: torch.Tensor,
40
+ min_range: torch.Tensor,
41
+ max_range: torch.Tensor,
42
+ num_bits: int) -> torch.Tensor:
43
+ """
44
+ Quantize a tensor uniformly for GPTQ quantizers.
45
+
46
+ Args:
47
+ input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
48
+ auxvar_tensor: Tensor that manifests the bit shift of the quantized weights due to gptq training.
49
+ min_range: Tensor with min values to compute the delta grid.
50
+ max_range: Tensor with max values to compute the delta grid.
51
+ num_bits: Num of bits to use.
52
+
53
+ Returns:
54
+ A quantized tensor.
55
+ """
56
+ # adjusts the quantization range so the quantization grid includes zero.
57
+ min_range, max_range = fix_range_to_include_zero(min_range, max_range, num_bits)
58
+ delta = qutils.calculate_delta_uniform(min_range, max_range, num_bits)
59
+ input_tensor_int = qutils.ste_floor((input_tensor - min_range) / delta)
60
+ tensor_q = input_tensor_int + auxvar_tensor
61
+ return delta * qutils.ste_clip(tensor_q,
62
+ min_val=0,
63
+ max_val=2 ** num_bits - 1) + min_range
64
+
65
+
66
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
67
+ quantization_method=[QuantizationMethod.UNIFORM],
68
+ quantizer_type=RoundingType.SoftQuantizer)
69
+ class UniformSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
70
+ """
71
+ Trainable uniform quantizer to optimize the rounding of the quantized values using a soft quantization method.
72
+ """
73
+
74
+ def __init__(self,
75
+ quantization_config: TrainableQuantizerWeightsConfig,
76
+ quantization_parameter_learning: bool = False):
77
+ """
78
+ Construct a Pytorch model that utilize a fake weight quantizer of soft-quantizer for symmetric quantizer.
79
+
80
+ Args:
81
+ quantization_config: Trainable weights quantizer config.
82
+ quantization_parameter_learning (Bool): Whether to learn the min/max ranges or not
83
+ """
84
+
85
+ super().__init__(quantization_config)
86
+ self.num_bits = quantization_config.weights_n_bits
87
+ self.per_channel = quantization_config.weights_per_channel_threshold
88
+
89
+ self.min_values = quantization_config.weights_quantization_params[RANGE_MIN]
90
+ self.max_values = quantization_config.weights_quantization_params[RANGE_MAX]
91
+
92
+ self.quantization_axis = quantization_config.weights_channels_axis
93
+ self.quantization_parameter_learning = quantization_parameter_learning
94
+
95
+ # gamma and zeta are stretch parameters for computing the rectified sigmoid function.
96
+ # See: https://arxiv.org/pdf/2004.10568.pdf
97
+ self.gamma = SOFT_ROUNDING_GAMMA
98
+ self.zeta = SOFT_ROUNDING_ZETA
99
+
100
+ def initialize_quantization(self,
101
+ tensor_shape: torch.Size,
102
+ name: str,
103
+ layer: qi.PytorchQuantizationWrapper):
104
+ """
105
+ Add quantizer parameters to the quantizer parameters dictionary
106
+
107
+ Args:
108
+ tensor_shape: tensor shape of the quantized tensor.
109
+ name: Tensor name.
110
+ layer: Layer to quantize.
111
+ """
112
+
113
+ # Add min and max variables to layer.
114
+ if self.per_channel:
115
+ min_values = to_torch_tensor(self.min_values)
116
+ max_values = to_torch_tensor(self.max_values)
117
+ else:
118
+ min_values = torch.tensor(self.min_values)
119
+ max_values = torch.tensor(self.max_values)
120
+
121
+ layer.register_parameter(name+"_"+FQ_MIN, nn.Parameter(min_values, requires_grad=self.quantization_parameter_learning))
122
+ layer.register_parameter(name+"_"+FQ_MAX, nn.Parameter(max_values, requires_grad=self.quantization_parameter_learning))
123
+
124
+ w = layer.layer.weight
125
+ delta = qutils.calculate_delta_uniform(min_values, max_values, self.num_bits)
126
+ w_clipped_normed = torch.clip((w - min_values) / delta, 0, 2 ** self.num_bits - 1)
127
+ rest = w_clipped_normed - torch.floor(w_clipped_normed) # rest of rounding [0, 1)
128
+ alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest
129
+ layer.register_parameter(f"{name}_{AUXVAR}", nn.Parameter(alpha, requires_grad=True))
130
+
131
+ # Save the quantizer parameters
132
+ self.add_quantizer_variable(FQ_MIN, layer.get_parameter(name+"_"+FQ_MIN), VariableGroup.QPARAMS)
133
+ self.add_quantizer_variable(FQ_MAX, layer.get_parameter(name+"_"+FQ_MAX), VariableGroup.QPARAMS)
134
+ self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
135
+
136
+ def get_soft_targets(self) -> torch.Tensor:
137
+ """
138
+ Computes the rectified sigmoid function for the quantization target parameters.
139
+
140
+ Returns:
141
+ A tensor with the soft rounding targets values.
142
+
143
+ """
144
+ scaled_sigmoid = torch.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma
145
+ return torch.clip(scaled_sigmoid, min=0, max=1)
146
+
147
+ def get_quant_config(self) -> Dict[str, np.ndarray]:
148
+ """
149
+ Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
150
+
151
+ Returns:
152
+ A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
153
+ Keys must match NodeQuantizationConfig attributes
154
+
155
+ """
156
+ min_values = torch_tensor_to_numpy(self.get_quantizer_variable(FQ_MIN))
157
+ max_values = torch_tensor_to_numpy(self.get_quantizer_variable(FQ_MAX))
158
+ return {RANGE_MIN: min_values,
159
+ RANGE_MAX: max_values}
160
+
161
+ def __call__(self,
162
+ inputs: nn.Parameter,
163
+ training: bool) -> torch.Tensor:
164
+ """
165
+ Quantize a tensor.
166
+
167
+ Args:
168
+ inputs: Input tensor to quantize.
169
+ training: whether in training mode or not
170
+
171
+ Returns:
172
+ quantized tensor
173
+ """
174
+ auxvar = self.get_quantizer_variable(AUXVAR)
175
+ min_range = self.get_quantizer_variable(FQ_MIN)
176
+ max_range = self.get_quantizer_variable(FQ_MAX)
177
+
178
+ #####################################################
179
+ # Soft Rounding
180
+ #####################################################
181
+ aux_var = self.get_soft_targets()
182
+ if not training:
183
+ aux_var = (aux_var >= 0.5).to(auxvar.dtype)
184
+
185
+ #####################################################
186
+ # Quantized Input
187
+ #####################################################
188
+ q_tensor = soft_rounding_unifrom_quantizer(input_tensor=inputs,
189
+ auxvar_tensor=aux_var,
190
+ min_range=min_range,
191
+ max_range=max_range,
192
+ num_bits=self.num_bits)
193
+
194
+ return q_tensor
@@ -14,23 +14,23 @@
14
14
  # ==============================================================================
15
15
  import torch
16
16
  import torch.nn as nn
17
- from typing import List, Dict
17
+ from typing import Dict
18
18
  import numpy as np
19
19
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
20
20
 
21
21
  from model_compression_toolkit import quantizers_infrastructure as qi
22
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
23
  from model_compression_toolkit.gptq.common.gptq_config import RoundingType
24
24
  from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
25
25
  BasePytorchGPTQTrainableQuantizer
26
26
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
27
27
  from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
28
28
  from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD, MAX_LSB_CHANGE
29
- from model_compression_toolkit.core.common.constants import THRESHOLD
29
+ from model_compression_toolkit.constants import THRESHOLD
30
30
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
31
31
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
32
32
  mark_quantizer
33
-
33
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
34
34
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
35
35
  get_threshold_reshape_shape
36
36
 
@@ -104,23 +104,19 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
104
104
  self.quantization_axis = quantization_config.weights_channels_axis
105
105
  self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
106
106
  self.max_lsbs_change = max_lsbs_change_map.get(self.num_bits)
107
- self.quantizer_parameters = {}
108
107
 
109
108
 
110
109
  def initialize_quantization(self,
111
110
  tensor_shape: torch.Size,
112
111
  name: str,
113
- layer: qi.PytorchQuantizationWrapper) -> Dict[str, nn.Parameter]:
112
+ layer: qi.PytorchQuantizationWrapper):
114
113
  """
115
- Return a dictionary of quantizer parameters and their names.
114
+ Add quantizer parameters to the quantizer parameters dictionary
116
115
 
117
116
  Args:
118
117
  tensor_shape: tensor shape of the quantized tensor.
119
118
  name: Tensor name.
120
119
  layer: Layer to quantize.
121
-
122
- Returns:
123
- Dictionary of parameters names to the variables.
124
120
  """
125
121
 
126
122
  layer.register_parameter(f"{name}_{PTQ_THRESHOLD}",
@@ -131,27 +127,9 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
131
127
  requires_grad=True))
132
128
 
133
129
  # save the quantizer added parameters for later calculations
134
- self.quantizer_parameters = {PTQ_THRESHOLD: layer.get_parameter(f"{name}_{PTQ_THRESHOLD}"),
135
- AUXVAR: layer.get_parameter(f"{name}_{AUXVAR}")}
136
-
137
- return self.quantizer_parameters
138
-
139
-
140
- def get_aux_variable(self) -> List[torch.Tensor]:
141
- """
142
- This function return a list with the quantizer's quantization auxiliary variables.
143
-
144
- Returns: A list with the quantization auxiliary variables.
145
- """
146
- return [self.quantizer_parameters.get(AUXVAR)]
130
+ self.add_quantizer_variable(PTQ_THRESHOLD, layer.get_parameter(f"{name}_{PTQ_THRESHOLD}"), VariableGroup.QPARAMS)
131
+ self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
147
132
 
148
- def get_quantization_variable(self) -> List[torch.Tensor]:
149
- """
150
- This function return a list with the quantizer's quantization parameters variables.
151
-
152
- Returns: A list with the quantization parameters.
153
- """
154
- return [self.quantizer_parameters.get(PTQ_THRESHOLD)]
155
133
 
156
134
  def get_quant_config(self) -> Dict[str, np.ndarray]:
157
135
  """
@@ -162,7 +140,7 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
162
140
  Keys must match NodeQuantizationConfig attributes
163
141
 
164
142
  """
165
- old_threshold = self.quantizer_parameters[PTQ_THRESHOLD]
143
+ old_threshold = self.get_quantizer_variable(PTQ_THRESHOLD)
166
144
  return {THRESHOLD: torch_tensor_to_numpy(old_threshold).reshape(self.threshold_shape)}
167
145
 
168
146
  def __call__(self,
@@ -178,8 +156,8 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
178
156
  Returns:
179
157
  quantized tensor
180
158
  """
181
- auxvar = self.quantizer_parameters[AUXVAR]
182
- ptq_threshold_tensor = self.quantizer_parameters[PTQ_THRESHOLD]
159
+ auxvar = self.get_quantizer_variable(AUXVAR)
160
+ ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
183
161
 
184
162
  if self.per_channel:
185
163
  reshape_shape = get_threshold_reshape_shape(inputs.shape,