mct-nightly 1.8.0.22032023.post333__py3-none-any.whl → 1.8.0.22052023.post408__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/METADATA +4 -3
  2. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/RECORD +294 -284
  3. model_compression_toolkit/__init__.py +9 -32
  4. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  5. model_compression_toolkit/core/__init__.py +14 -0
  6. model_compression_toolkit/core/analyzer.py +3 -2
  7. model_compression_toolkit/core/common/__init__.py +0 -1
  8. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  9. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  11. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  12. model_compression_toolkit/core/common/framework_info.py +1 -1
  13. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  14. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  15. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  16. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  18. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  19. model_compression_toolkit/core/common/memory_computation.py +1 -1
  20. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  21. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  23. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  26. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  28. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  29. model_compression_toolkit/core/common/model_collector.py +2 -2
  30. model_compression_toolkit/core/common/model_validation.py +1 -1
  31. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  32. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  33. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  34. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  35. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  36. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  37. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  38. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  39. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  50. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  51. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  52. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  54. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  55. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  56. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  57. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  58. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  60. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  63. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  65. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  66. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  67. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  69. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +66 -21
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  73. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  74. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  75. model_compression_toolkit/core/keras/constants.py +0 -7
  76. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  85. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  86. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  87. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  88. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  89. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  90. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  91. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  92. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  93. model_compression_toolkit/core/keras/reader/common.py +1 -1
  94. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  95. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  99. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  100. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/constants.py +0 -6
  102. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  103. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  109. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  110. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  111. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  112. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  113. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  114. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  115. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  116. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  117. model_compression_toolkit/core/runner.py +7 -7
  118. model_compression_toolkit/exporter/__init__.py +6 -3
  119. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  120. model_compression_toolkit/exporter/model_exporter/keras/export_serialization_format.py +20 -0
  121. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/fakely_quant_tflite_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/int8_tflite_exporter.py +1 -1
  124. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +60 -22
  125. model_compression_toolkit/exporter/model_exporter/pytorch/export_serialization_format.py +20 -0
  126. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  127. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  128. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +54 -31
  129. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +5 -3
  130. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +5 -3
  131. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +2 -2
  132. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +3 -2
  133. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +4 -3
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +2 -2
  135. model_compression_toolkit/gptq/common/gptq_config.py +2 -4
  136. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  137. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  138. model_compression_toolkit/gptq/common/gptq_training.py +5 -4
  139. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  140. model_compression_toolkit/gptq/keras/gptq_training.py +41 -14
  141. model_compression_toolkit/gptq/keras/graph_info.py +4 -0
  142. model_compression_toolkit/gptq/keras/quantization_facade.py +27 -20
  143. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  144. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -2
  145. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  146. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -5
  147. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +21 -16
  148. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  149. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +2 -2
  150. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  151. model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
  152. model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -13
  153. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  154. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -3
  155. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  156. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -2
  157. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +13 -5
  158. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  159. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +2 -2
  160. model_compression_toolkit/gptq/runner.py +3 -2
  161. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +11 -12
  162. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  163. model_compression_toolkit/ptq/__init__.py +3 -0
  164. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  165. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  166. model_compression_toolkit/qat/__init__.py +4 -0
  167. model_compression_toolkit/qat/common/__init__.py +1 -2
  168. model_compression_toolkit/qat/common/qat_config.py +5 -1
  169. model_compression_toolkit/qat/keras/quantization_facade.py +34 -28
  170. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  171. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +31 -4
  172. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +13 -11
  173. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +9 -9
  174. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  175. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  176. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +4 -3
  177. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +7 -5
  178. model_compression_toolkit/quantizers_infrastructure/__init__.py +2 -2
  179. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  180. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  181. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +5 -0
  182. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +2 -2
  183. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +147 -0
  184. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +5 -5
  185. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  186. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  187. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +3 -5
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +2 -3
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  211. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +9 -9
  212. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  213. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -6
  214. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  215. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  217. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +2 -2
  218. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  219. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  220. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  221. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  222. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  223. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  224. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  225. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  226. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  227. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  228. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  229. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  232. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  233. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  234. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  235. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  236. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  237. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  238. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  239. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  240. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  241. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  242. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  243. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  244. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  248. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  250. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  254. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  255. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  257. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  259. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  261. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  263. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  265. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  273. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  274. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  275. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  276. model_compression_toolkit/exporter/model_exporter/tflite/__init__.py +0 -14
  277. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +0 -73
  278. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/LICENSE.md +0 -0
  279. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/WHEEL +0 -0
  280. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/top_level.txt +0 -0
  281. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  282. /model_compression_toolkit/{core/common/logger.py → logger.py} +0 -0
  283. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  284. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  285. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  286. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  287. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  288. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  289. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  290. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  291. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  292. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  293. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  294. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import tensorflow as tf
17
- from model_compression_toolkit.core.common.constants import MIN_THRESHOLD
17
+ from model_compression_toolkit.constants import MIN_THRESHOLD
18
18
  from typing import Tuple
19
19
 
20
20
 
@@ -26,6 +26,14 @@ def ste_ceil(x: tf.Tensor) -> tf.Tensor:
26
26
  return error + x
27
27
 
28
28
 
29
+ def ste_floor(x: tf.Tensor) -> tf.Tensor:
30
+ """
31
+ Return the floor values of a tensor.
32
+ """
33
+ error = tf.stop_gradient(tf.math.floor(x) - x)
34
+ return error + x
35
+
36
+
29
37
  def safe_log(x: tf.Tensor, eps: float) -> tf.Tensor:
30
38
  """
31
39
  Computes log function of x unless x is smaller than some small value, so the log function would not fail.
@@ -72,6 +80,15 @@ def calculate_delta(max_tensor: tf.Tensor,
72
80
  return max_tensor / (2 ** (num_bits - int(signed)))
73
81
 
74
82
 
83
+ def calculate_delta_uniform(min_tensor: tf.Tensor,
84
+ max_tensor: tf.Tensor,
85
+ num_bits: int) -> tf.Tensor:
86
+ """
87
+ Compute the step size for the uniform quantization.
88
+ """
89
+ return (max_tensor-min_tensor) / (2 ** num_bits - 1)
90
+
91
+
75
92
  def ste_clip(x: [tf.Tensor, tf.Variable], max_val=1, min_val=None) -> tf.Tensor:
76
93
  """
77
94
  clip a variable between fixed values such that min_val<=output<=max_val
@@ -77,7 +77,7 @@ class SoftQuantizerRegularization:
77
77
  # Initializing the temperature decay according to the number of expected gradient steps
78
78
  self.linear_decay = LinearTempDecay(total_gradient_steps)
79
79
 
80
- self.count_iter = 0
80
+ self.count_iter = tf.Variable(0.)
81
81
 
82
82
 
83
83
  def __call__(self, model: Model, entropy_reg: float):
@@ -90,16 +90,14 @@ class SoftQuantizerRegularization:
90
90
 
91
91
  Returns: Regularization value.
92
92
  """
93
-
94
93
  soft_reg_aux: List[tf.Tensor] = []
94
+ b = self.linear_decay(self.count_iter.value())
95
95
  for layer in model.layers:
96
96
  if isinstance(layer, KerasQuantizationWrapper):
97
97
  kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
98
98
  fw_info=DEFAULT_KERAS_INFO)
99
99
 
100
100
  st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
101
- b = self.linear_decay(self.count_iter)
102
-
103
101
  soft_reg_aux.append(tf.reduce_sum(1 - tf.pow(tf.math.abs(st - .5) * 2, b)))
104
102
 
105
103
  reg = 0
@@ -107,6 +105,6 @@ class SoftQuantizerRegularization:
107
105
  for sq in soft_reg_aux:
108
106
  reg += sq
109
107
 
110
- self.count_iter += 1
108
+ self.count_iter.assign_add(1.0)
111
109
 
112
110
  return entropy_reg * reg
@@ -19,12 +19,12 @@ import numpy as np
19
19
  from model_compression_toolkit.gptq import RoundingType
20
20
  from model_compression_toolkit import quantizers_infrastructure as qi
21
21
  from model_compression_toolkit.core.common import max_power_of_two
22
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
23
  from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
24
24
  SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
25
25
  from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
26
26
  from typing import Dict, Any
27
- from model_compression_toolkit.core.common.constants import THRESHOLD, MIN_THRESHOLD
27
+ from model_compression_toolkit.constants import THRESHOLD, MIN_THRESHOLD
28
28
  from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
29
29
  from model_compression_toolkit.gptq.keras.quantizer.quant_utils import power_of_two_max, clip, calculate_delta
30
30
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
@@ -188,6 +188,13 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
188
188
 
189
189
  ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
190
190
 
191
+ #####################################################
192
+ # Soft Rounding
193
+ #####################################################
194
+ aux_var = self.get_soft_targets()
195
+ if not training:
196
+ aux_var = tf.cast(tf.math.greater_equal(aux_var, 0.5), tf.float32)
197
+
191
198
  if self.per_channel:
192
199
  reshape_shape = get_threshold_reshape_shape(inputs.shape,
193
200
  quant_axis=self.quantization_axis,
@@ -197,13 +204,6 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
197
204
  # Calculate soft rounding targets and optimized threshold
198
205
  ##########################################################
199
206
  ptq_threshold_tensor_hat = tf.reshape(ptq_threshold_tensor, reshape_shape)
200
- aux_var = self.get_soft_targets()
201
-
202
- #####################################################
203
- # Soft Rounding
204
- #####################################################
205
- if not training:
206
- aux_var = tf.cast(tf.math.greater_equal(aux_var, 0.5), tf.float32)
207
207
 
208
208
  #####################################################
209
209
  # Quantized Input
@@ -219,14 +219,19 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
219
219
  scale = tf.reshape(self.get_quantizer_variable(SCALE_PTQ), reshape_shape)
220
220
  q_tensor *= scale
221
221
 
222
- return q_tensor
223
222
  else:
224
- return soft_rounding_symmetric_quantizer(input_tensor=inputs,
225
- auxvar_tensor=self.quantizer_parameters[AUXVAR]['var'],
226
- threshold_tensor=ptq_threshold_tensor.value(),
227
- num_bits=self.num_bits,
228
- signed=True,
229
- power_of_two=self.power_of_two)
223
+ q_tensor = soft_rounding_symmetric_quantizer(input_tensor=inputs,
224
+ auxvar_tensor=aux_var,
225
+ threshold_tensor=ptq_threshold_tensor.value(),
226
+ num_bits=self.num_bits,
227
+ signed=True,
228
+ power_of_two=self.power_of_two)
229
+
230
+ if self.quantization_parameter_learning and not self.power_of_two:
231
+ scale = self.get_quantizer_variable(SCALE_PTQ)
232
+ q_tensor *= scale
233
+
234
+ return q_tensor
230
235
 
231
236
  def get_quant_config(self) -> Dict[str, np.ndarray]:
232
237
  """
@@ -0,0 +1,224 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import tensorflow as tf
17
+ import numpy as np
18
+
19
+ from model_compression_toolkit.gptq import RoundingType
20
+ from model_compression_toolkit import quantizers_infrastructure as qi
21
+ from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
+ from model_compression_toolkit.gptq.common.gptq_constants import \
24
+ SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
25
+ from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
26
+ from typing import Dict, Any
27
+ from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX
28
+ from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
29
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
30
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
31
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
32
+ get_threshold_reshape_shape
33
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
34
+
35
+
36
+ def soft_rounding_uniform_quantizer(input_tensor: tf.Tensor,
37
+ auxvar_tensor: tf.Variable,
38
+ min_tensor: tf.Tensor,
39
+ max_tensor: tf.Tensor,
40
+ num_bits: int) -> tf.Tensor:
41
+ """
42
+ Quantize a tensor uniformly for GPTQ quantizers.
43
+
44
+ Args:
45
+ input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
46
+ auxvar_tensor: Tensor that manifests the bit shift of the quantized weights due to gptq training.
47
+ min_tensor: Tensor with values to compute the min threshold.
48
+ max_tensor: Tensor with values to compute the max threshold.
49
+ num_bits: Num of bits to use.
50
+
51
+ Returns:
52
+ A quantized tensor.
53
+ """
54
+ # adjusts the quantization range so the quantization grid includes zero.
55
+ min_range, max_range = qutils.fix_range_to_include_zero(min_tensor, max_tensor, num_bits)
56
+ delta = qutils.calculate_delta_uniform(min_range, max_range, num_bits)
57
+ input_tensor_int = qutils.ste_floor((input_tensor - min_range) / delta)
58
+ tensor_q = input_tensor_int + auxvar_tensor
59
+ return delta * qutils.ste_clip(tensor_q,
60
+ min_val=0,
61
+ max_val=2 ** num_bits - 1) + min_range
62
+
63
+
64
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
65
+ quantization_method=[QuantizationMethod.UNIFORM],
66
+ quantizer_type=RoundingType.SoftQuantizer)
67
+ class UniformSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
68
+ """
69
+ Trainable uniform quantizer to optimize the rounding of the quantized values using a soft quantization method.
70
+ """
71
+
72
+ def __init__(self,
73
+ quantization_config: TrainableQuantizerWeightsConfig,
74
+ quantization_parameter_learning: bool = False):
75
+ """
76
+ Initialize a UniformSoftRoundingGPTQ object with parameters to use
77
+ for the quantization.
78
+
79
+ Args:
80
+ quantization_config: Trainable weight quantizer config.
81
+ quantization_parameter_learning: Whether to train the quantization threshold.
82
+ """
83
+ super().__init__(quantization_config)
84
+ self.num_bits = quantization_config.weights_n_bits
85
+ self.per_channel = quantization_config.weights_per_channel_threshold
86
+
87
+ self.min_values = quantization_config.weights_quantization_params[RANGE_MIN]
88
+ self.max_values = quantization_config.weights_quantization_params[RANGE_MAX]
89
+
90
+ self.quantization_axis = quantization_config.weights_channels_axis
91
+ assert quantization_parameter_learning is False, \
92
+ "Quantization parameters learning in UniformSoftRoundingGPTQ not implemented yet"
93
+ self.quantization_parameter_learning = quantization_parameter_learning
94
+ self.num_channels = self.min_values.shape[self.quantization_axis] if self.per_channel else 1
95
+
96
+ # gamma and zeta are stretch parameters for computing the rectified sigmoid function.
97
+ # See: https://arxiv.org/pdf/2004.10568.pdf
98
+ self.gamma = SOFT_ROUNDING_GAMMA
99
+ self.zeta = SOFT_ROUNDING_ZETA
100
+
101
+ def initialize_quantization(self,
102
+ tensor_shape: Any,
103
+ name: str,
104
+ layer: Any):
105
+ """
106
+ Add quantizer parameters to the quantizer parameters dictionary
107
+
108
+ Args:
109
+ tensor_shape: tensor shape of the quantized tensor.
110
+ name: Tensor name.
111
+ layer: Layer to quantize.
112
+ """
113
+
114
+ if self.per_channel:
115
+ reshape_shape = get_threshold_reshape_shape(tensor_shape,
116
+ quant_axis=self.quantization_axis,
117
+ quant_axis_dim=self.num_channels)
118
+ else:
119
+ reshape_shape = [self.num_channels]
120
+
121
+ min_tensor = layer.add_weight(
122
+ f"{name}_{FQ_MIN}",
123
+ shape=reshape_shape,
124
+ initializer=tf.keras.initializers.Constant(1.0),
125
+ trainable=False)
126
+ min_tensor.assign(self.min_values.reshape(reshape_shape))
127
+
128
+ max_tensor = layer.add_weight(
129
+ f"{name}_{FQ_MAX}",
130
+ shape=reshape_shape,
131
+ initializer=tf.keras.initializers.Constant(1.0),
132
+ trainable=False)
133
+ max_tensor.assign(self.max_values.reshape(reshape_shape))
134
+
135
+ w = getattr(layer.layer, name)
136
+ auxvar_tensor = layer.add_weight(
137
+ f"{name}_{AUXVAR}",
138
+ shape=list(w.shape),
139
+ initializer=tf.keras.initializers.Constant(0.0),
140
+ trainable=True)
141
+
142
+ w = layer.layer.depthwise_kernel if isinstance(layer.layer, (tf.keras.layers.DepthwiseConv2D,
143
+ tf.keras.layers.DepthwiseConv1D)) \
144
+ else layer.layer.kernel
145
+ delta = qutils.calculate_delta_uniform(min_tensor, max_tensor, self.num_bits)
146
+ w_clipped_normed = qutils.clip((w - min_tensor)/ delta, 0, 2 ** self.num_bits - 1)
147
+ rest = w_clipped_normed - tf.floor(w_clipped_normed) # rest of rounding [0, 1)
148
+ alpha = -qutils.safe_log((self.zeta - self.gamma) / (rest - self.gamma) - 1, 1e-16) # => sigmoid(alpha) = rest
149
+ auxvar_tensor.assign(alpha)
150
+
151
+ # Add quantization variables
152
+ self.add_quantizer_variable(AUXVAR, auxvar_tensor, VariableGroup.WEIGHTS)
153
+ self.add_quantizer_variable(RANGE_MIN, min_tensor, VariableGroup.QPARAMS)
154
+ self.add_quantizer_variable(RANGE_MAX, max_tensor, VariableGroup.QPARAMS)
155
+
156
+ def get_soft_targets(self) -> tf.Tensor:
157
+ """
158
+ Computes the rectified sigmoid function for the quantization target parameters.
159
+
160
+ Returns:
161
+ A tensor with the soft rounding targets values.
162
+
163
+ """
164
+ return qutils.clip(
165
+ tf.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma, 1, 0)
166
+
167
+ def __call__(self,
168
+ inputs: tf.Tensor,
169
+ training: bool):
170
+ """
171
+ Quantize a tensor.
172
+
173
+ Args:
174
+ inputs: Input tensor to quantize.
175
+ training: Whether the graph is in training mode.
176
+
177
+ Returns:
178
+ The quantized tensor.
179
+ """
180
+
181
+ min_tensor = self.get_quantizer_variable(RANGE_MIN)
182
+ max_tensor = self.get_quantizer_variable(RANGE_MAX)
183
+
184
+ #####################################################
185
+ # Soft Rounding
186
+ #####################################################
187
+ aux_var = self.get_soft_targets()
188
+ if not training:
189
+ aux_var = tf.cast(tf.math.greater_equal(aux_var, 0.5), tf.float32)
190
+
191
+ if self.per_channel:
192
+ reshape_shape = get_threshold_reshape_shape(inputs.shape,
193
+ quant_axis=self.quantization_axis,
194
+ quant_axis_dim=-1)
195
+
196
+ #####################################################
197
+ # Quantized Input
198
+ #####################################################
199
+ q_tensor = soft_rounding_uniform_quantizer(input_tensor=inputs,
200
+ auxvar_tensor=aux_var,
201
+ min_tensor=tf.reshape(min_tensor, reshape_shape),
202
+ max_tensor=tf.reshape(max_tensor, reshape_shape),
203
+ num_bits=self.num_bits)
204
+
205
+ else:
206
+ q_tensor = soft_rounding_uniform_quantizer(input_tensor=inputs,
207
+ auxvar_tensor=aux_var,
208
+ min_tensor=min_tensor,
209
+ max_tensor=max_tensor,
210
+ num_bits=self.num_bits)
211
+
212
+ return q_tensor
213
+
214
+ def get_quant_config(self) -> Dict[str, np.ndarray]:
215
+ """
216
+ Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
217
+
218
+ Returns:
219
+ A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
220
+ Keys must match NodeQuantizationConfig attributes
221
+ """
222
+
223
+ return {RANGE_MIN: self.get_quantizer_variable(RANGE_MIN).numpy(),
224
+ RANGE_MAX: self.get_quantizer_variable(RANGE_MAX).numpy()}
@@ -20,10 +20,10 @@ import tensorflow as tf
20
20
 
21
21
  from model_compression_toolkit.gptq import RoundingType
22
22
  from model_compression_toolkit import quantizers_infrastructure as qi
23
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
23
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
24
24
  from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD
25
25
  from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
26
- from model_compression_toolkit.core.common.constants import THRESHOLD
26
+ from model_compression_toolkit.constants import THRESHOLD
27
27
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
28
28
  from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
29
29
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
@@ -0,0 +1,29 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Type
17
+
18
+ from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
19
+ from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
20
+ from model_compression_toolkit.gptq.pytorch.gptq_training import PytorchGPTQTrainer
21
+
22
+
23
+ class GPTQPytorchImplemantation(GPTQFrameworkImplemantation, PytorchImplementation):
24
+
25
+ def get_gptq_trainer_obj(self) -> Type[PytorchGPTQTrainer]:
26
+ """
27
+ Returns: Pytorch object of GPTQTrainer
28
+ """
29
+ return PytorchGPTQTrainer
@@ -19,7 +19,7 @@ from torch.nn import Module
19
19
  from tqdm import tqdm
20
20
  import copy
21
21
  import torch
22
- from model_compression_toolkit.core.common.logger import Logger
22
+ from model_compression_toolkit.logger import Logger
23
23
  from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
24
24
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
25
25
  from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
@@ -14,18 +14,18 @@
14
14
  # ==============================================================================
15
15
  from typing import Callable
16
16
  from model_compression_toolkit.core import common
17
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.constants import PYTORCH
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import PYTORCH
20
20
  from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
21
- from model_compression_toolkit.core.common.target_platform import TargetPlatformCapabilities
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
22
22
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
23
23
  from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
24
24
  from model_compression_toolkit.gptq.keras.quantization_facade import GPTQ_MOMENTUM
25
25
  from model_compression_toolkit.gptq.runner import gptq_runner
26
26
  from model_compression_toolkit.core.exporter import export_model
27
27
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
28
- from model_compression_toolkit import CoreConfig
28
+ from model_compression_toolkit.core import CoreConfig
29
29
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
30
30
  MixedPrecisionQuantizationConfigV2
31
31
 
@@ -36,8 +36,8 @@ LR_QUANTIZATION_PARAM_DEFAULT = 1e-4
36
36
 
37
37
  if FOUND_TORCH:
38
38
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
39
- from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
40
- from model_compression_toolkit.core.pytorch.constants import DEFAULT_TP_MODEL
39
+ from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation
40
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
41
41
  from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss
42
42
  from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
43
43
  import torch
@@ -118,7 +118,7 @@ if FOUND_TORCH:
118
118
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
119
119
  gptq_config (GradientPTQConfigV2): Configuration for using gptq (e.g. optimizer).
120
120
  gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
121
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to. `Default PyTorch TPC <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/tpc_models/pytorch_tp_models/pytorch_default.py>`_
121
+ target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
122
122
  new_experimental_exporter (bool): Whether exporting the quantized model using new exporter or not (in progress. Avoiding it for now is recommended).
123
123
 
124
124
  Returns:
@@ -142,7 +142,7 @@ if FOUND_TORCH:
142
142
 
143
143
  Create MCT core configurations with number of calibration iterations set to 1:
144
144
 
145
- >>> config = mct.CoreConfig()
145
+ >>> config = mct.core.CoreConfig()
146
146
 
147
147
  Pass the module, the representative dataset generator and the configuration (optional) to get a quantized module
148
148
 
@@ -152,16 +152,16 @@ if FOUND_TORCH:
152
152
 
153
153
  if core_config.mixed_precision_enable:
154
154
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
155
- common.Logger.error("Given quantization config to mixed-precision facade is not of type "
155
+ Logger.error("Given quantization config to mixed-precision facade is not of type "
156
156
  "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
157
157
  "API, or pass a valid mixed precision configuration.") # pragma: no cover
158
158
 
159
- common.Logger.info("Using experimental mixed-precision quantization. "
159
+ Logger.info("Using experimental mixed-precision quantization. "
160
160
  "If you encounter an issue please file a bug.")
161
161
 
162
162
  tb_w = _init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
163
163
 
164
- fw_impl = PytorchImplementation()
164
+ fw_impl = GPTQPytorchImplemantation()
165
165
 
166
166
  # ---------------------- #
167
167
  # Core Runner
@@ -192,7 +192,7 @@ if FOUND_TORCH:
192
192
  Logger.warning('Using new experimental exported models. '
193
193
  'Please do not use unless you are familiar with what you are doing')
194
194
 
195
- return get_fully_quantized_pytorch_model(graph_gptq)
195
+ return get_exportable_pytorch_model(graph_gptq)
196
196
 
197
197
  return export_model(graph_gptq,
198
198
  DEFAULT_PYTORCH_INFO,
@@ -15,3 +15,4 @@
15
15
 
16
16
  import model_compression_toolkit.gptq.pytorch.quantizer.ste_rounding.symmetric_ste
17
17
  import model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.symmetric_soft_quantizer
18
+ import model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.uniform_soft_quantizer
@@ -13,10 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from abc import abstractmethod
16
- from typing import Union, Dict, List
16
+ from typing import Union, Dict
17
17
 
18
- from model_compression_toolkit.core.common.logger import Logger
19
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import FOUND_TORCH
20
20
  from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
21
21
 
22
22
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
@@ -14,9 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import Union, Tuple
16
16
  import torch
17
- from torch.nn.functional import softmax, log_softmax, one_hot
18
- from model_compression_toolkit.core.common.constants import MIN_THRESHOLD
19
- from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
17
+ from model_compression_toolkit.constants import MIN_THRESHOLD
20
18
 
21
19
 
22
20
  def power_of_two_max(max_tensor: torch.Tensor) -> torch.Tensor:
@@ -30,11 +28,20 @@ def calculate_delta(max_tensor: torch.Tensor,
30
28
  num_bits: int,
31
29
  signed: bool) -> torch.Tensor:
32
30
  """
33
- Compute the step size for the quantization.
31
+ Compute the step size for the symmetric quantization.
34
32
  """
35
33
  return max_tensor / (2 ** (num_bits - int(signed)))
36
34
 
37
35
 
36
+ def calculate_delta_uniform(min_tensor: torch.Tensor,
37
+ max_tensor: torch.Tensor,
38
+ num_bits: int) -> torch.Tensor:
39
+ """
40
+ Compute the step size for the uniform quantization.
41
+ """
42
+ return (max_tensor-min_tensor) / (2 ** num_bits - 1)
43
+
44
+
38
45
  def ste_ceil(x: torch.Tensor) -> torch.Tensor:
39
46
  """
40
47
  Return the ceil values of a tensor.
@@ -42,6 +49,13 @@ def ste_ceil(x: torch.Tensor) -> torch.Tensor:
42
49
  return (torch.ceil(x) - x).detach() + x
43
50
 
44
51
 
52
+ def ste_floor(x: torch.Tensor) -> torch.Tensor:
53
+ """
54
+ Return the floor values of a tensor.
55
+ """
56
+ return (torch.floor(x) - x).detach() + x
57
+
58
+
45
59
  def ste_round(x: torch.Tensor) -> torch.Tensor:
46
60
  """
47
61
  Calculate the rounded values of a tensor
@@ -95,14 +95,13 @@ class SoftQuantizerRegularization:
95
95
  """
96
96
 
97
97
  soft_reg_aux: List[torch.Tensor] = []
98
+ b = self.linear_decay(self.count_iter)
98
99
  for layer in model.modules():
99
100
  if isinstance(layer, PytorchQuantizationWrapper):
100
101
  kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
101
102
  fw_info=DEFAULT_PYTORCH_INFO)
102
103
 
103
104
  st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
104
- b = self.linear_decay(self.count_iter)
105
-
106
105
  soft_reg_aux.append((1 - torch.pow(torch.abs(st - .5) * 2, b)).sum())
107
106
 
108
107
  reg = 0
@@ -19,7 +19,7 @@ import numpy as np
19
19
 
20
20
  from model_compression_toolkit.core.common import max_power_of_two
21
21
  from model_compression_toolkit import quantizers_infrastructure as qi
22
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
23
  from model_compression_toolkit.gptq.common.gptq_config import RoundingType
24
24
  from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
25
25
  BasePytorchGPTQTrainableQuantizer
@@ -27,7 +27,7 @@ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_
27
27
  from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
28
28
  from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
29
29
  SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
30
- from model_compression_toolkit.core.common.constants import THRESHOLD, MIN_THRESHOLD
30
+ from model_compression_toolkit.constants import THRESHOLD, MIN_THRESHOLD
31
31
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
32
32
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
33
33
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
@@ -142,9 +142,13 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
142
142
  self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
143
143
 
144
144
  if self.quantization_parameter_learning:
145
- layer.register_parameter(f"{name}_{SCALE_PTQ}",
146
- nn.Parameter(to_torch_tensor(torch.ones_like(torch.Tensor(self.threshold_values))),
147
- requires_grad=True))
145
+ if self.per_channel:
146
+ layer.register_parameter(f"{name}_{SCALE_PTQ}",
147
+ nn.Parameter(to_torch_tensor(torch.ones_like(torch.Tensor(self.threshold_values))),
148
+ requires_grad=True))
149
+ else:
150
+ layer.register_parameter(f"{name}_{SCALE_PTQ}",
151
+ nn.Parameter(to_torch_tensor((torch.tensor([1.0], requires_grad=True)))))
148
152
  self.add_quantizer_variable(SCALE_PTQ, layer.get_parameter(f"{name}_{SCALE_PTQ}"), VariableGroup.QPARAMS)
149
153
 
150
154
  def get_soft_targets(self) -> torch.Tensor:
@@ -233,4 +237,8 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
233
237
  signed=True,
234
238
  power_of_two=self.power_of_two)
235
239
 
240
+ if self.quantization_parameter_learning and not self.power_of_two:
241
+ scale = self.get_quantizer_variable(SCALE_PTQ)
242
+ q_tensor *= scale
243
+
236
244
  return q_tensor