mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -11,4 +11,8 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- # ==============================================================================
14
+ # ==============================================================================
15
+
16
+ import model_compression_toolkit.gptq.pytorch.quantizer.ste_rounding.symmetric_ste
17
+ import model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.symmetric_soft_quantizer
18
+ import model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.uniform_soft_quantizer
@@ -0,0 +1,92 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from abc import abstractmethod
16
+ from typing import Union, Dict, List
17
+
18
+ from model_compression_toolkit.core.common.logger import Logger
19
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
20
+ from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
21
+
22
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
23
+ TrainableQuantizerActivationConfig
24
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import \
25
+ BaseTrainableQuantizer
26
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.pytorch.base_pytorch_quantizer import \
27
+ BasePytorchTrainableQuantizer
28
+
29
+ if FOUND_TORCH:
30
+ from torch import Tensor
31
+ from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
32
+
33
+ class BasePytorchGPTQTrainableQuantizer(BasePytorchTrainableQuantizer):
34
+ """
35
+ A base class for trainable Pytorch quantizer for GPTQ.
36
+ """
37
+
38
+ def __init__(self,
39
+ quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
40
+ """
41
+ Initializes BasePytorchGPTQTrainableQuantizer object.
42
+
43
+ Args:
44
+ quantization_config: quantizer config class contains all the information about a quantizer configuration.
45
+ """
46
+
47
+ super().__init__(quantization_config)
48
+
49
+ def update_layer_quantization_params(self, layer: PytorchQuantizationWrapper
50
+ ) -> (Dict[str, Tensor], Dict[str, Dict], Dict):
51
+ """
52
+ A Function to calculate the needed change in attributes in NodeQuantizationConfig after retraining.
53
+
54
+ Args:
55
+ layer: A wrapped Pytorch layer.
56
+
57
+ Returns:
58
+ 3 dictionaries describing the change in layer's weights, weights config, activation config
59
+ that changed during GPTQ retraining.
60
+ Keys must match NodeQuantizationConfig attributes
61
+
62
+ """
63
+ weights = {}
64
+ for weight, quantizer_vars, quantizer in layer.get_weights_vars():
65
+ if not isinstance(quantizer, BaseTrainableQuantizer):
66
+ Logger.error(f"Expecting a GPTQ trainable quantizer, " # pragma: no cover
67
+ f"but got {type(quantizer)} which is not callable.")
68
+ weights.update({weight: quantizer(training=False, inputs=quantizer_vars)})
69
+
70
+ quant_config = {WEIGHTS_QUANTIZATION_PARAMS: self.get_quant_config()}
71
+
72
+ return weights, quant_config, {}
73
+
74
+ @abstractmethod
75
+ def get_quant_config(self):
76
+ """
77
+ Returns the config used to edit NodeQuantizationConfig after GPTQ retraining.
78
+
79
+ Returns:
80
+ A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
81
+ Keys must match NodeQuantizationConfig attributes.
82
+
83
+ """
84
+ raise NotImplemented(f'{self.__class__.__name__} have to implement the ' # pragma: no cover
85
+ f'quantizer\'s get_quant_config.')
86
+
87
+ else:
88
+ class BasePytorchGPTQTrainableQuantizer: # pragma: no cover
89
+ def __init__(self, *args, **kwargs):
90
+ Logger.critical('Installing Pytorch is mandatory '
91
+ 'when using BasePytorchGPTQTrainableQuantizer. '
92
+ 'Could not find torch package.') # pragma: no cover
@@ -30,11 +30,20 @@ def calculate_delta(max_tensor: torch.Tensor,
30
30
  num_bits: int,
31
31
  signed: bool) -> torch.Tensor:
32
32
  """
33
- Compute the step size for the quantization.
33
+ Compute the step size for the symmetric quantization.
34
34
  """
35
35
  return max_tensor / (2 ** (num_bits - int(signed)))
36
36
 
37
37
 
38
+ def calculate_delta_uniform(min_tensor: torch.Tensor,
39
+ max_tensor: torch.Tensor,
40
+ num_bits: int) -> torch.Tensor:
41
+ """
42
+ Compute the step size for the uniform quantization.
43
+ """
44
+ return (max_tensor-min_tensor) / (2 ** num_bits - 1)
45
+
46
+
38
47
  def ste_ceil(x: torch.Tensor) -> torch.Tensor:
39
48
  """
40
49
  Return the ceil values of a tensor.
@@ -66,93 +75,6 @@ def ste_clip(x: torch.Tensor, min_val=-1.0, max_val=1.0) -> torch.Tensor:
66
75
  return (torch.clip(x, min=min_val, max=max_val) - x).detach() + x
67
76
 
68
77
 
69
- def gumbel_softmax(x: torch.Tensor, tau: Union[torch.Tensor,float], gumbel_tensor: Union[torch.Tensor,float], eps: float = 1e-6, axis=0,
70
- gumbel_scale: float = 1.0) -> torch.Tensor:
71
- """
72
- A gumbel softmax function.
73
- Args:
74
- x: A tensor of log probability.
75
- tau: A temperature tensor.
76
- gumbel_tensor: A tensor of gumbel random variable.
77
- eps: A small number for numeric stability.
78
- axis: A integer representing the axis of which the gumbel softmax applyed on.
79
- gumbel_scale: A normalization factor for the gumbel tensor values
80
-
81
- Returns: A gumbel softmax probability tensor.
82
-
83
- """
84
- return softmax((log_softmax(x, dim=axis) + gumbel_tensor * gumbel_scale) / (tau + eps), dim=axis)
85
-
86
-
87
- def select_gumbel(prob: torch.Tensor) -> torch.Tensor:
88
- """
89
- This function apply ste on the output of the gumbel softmax.
90
- Args:
91
- prob: A tensor of probability.
92
-
93
- Returns: A Tensor of ohe hot vector
94
-
95
- """
96
- max_index = torch.argmax(prob, dim=0)
97
- axis_list = [i for i in range(len(max_index.shape))]
98
- axis_list.insert(0, len(max_index.shape))
99
- one_hot_prob = torch.permute(one_hot(max_index, num_classes=prob.shape[0]), axis_list)
100
- return one_hot_prob + 0*prob
101
-
102
-
103
- def ste_gumbel(prob: torch.Tensor) -> torch.Tensor:
104
- """
105
- This function apply ste on the output of the gumbel softmax.
106
- Args:
107
- prob:A tensor of probability
108
-
109
- Returns: A Tensor of ohe hot vector with STE.
110
-
111
- """
112
- delta = (select_gumbel(prob) - prob).detach()
113
- return prob + delta
114
-
115
-
116
- def sample_gumbel(shape, eps=1e-6) -> torch.Tensor:
117
- """
118
- A function that sample a tensor of i.i.d gumbel random variable.
119
- Args:
120
- shape: The tensor output shape
121
- eps: A small number for numeric stability.
122
-
123
- Returns: A tensor of i.i.d gumbel random variable.
124
-
125
- """
126
- u = to_torch_tensor(torch.rand(shape))
127
- return -torch.log(-torch.log(u + eps) + eps)
128
-
129
-
130
- def symmetric_quantizer(input_tensor: torch.Tensor,
131
- max_tensor: torch.Tensor,
132
- num_bits: int,
133
- signed: bool,
134
- power_of_two: bool = False) -> torch.Tensor:
135
- """
136
- Quantize a tensor symmetrically.
137
- Args:
138
- input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
139
- max_tensor: Tensor with max values to compute the threshold.
140
- num_bits: Num of bits to use.
141
- signed: Signedness of the quantization range.
142
- power_of_two: Whether the threshold should be constrained or not.
143
- Returns:
144
- A quantized tensor.
145
- """
146
-
147
- if power_of_two:
148
- max_tensor = power_of_two_max(max_tensor)
149
- delta_tensor = calculate_delta(max_tensor, num_bits, signed)
150
- tensor_q = ste_round(input_tensor / delta_tensor)
151
- min_int = -int(signed) * (2 ** (num_bits - int(signed)))
152
- max_int = (2 ** (num_bits - int(signed))) - 1
153
- return delta_tensor * ste_clip(tensor_q, min_val=min_int, max_val=max_int)
154
-
155
-
156
78
  def fix_range_to_include_zero(range_min: torch.Tensor,
157
79
  range_max: torch.Tensor,
158
80
  n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -180,34 +102,3 @@ def fix_range_to_include_zero(range_min: torch.Tensor,
180
102
  min_range_adj = min_range_adj * mid_range + max_negative * range_min
181
103
  max_range_adj = max_range_adj * mid_range + min_positive * range_max
182
104
  return min_range_adj, max_range_adj
183
-
184
-
185
- def uniform_quantizer(tensor_data: torch.Tensor,
186
- range_min: torch.Tensor,
187
- range_max: torch.Tensor,
188
- n_bits: int) -> torch.Tensor:
189
- """
190
- Quantize a tensor according to given range (min, max) and number of bits.
191
- Args:
192
- tensor_data: Tensor values to quantize.
193
- range_min: minimum bound of the range for quantization (or array of min values per channel).
194
- range_max: maximum bound of the range for quantization (or array of max values per channel).
195
- n_bits: Number of bits to quantize the tensor.
196
- Returns:
197
- Quantized data.
198
- """
199
- # adjusts the quantization rage so the quantization grid include zero.
200
- a, b = fix_range_to_include_zero(range_min, range_max, n_bits)
201
-
202
- # Compute the step size of quantized values.
203
- delta_tensor = (b - a) / (2 ** n_bits - 1)
204
-
205
- # Apply rounding
206
- input_tensor_int = ste_round((tensor_data - a) / delta_tensor)
207
-
208
- # Clip data in range
209
- clipped_tensor = ste_clip(input_tensor_int, min_val=0, max_val=2 ** n_bits - 1)
210
-
211
- # Quantize the data between min/max of quantization range.
212
- q = delta_tensor * clipped_tensor + a
213
- return q
@@ -0,0 +1,75 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List, Dict, Tuple
16
+
17
+ from model_compression_toolkit.gptq import GradientPTQConfigV2
18
+ from model_compression_toolkit.core import common
19
+ from model_compression_toolkit.core.pytorch.constants import KERNEL
20
+ from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizer import \
21
+ get_activation_inferable_quantizer_kwargs
22
+ from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
23
+ BasePytorchGPTQTrainableQuantizer
24
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
25
+ get_inferable_quantizer_class
26
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
27
+ BasePyTorchInferableQuantizer
28
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizer_config import \
29
+ get_trainable_quantizer_weights_config
30
+ from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
31
+ from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
32
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizers import \
33
+ get_trainable_quantizer_class
34
+
35
+
36
+ def quantization_builder(n: common.BaseNode,
37
+ gptq_config: GradientPTQConfigV2,
38
+ ) -> Tuple[Dict[str, BasePytorchQATTrainableQuantizer],
39
+ List[BasePyTorchInferableQuantizer]]:
40
+ """
41
+ Build quantizers for a node according to its quantization configuration and
42
+ a global NoOpQuantizeConfig object.
43
+
44
+ Args:
45
+ n: Node to build its QuantizeConfig.
46
+ gptq_config (GradientPTQConfigV2): GradientPTQConfigV2 configuration.
47
+
48
+ Returns:
49
+ A dictionary which maps the weights kernel attribute to a quantizer for GPTQ training.
50
+ Note that we return a dictionary although there is only a single attribute that is being mapped to a quantizer,
51
+ to be compatible with the quantization infrastructure template.
52
+ """
53
+
54
+ weights_quantizers = {}
55
+ if n.is_weights_quantization_enabled():
56
+ quant_method = n.final_weights_quantization_cfg.weights_quantization_method
57
+ quantizer_class = get_trainable_quantizer_class(quant_target=QuantizationTarget.Weights,
58
+ quantizer_type=gptq_config.rounding_type,
59
+ quant_method=quant_method,
60
+ quantizer_base_class=BasePytorchGPTQTrainableQuantizer)
61
+ weights_quantizers.update({KERNEL: quantizer_class(get_trainable_quantizer_weights_config(n),
62
+ **gptq_config.gptq_quantizer_params_override)})
63
+ activation_quantizers = []
64
+ if n.is_activation_quantization_enabled():
65
+ quant_method = n.final_activation_quantization_cfg.activation_quantization_method
66
+
67
+ quantizer_class = get_inferable_quantizer_class(quant_target=QuantizationTarget.Activation,
68
+ quant_method=quant_method,
69
+ quantizer_base_class=BasePyTorchInferableQuantizer)
70
+
71
+ kwargs = get_activation_inferable_quantizer_kwargs(n)
72
+
73
+ activation_quantizers.append(quantizer_class(**kwargs))
74
+
75
+ return weights_quantizers, activation_quantizers
@@ -0,0 +1,45 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Callable
16
+
17
+ from model_compression_toolkit.gptq import RoundingType, GradientPTQConfigV2, GradientPTQConfig
18
+ from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import \
19
+ SoftQuantizerRegularization
20
+
21
+
22
+ def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen: Callable) -> Callable:
23
+ """
24
+ Returns a function that computes the regularization term for GPTQ training based on the given
25
+ rounding type in the GPTQ configuration.
26
+
27
+ Args:
28
+ gptq_config: A GPTQ configuration.
29
+ representative_data_gen: Dataset used for the GPTQ training.
30
+
31
+ Returns: A function for computing the regularization. If there is no regularization function defined for the given
32
+ rounding type, then it returns a function that just returns 0.
33
+
34
+ """
35
+ if gptq_config.rounding_type == RoundingType.SoftQuantizer:
36
+ # dry run on the representative dataset to count number of batches
37
+ num_batches = 0
38
+ for _ in representative_data_gen():
39
+ num_batches += 1
40
+
41
+ n_epochs = GradientPTQConfigV2.from_v1(n_ptq_iter=num_batches, config_v1=gptq_config).n_epochs if \
42
+ not type(gptq_config) == GradientPTQConfigV2 else gptq_config.n_epochs
43
+ return SoftQuantizerRegularization(total_gradient_steps=num_batches * n_epochs)
44
+ else:
45
+ return lambda m, e_reg: 0
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -0,0 +1,115 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List
16
+
17
+ import torch
18
+ import numpy as np
19
+ from torch import nn
20
+
21
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
22
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
23
+ from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
24
+ from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
25
+
26
+
27
+ class LinearTempDecay:
28
+ """
29
+ Annealing process for the soft quantizer regularization temperature term.
30
+ """
31
+
32
+ def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 20, end_b: int = 2):
33
+ """
34
+ Initializes a LinearTempDecay object.
35
+
36
+ Args:
37
+ t_max: maximal time step.
38
+ rel_start_decay: Decay step size at the beginning of the process.
39
+ start_b: Starting value of the regularization term.
40
+ end_b: Target value of the regularization term.
41
+ """
42
+
43
+ self.t_max = t_max
44
+ self.start_decay = rel_start_decay * t_max
45
+ self.start_b = start_b
46
+ self.end_b = end_b
47
+
48
+ def __call__(self, t: float) -> float:
49
+ """
50
+ Cosine annealing scheduler for soft quantizer regularization temperature term.
51
+
52
+ Args:
53
+ t: The current time step.
54
+
55
+ Returns: Scheduled temperature.
56
+ """
57
+
58
+ is_before_start_decay = (t < self.start_decay)
59
+
60
+ rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
61
+
62
+ return self.start_b * is_before_start_decay + \
63
+ (1 - is_before_start_decay) * \
64
+ (self.end_b + (self.start_b - self.end_b) * torch.maximum(to_torch_tensor(np.array([0.0])),
65
+ to_torch_tensor(np.array((1 - rel_t)))))
66
+
67
+
68
+ class SoftQuantizerRegularization:
69
+ """
70
+ A class to handle the computation of soft quantizer regularization for GPTQ training.
71
+ """
72
+
73
+ def __init__(self, total_gradient_steps: int):
74
+ """
75
+ Initializes the regularization computation object with a LinearDecay object.
76
+
77
+ Args:
78
+ total_gradient_steps: The number of gradient steps during optimization.
79
+ """
80
+
81
+ # Initializing the temperature decay according to the number of expected gradient steps
82
+ self.linear_decay = LinearTempDecay(total_gradient_steps)
83
+
84
+ self.count_iter = 0
85
+
86
+ def __call__(self, model: nn.Module, entropy_reg: float):
87
+ """
88
+ Returns the soft quantizer regularization value for SoftRounding.
89
+
90
+ Args:
91
+ model: A model to be quantized with SoftRounding.
92
+ entropy_reg: Entropy value to scale the quantizer regularization.
93
+
94
+ Returns: Regularization value.
95
+ """
96
+
97
+ soft_reg_aux: List[torch.Tensor] = []
98
+ for layer in model.modules():
99
+ if isinstance(layer, PytorchQuantizationWrapper):
100
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
101
+ fw_info=DEFAULT_PYTORCH_INFO)
102
+
103
+ st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
104
+ b = self.linear_decay(self.count_iter)
105
+
106
+ soft_reg_aux.append((1 - torch.pow(torch.abs(st - .5) * 2, b)).sum())
107
+
108
+ reg = 0
109
+
110
+ for sq in soft_reg_aux:
111
+ reg += sq
112
+
113
+ self.count_iter += 1
114
+
115
+ return entropy_reg * reg