mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
  2. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
  3. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +12 -41
  5. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  6. model_compression_toolkit/core/__init__.py +14 -0
  7. model_compression_toolkit/core/analyzer.py +3 -2
  8. model_compression_toolkit/core/common/__init__.py +0 -1
  9. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  11. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  12. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  13. model_compression_toolkit/core/common/framework_info.py +1 -1
  14. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  15. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  16. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  18. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  19. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  20. model_compression_toolkit/core/common/memory_computation.py +1 -1
  21. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  23. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  26. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  28. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  29. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  30. model_compression_toolkit/core/common/model_collector.py +2 -2
  31. model_compression_toolkit/core/common/model_validation.py +1 -1
  32. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  33. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  34. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  35. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  36. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  37. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  38. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  39. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  50. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  51. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  52. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  54. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  55. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  56. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  57. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  58. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  59. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  60. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  61. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  62. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  63. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  65. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  66. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  67. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  68. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  69. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  72. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
  73. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  74. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  75. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  76. model_compression_toolkit/core/keras/constants.py +0 -7
  77. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  85. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  86. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  87. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  88. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  89. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  90. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  91. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  92. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  93. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  94. model_compression_toolkit/core/keras/reader/common.py +1 -1
  95. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  99. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  100. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  102. model_compression_toolkit/core/pytorch/constants.py +4 -6
  103. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  109. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  110. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  111. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  112. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  113. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  114. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  115. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  116. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  117. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  118. model_compression_toolkit/core/runner.py +7 -7
  119. model_compression_toolkit/exporter/__init__.py +5 -0
  120. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
  121. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
  124. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  125. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  126. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
  127. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
  128. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
  129. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
  130. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
  131. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
  132. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
  133. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
  135. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
  136. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
  137. model_compression_toolkit/gptq/__init__.py +6 -0
  138. model_compression_toolkit/gptq/common/gptq_config.py +57 -104
  139. model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
  140. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  141. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  142. model_compression_toolkit/gptq/common/gptq_training.py +30 -39
  143. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  144. model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
  145. model_compression_toolkit/gptq/keras/graph_info.py +8 -33
  146. model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
  147. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  148. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
  149. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  150. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
  151. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  152. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
  153. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
  154. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  155. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
  156. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  157. model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
  158. model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
  159. model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
  160. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  161. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
  162. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  163. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
  164. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  165. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
  166. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
  167. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  168. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
  169. model_compression_toolkit/gptq/runner.py +3 -2
  170. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
  171. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  172. model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
  173. model_compression_toolkit/ptq/__init__.py +3 -0
  174. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  175. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  176. model_compression_toolkit/qat/__init__.py +4 -0
  177. model_compression_toolkit/qat/common/__init__.py +1 -2
  178. model_compression_toolkit/qat/common/qat_config.py +3 -1
  179. model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
  180. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  181. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
  182. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
  183. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  184. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  185. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
  186. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
  187. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  211. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  212. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  213. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  214. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  215. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
  217. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  218. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  219. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  220. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  221. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  222. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
  223. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  224. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
  225. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  226. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
  227. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  228. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
  229. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  232. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  233. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  234. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  235. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  236. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  237. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  238. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  239. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  240. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  241. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  242. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  243. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  244. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  248. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  250. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  254. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  255. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  257. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  259. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  261. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  263. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  265. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  273. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  274. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  275. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  276. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  277. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  278. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  279. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  280. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  281. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  282. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  283. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  284. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  285. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  286. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  287. model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
  288. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
  289. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
  290. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  291. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  292. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  293. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  294. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  300. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  301. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  302. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  303. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  304. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  305. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  306. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  307. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -13,23 +13,24 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Dict, Any, List
16
+ from typing import Dict, Any
17
17
 
18
18
  import numpy as np
19
19
  import tensorflow as tf
20
20
 
21
- from model_compression_toolkit import RoundingType
21
+ from model_compression_toolkit.gptq import RoundingType
22
22
  from model_compression_toolkit import quantizers_infrastructure as qi
23
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
24
- from model_compression_toolkit.gptq.common.gptq_constants import GPTQ_ITER, AUXVAR, PTQ_THRESHOLD
23
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
24
+ from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD
25
25
  from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
26
- from model_compression_toolkit.core.common.constants import THRESHOLD
26
+ from model_compression_toolkit.constants import THRESHOLD
27
27
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
28
28
  from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
29
29
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
30
30
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
31
31
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
32
32
  get_threshold_reshape_shape
33
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
33
34
 
34
35
 
35
36
  def pertubation_symmetric_quantizer(input_tensor: tf.Tensor,
@@ -96,30 +97,20 @@ class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):
96
97
  self.quantization_axis = quantization_config.weights_channels_axis
97
98
  self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
98
99
  self.max_lsbs_change = max_lsbs_change_map.get(self.num_bits)
99
- self.quantizer_parameters = {}
100
100
 
101
101
  def initialize_quantization(self,
102
102
  tensor_shape: Any,
103
103
  name: str,
104
- layer: Any) -> Dict[Any, Any]:
104
+ layer: Any):
105
105
  """
106
- Return a dictionary of quantizer parameters and their names.
106
+ Add quantizer parameters to the quantizer parameters dictionary
107
107
 
108
108
  Args:
109
109
  tensor_shape: tensor shape of the quantized tensor.
110
110
  name: Tensor name.
111
111
  layer: Layer to quantize.
112
-
113
- Returns:
114
- Dictionary of parameters names to the variables.
115
112
  """
116
113
 
117
- ar_iter = layer.add_weight(
118
- f"{name}_{GPTQ_ITER}",
119
- shape=(),
120
- initializer=tf.keras.initializers.Constant(0.0),
121
- trainable=False)
122
-
123
114
  ptq_threshold_tensor = layer.add_weight(
124
115
  f"{name}_{PTQ_THRESHOLD}",
125
116
  shape=len(self.threshold_values) if self.per_channel else (),
@@ -135,10 +126,8 @@ class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):
135
126
  trainable=True)
136
127
 
137
128
  # save the quantizer added parameters for later calculations
138
- self.quantizer_parameters = {PTQ_THRESHOLD: ptq_threshold_tensor,
139
- AUXVAR: auxvar_tensor,
140
- GPTQ_ITER: ar_iter}
141
- return self.quantizer_parameters
129
+ self.add_quantizer_variable(PTQ_THRESHOLD, ptq_threshold_tensor, VariableGroup.QPARAMS)
130
+ self.add_quantizer_variable(AUXVAR, auxvar_tensor, VariableGroup.WEIGHTS)
142
131
 
143
132
  def __call__(self,
144
133
  inputs: tf.Tensor,
@@ -154,8 +143,8 @@ class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):
154
143
  The quantized tensor.
155
144
  """
156
145
 
157
- auxvar = self.quantizer_parameters[AUXVAR]
158
- ptq_threshold_tensor = self.quantizer_parameters[PTQ_THRESHOLD]
146
+ auxvar = self.get_quantizer_variable(AUXVAR)
147
+ ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
159
148
 
160
149
  if self.per_channel:
161
150
  reshape_shape = get_threshold_reshape_shape(inputs.shape,
@@ -178,25 +167,6 @@ class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):
178
167
  signed=True,
179
168
  power_of_two=self.power_of_two)
180
169
 
181
- def get_aux_variable(self) -> List[tf.Tensor]:
182
- """
183
- This function return a list with the quantizer's quantization auxiliary variables.
184
-
185
- Returns: A list with the quantization auxiliary variables.
186
-
187
- """
188
-
189
- return [self.quantizer_parameters[AUXVAR]]
190
-
191
- def get_quantization_variable(self) -> List[tf.Tensor]:
192
- """
193
- This function return a list with the quantizer's quantization parameters variables.
194
-
195
- Returns: A list with the quantization parameters.
196
-
197
- """
198
-
199
- return [self.quantizer_parameters[PTQ_THRESHOLD]]
200
170
 
201
171
  def get_quant_config(self) -> Dict[str, np.ndarray]:
202
172
  """
@@ -207,5 +177,5 @@ class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):
207
177
  Keys must match NodeQuantizationConfig attributes
208
178
 
209
179
  """
210
- old_threshold = self.quantizer_parameters[PTQ_THRESHOLD]
180
+ old_threshold = self.get_quantizer_variable(PTQ_THRESHOLD)
211
181
  return {THRESHOLD: old_threshold.numpy().reshape(self.threshold_shape)}
@@ -0,0 +1,29 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Type
17
+
18
+ from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
19
+ from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
20
+ from model_compression_toolkit.gptq.pytorch.gptq_training import PytorchGPTQTrainer
21
+
22
+
23
+ class GPTQPytorchImplemantation(GPTQFrameworkImplemantation, PytorchImplementation):
24
+
25
+ def get_gptq_trainer_obj(self) -> Type[PytorchGPTQTrainer]:
26
+ """
27
+ Returns: Pytorch object of GPTQTrainer
28
+ """
29
+ return PytorchGPTQTrainer
@@ -19,21 +19,21 @@ from torch.nn import Module
19
19
  from tqdm import tqdm
20
20
  import copy
21
21
  import torch
22
- from model_compression_toolkit.core.common.logger import Logger
22
+ from model_compression_toolkit.logger import Logger
23
23
  from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
24
24
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
25
25
  from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
26
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2, RoundingType
26
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
27
27
  from model_compression_toolkit.core.common import Graph, BaseNode
28
28
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
29
29
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
30
30
  from model_compression_toolkit.core.pytorch.constants import BIAS
31
31
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, set_model, torch_tensor_to_numpy
32
32
  from model_compression_toolkit.gptq.pytorch.graph_info import get_gptq_trainable_parameters, \
33
- get_weights_for_loss, get_soft_rounding_reg
33
+ get_weights_for_loss
34
34
  from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder
35
- from model_compression_toolkit.gptq.common.gptq_constants import REGULARIZATION_VALUES
36
35
  from model_compression_toolkit import quantizers_infrastructure as qi
36
+ from model_compression_toolkit.gptq.pytorch.quantizer.regularization_factory import get_regularization
37
37
  from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
38
38
 
39
39
 
@@ -63,7 +63,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
63
63
  fw_info: Framework information
64
64
  representative_data_gen: Dataset to use for inputs of the models.
65
65
  """
66
- super().__init__(graph_float, graph_quant, gptq_config, fw_impl, fw_info, representative_data_gen)
66
+ super().__init__(graph_float, graph_quant, gptq_config, fw_impl, fw_info)
67
67
  self.loss_list = []
68
68
  self.input_scale = 1
69
69
  if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
@@ -71,7 +71,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
71
71
  else:
72
72
  self.input_scale = self.gptq_user_info.input_scale
73
73
 
74
- trainable_weights, trainable_bias, trainable_threshold, trainable_temperature = get_gptq_trainable_parameters(
74
+ trainable_weights, trainable_bias, trainable_threshold = get_gptq_trainable_parameters(
75
75
  self.fxp_model,
76
76
  add_bias=self.gptq_config.train_bias)
77
77
 
@@ -86,7 +86,9 @@ class PytorchGPTQTrainer(GPTQTrainer):
86
86
  trainable_bias,
87
87
  trainable_threshold)
88
88
 
89
- self.weights_for_average_loss = to_torch_tensor(self.compute_jacobian_based_weights(representative_data_gen))
89
+ self.weights_for_average_loss = to_torch_tensor(self.compute_hessian_based_weights(representative_data_gen))
90
+
91
+ self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
90
92
 
91
93
  def _is_gptq_applicable(self,
92
94
  node: BaseNode) -> bool:
@@ -184,9 +186,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
184
186
  self.compare_points_std,
185
187
  self.weights_for_average_loss)
186
188
 
187
- reg_value = self.gptq_config.quantizer_config.get_regularization_value(
188
- self.fxp_model,
189
- **{REGULARIZATION_VALUES: self._get_quantizer_regularization_values(self.gptq_config.rounding_type)})
189
+ reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor)
190
190
 
191
191
  loss_value += reg_value
192
192
 
@@ -272,18 +272,3 @@ class PytorchGPTQTrainer(GPTQTrainer):
272
272
  if hasattr(layer.layer, BIAS):
273
273
  bias = getattr(layer.layer, BIAS)
274
274
  bias.requires_grad = self.gptq_config.train_bias
275
-
276
- def _get_quantizer_regularization_values(self, rounding_type: RoundingType) -> List[torch.Tensor]:
277
- """
278
- Mapping between a rounding type to its matching regularization method.
279
-
280
- Args:
281
- rounding_type: GPTQ rounding type.
282
-
283
- Returns: A regularization computation method.
284
-
285
- """
286
- if rounding_type == RoundingType.SoftQuantizer:
287
- return get_soft_rounding_reg(self.fxp_model)
288
- else:
289
- return []
@@ -15,11 +15,11 @@
15
15
  import torch
16
16
  import torch.nn as nn
17
17
  from typing import List
18
-
19
18
  from model_compression_toolkit.core.pytorch.constants import BIAS
20
19
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
21
20
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
22
21
  from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
22
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
23
23
 
24
24
 
25
25
  def get_gptq_trainable_parameters(fxp_model: nn.Module,
@@ -39,21 +39,23 @@ def get_gptq_trainable_parameters(fxp_model: nn.Module,
39
39
  trainable_aux_weights = nn.ParameterList()
40
40
  trainable_threshold = nn.ParameterList()
41
41
  trainable_bias = nn.ParameterList()
42
- trainable_temperature = nn.ParameterList()
43
42
 
44
43
  for layer in fxp_model.modules():
45
44
  if isinstance(layer, PytorchQuantizationWrapper):
46
45
  kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
47
46
  fw_info=DEFAULT_PYTORCH_INFO)
48
47
 
49
- trainable_aux_weights.extend(layer.weights_quantizers[kernel_attribute].get_aux_variable())
50
- trainable_threshold.extend(layer.weights_quantizers[kernel_attribute].get_quantization_variable())
48
+ # collect trainable weights per quantizer
49
+ quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
50
+ quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
51
+ trainable_aux_weights.extend(quantizer_trainable_weights)
52
+ trainable_threshold.extend(quantizer_trainable_threshold)
51
53
 
52
54
  if add_bias and hasattr(layer.layer, BIAS):
53
55
  bias = getattr(layer.layer, BIAS)
54
56
  trainable_bias.append(bias)
55
57
 
56
- return trainable_aux_weights, trainable_bias, trainable_threshold, trainable_temperature
58
+ return trainable_aux_weights, trainable_bias, trainable_threshold
57
59
 
58
60
 
59
61
  def get_weights_for_loss(fxp_model: nn.Module) -> [List[nn.Parameter], List[torch.Tensor]]:
@@ -77,25 +79,3 @@ def get_weights_for_loss(fxp_model: nn.Module) -> [List[nn.Parameter], List[torc
77
79
  fxp_weights_list.append(quantizer(training=False, inputs=quantizer_vars))
78
80
 
79
81
  return flp_weights_list, fxp_weights_list
80
-
81
-
82
- # TODO: this function need to move to location that is relevant only for soft quantizer -
83
- # once deciding how to handle GPTQ quantizers regularization.
84
- def get_soft_rounding_reg(fxp_model: nn.Module) -> List[torch.Tensor]:
85
- """
86
- This function returns the soft quantizer regularization values for SoftRounding.
87
-
88
- Args:
89
- fxp_model: A model to be quantized with SoftRounding.
90
-
91
- Returns: A list of tensors.
92
- """
93
-
94
- soft_reg_aux: List[torch.Tensor] = []
95
- for layer in fxp_model.modules():
96
- if isinstance(layer, PytorchQuantizationWrapper):
97
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
98
- fw_info=DEFAULT_PYTORCH_INFO)
99
-
100
- soft_reg_aux.append(layer.weights_quantizers[kernel_attribute].get_regularization())
101
- return soft_reg_aux
@@ -14,17 +14,18 @@
14
14
  # ==============================================================================
15
15
  from typing import Callable
16
16
  from model_compression_toolkit.core import common
17
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.constants import PYTORCH
20
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2, RoundingType
21
- from model_compression_toolkit.core.common.target_platform import TargetPlatformCapabilities
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import PYTORCH
20
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
22
22
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
23
23
  from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
24
+ from model_compression_toolkit.gptq.keras.quantization_facade import GPTQ_MOMENTUM
24
25
  from model_compression_toolkit.gptq.runner import gptq_runner
25
26
  from model_compression_toolkit.core.exporter import export_model
26
27
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
27
- from model_compression_toolkit import CoreConfig, GPTQQuantizerConfig
28
+ from model_compression_toolkit.core import CoreConfig
28
29
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
29
30
  MixedPrecisionQuantizationConfigV2
30
31
 
@@ -35,8 +36,8 @@ LR_QUANTIZATION_PARAM_DEFAULT = 1e-4
35
36
 
36
37
  if FOUND_TORCH:
37
38
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
38
- from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
39
- from model_compression_toolkit.core.pytorch.constants import DEFAULT_TP_MODEL
39
+ from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation
40
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
40
41
  from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss
41
42
  from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
42
43
  import torch
@@ -71,33 +72,19 @@ if FOUND_TORCH:
71
72
  Import MCT and Create a GradientPTQConfigV2 to run for 5 epochs:
72
73
 
73
74
  >>> import model_compression_toolkit as mct
74
- >>> gptq_conf = mct.get_pytorch_gptq_config(n_epochs=5)
75
+ >>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=5)
75
76
 
76
77
  Other PyTorch optimizers can be passed with dummy params:
77
78
 
78
79
  >>> import torch
79
- >>> gptq_conf = mct.get_pytorch_gptq_config(n_epochs=3, optimizer=torch.optim.Adam([torch.Tensor(1)]))
80
+ >>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=3, optimizer=torch.optim.Adam([torch.Tensor(1)]))
80
81
 
81
82
  The configuration can be passed to :func:`~model_compression_toolkit.pytorch_post_training_quantization` in order to quantize a pytorch model using gptq.
82
83
 
83
84
  """
84
- bias_optimizer = Adam([torch.Tensor([])], lr=LR_BIAS_DEFAULT)
85
- optimizer_quantization_parameter = Adam([torch.Tensor([])], lr=LR_QUANTIZATION_PARAM_DEFAULT)
86
- # TODO: Once implementing Soft Quantizer for GPTQ in Pytorch:
87
- # - change default quantization_parameters_learning to True.
88
- # - remove explicit rounding_type and quantizer_config (and let it use the default GradientPTQConfig).
89
- return GradientPTQConfigV2(n_epochs,
90
- optimizer,
91
- optimizer_rest=optimizer_rest,
92
- loss=loss,
93
- log_function=log_function,
94
- train_bias=True,
95
- optimizer_quantization_parameter=optimizer_quantization_parameter,
96
- optimizer_bias=bias_optimizer,
97
- rounding_type=RoundingType.STE,
98
- quantizer_config=GPTQQuantizerConfig(),
99
- quantization_parameters_learning=False,
100
- )
85
+ bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
86
+ return GradientPTQConfigV2(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
87
+ log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer)
101
88
 
102
89
 
103
90
  def pytorch_gradient_post_training_quantization_experimental(model: Module,
@@ -131,7 +118,7 @@ if FOUND_TORCH:
131
118
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
132
119
  gptq_config (GradientPTQConfigV2): Configuration for using gptq (e.g. optimizer).
133
120
  gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
134
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to. `Default PyTorch TPC <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/tpc_models/pytorch_tp_models/pytorch_default.py>`_
121
+ target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
135
122
  new_experimental_exporter (bool): Whether exporting the quantized model using new exporter or not (in progress. Avoiding it for now is recommended).
136
123
 
137
124
  Returns:
@@ -155,26 +142,26 @@ if FOUND_TORCH:
155
142
 
156
143
  Create MCT core configurations with number of calibration iterations set to 1:
157
144
 
158
- >>> config = mct.CoreConfig()
145
+ >>> config = mct.core.CoreConfig()
159
146
 
160
147
  Pass the module, the representative dataset generator and the configuration (optional) to get a quantized module
161
148
 
162
- >>> quantized_module, quantization_info = mct.pytorch_gradient_post_training_quantization_experimental(module, repr_datagen, core_config=config, gptq_config=gptq_conf)
149
+ >>> quantized_module, quantization_info = mct.gptq.pytorch_gradient_post_training_quantization_experimental(module, repr_datagen, core_config=config, gptq_config=gptq_conf)
163
150
 
164
151
  """
165
152
 
166
153
  if core_config.mixed_precision_enable:
167
154
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
168
- common.Logger.error("Given quantization config to mixed-precision facade is not of type "
155
+ Logger.error("Given quantization config to mixed-precision facade is not of type "
169
156
  "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
170
157
  "API, or pass a valid mixed precision configuration.") # pragma: no cover
171
158
 
172
- common.Logger.info("Using experimental mixed-precision quantization. "
159
+ Logger.info("Using experimental mixed-precision quantization. "
173
160
  "If you encounter an issue please file a bug.")
174
161
 
175
162
  tb_w = _init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
176
163
 
177
- fw_impl = PytorchImplementation()
164
+ fw_impl = GPTQPytorchImplemantation()
178
165
 
179
166
  # ---------------------- #
180
167
  # Core Runner
@@ -205,7 +192,7 @@ if FOUND_TORCH:
205
192
  Logger.warning('Using new experimental exported models. '
206
193
  'Please do not use unless you are familiar with what you are doing')
207
194
 
208
- return get_fully_quantized_pytorch_model(graph_gptq)
195
+ return get_exportable_pytorch_model(graph_gptq)
209
196
 
210
197
  return export_model(graph_gptq,
211
198
  DEFAULT_PYTORCH_INFO,
@@ -15,3 +15,4 @@
15
15
 
16
16
  import model_compression_toolkit.gptq.pytorch.quantizer.ste_rounding.symmetric_ste
17
17
  import model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.symmetric_soft_quantizer
18
+ import model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.uniform_soft_quantizer
@@ -13,10 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from abc import abstractmethod
16
- from typing import Union, Dict, List
16
+ from typing import Union, Dict
17
17
 
18
- from model_compression_toolkit.core.common.logger import Logger
19
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import FOUND_TORCH
20
20
  from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
21
21
 
22
22
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
@@ -71,26 +71,6 @@ if FOUND_TORCH:
71
71
 
72
72
  return weights, quant_config, {}
73
73
 
74
- def get_aux_variable(self) -> List[Tensor]:
75
- """
76
- This function return a list with the quantizer's quantization auxiliary variables.
77
-
78
- Returns: A list with the quantization auxiliary variables.
79
-
80
- """
81
-
82
- return [] # pragma: no cover
83
-
84
- def get_quantization_variable(self) -> List[Tensor]:
85
- """
86
- This function return a list with the quantizer's quantization parameters variables.
87
-
88
- Returns: A list with the quantization parameters.
89
-
90
- """
91
-
92
- return [] # pragma: no cover
93
-
94
74
  @abstractmethod
95
75
  def get_quant_config(self):
96
76
  """
@@ -14,9 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import Union, Tuple
16
16
  import torch
17
- from torch.nn.functional import softmax, log_softmax, one_hot
18
- from model_compression_toolkit.core.common.constants import MIN_THRESHOLD
19
- from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
17
+ from model_compression_toolkit.constants import MIN_THRESHOLD
20
18
 
21
19
 
22
20
  def power_of_two_max(max_tensor: torch.Tensor) -> torch.Tensor:
@@ -30,11 +28,20 @@ def calculate_delta(max_tensor: torch.Tensor,
30
28
  num_bits: int,
31
29
  signed: bool) -> torch.Tensor:
32
30
  """
33
- Compute the step size for the quantization.
31
+ Compute the step size for the symmetric quantization.
34
32
  """
35
33
  return max_tensor / (2 ** (num_bits - int(signed)))
36
34
 
37
35
 
36
+ def calculate_delta_uniform(min_tensor: torch.Tensor,
37
+ max_tensor: torch.Tensor,
38
+ num_bits: int) -> torch.Tensor:
39
+ """
40
+ Compute the step size for the uniform quantization.
41
+ """
42
+ return (max_tensor-min_tensor) / (2 ** num_bits - 1)
43
+
44
+
38
45
  def ste_ceil(x: torch.Tensor) -> torch.Tensor:
39
46
  """
40
47
  Return the ceil values of a tensor.
@@ -42,6 +49,13 @@ def ste_ceil(x: torch.Tensor) -> torch.Tensor:
42
49
  return (torch.ceil(x) - x).detach() + x
43
50
 
44
51
 
52
+ def ste_floor(x: torch.Tensor) -> torch.Tensor:
53
+ """
54
+ Return the floor values of a tensor.
55
+ """
56
+ return (torch.floor(x) - x).detach() + x
57
+
58
+
45
59
  def ste_round(x: torch.Tensor) -> torch.Tensor:
46
60
  """
47
61
  Calculate the rounded values of a tensor
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import List, Dict, Tuple
16
16
 
17
- from model_compression_toolkit import GradientPTQConfigV2
17
+ from model_compression_toolkit.gptq import GradientPTQConfigV2
18
18
  from model_compression_toolkit.core import common
19
19
  from model_compression_toolkit.core.pytorch.constants import KERNEL
20
20
  from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizer import \
@@ -59,7 +59,7 @@ def quantization_builder(n: common.BaseNode,
59
59
  quant_method=quant_method,
60
60
  quantizer_base_class=BasePytorchGPTQTrainableQuantizer)
61
61
  weights_quantizers.update({KERNEL: quantizer_class(get_trainable_quantizer_weights_config(n),
62
- **gptq_config.get_extended_quantizer_parametes())})
62
+ **gptq_config.gptq_quantizer_params_override)})
63
63
  activation_quantizers = []
64
64
  if n.is_activation_quantization_enabled():
65
65
  quant_method = n.final_activation_quantization_cfg.activation_quantization_method
@@ -0,0 +1,45 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Callable
16
+
17
+ from model_compression_toolkit.gptq import RoundingType, GradientPTQConfigV2, GradientPTQConfig
18
+ from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import \
19
+ SoftQuantizerRegularization
20
+
21
+
22
+ def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen: Callable) -> Callable:
23
+ """
24
+ Returns a function that computes the regularization term for GPTQ training based on the given
25
+ rounding type in the GPTQ configuration.
26
+
27
+ Args:
28
+ gptq_config: A GPTQ configuration.
29
+ representative_data_gen: Dataset used for the GPTQ training.
30
+
31
+ Returns: A function for computing the regularization. If there is no regularization function defined for the given
32
+ rounding type, then it returns a function that just returns 0.
33
+
34
+ """
35
+ if gptq_config.rounding_type == RoundingType.SoftQuantizer:
36
+ # dry run on the representative dataset to count number of batches
37
+ num_batches = 0
38
+ for _ in representative_data_gen():
39
+ num_batches += 1
40
+
41
+ n_epochs = GradientPTQConfigV2.from_v1(n_ptq_iter=num_batches, config_v1=gptq_config).n_epochs if \
42
+ not type(gptq_config) == GradientPTQConfigV2 else gptq_config.n_epochs
43
+ return SoftQuantizerRegularization(total_gradient_steps=num_batches * n_epochs)
44
+ else:
45
+ return lambda m, e_reg: 0