mct-nightly 1.8.0.22032023.post333__py3-none-any.whl → 1.8.0.22052023.post408__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/METADATA +4 -3
  2. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/RECORD +294 -284
  3. model_compression_toolkit/__init__.py +9 -32
  4. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  5. model_compression_toolkit/core/__init__.py +14 -0
  6. model_compression_toolkit/core/analyzer.py +3 -2
  7. model_compression_toolkit/core/common/__init__.py +0 -1
  8. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  9. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  11. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  12. model_compression_toolkit/core/common/framework_info.py +1 -1
  13. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  14. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  15. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  16. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  18. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  19. model_compression_toolkit/core/common/memory_computation.py +1 -1
  20. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  21. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  23. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  26. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  28. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  29. model_compression_toolkit/core/common/model_collector.py +2 -2
  30. model_compression_toolkit/core/common/model_validation.py +1 -1
  31. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  32. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  33. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  34. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  35. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  36. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  37. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  38. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  39. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  50. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  51. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  52. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  54. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  55. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  56. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  57. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  58. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  60. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  63. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  65. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  66. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  67. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  69. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +66 -21
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  73. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  74. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  75. model_compression_toolkit/core/keras/constants.py +0 -7
  76. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  85. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  86. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  87. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  88. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  89. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  90. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  91. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  92. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  93. model_compression_toolkit/core/keras/reader/common.py +1 -1
  94. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  95. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  99. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  100. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/constants.py +0 -6
  102. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  103. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  109. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  110. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  111. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  112. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  113. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  114. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  115. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  116. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  117. model_compression_toolkit/core/runner.py +7 -7
  118. model_compression_toolkit/exporter/__init__.py +6 -3
  119. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  120. model_compression_toolkit/exporter/model_exporter/keras/export_serialization_format.py +20 -0
  121. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/fakely_quant_tflite_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/int8_tflite_exporter.py +1 -1
  124. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +60 -22
  125. model_compression_toolkit/exporter/model_exporter/pytorch/export_serialization_format.py +20 -0
  126. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  127. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  128. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +54 -31
  129. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +5 -3
  130. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +5 -3
  131. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +2 -2
  132. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +3 -2
  133. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +4 -3
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +2 -2
  135. model_compression_toolkit/gptq/common/gptq_config.py +2 -4
  136. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  137. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  138. model_compression_toolkit/gptq/common/gptq_training.py +5 -4
  139. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  140. model_compression_toolkit/gptq/keras/gptq_training.py +41 -14
  141. model_compression_toolkit/gptq/keras/graph_info.py +4 -0
  142. model_compression_toolkit/gptq/keras/quantization_facade.py +27 -20
  143. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  144. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -2
  145. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  146. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -5
  147. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +21 -16
  148. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  149. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +2 -2
  150. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  151. model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
  152. model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -13
  153. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  154. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -3
  155. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  156. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -2
  157. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +13 -5
  158. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  159. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +2 -2
  160. model_compression_toolkit/gptq/runner.py +3 -2
  161. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +11 -12
  162. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  163. model_compression_toolkit/ptq/__init__.py +3 -0
  164. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  165. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  166. model_compression_toolkit/qat/__init__.py +4 -0
  167. model_compression_toolkit/qat/common/__init__.py +1 -2
  168. model_compression_toolkit/qat/common/qat_config.py +5 -1
  169. model_compression_toolkit/qat/keras/quantization_facade.py +34 -28
  170. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  171. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +31 -4
  172. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +13 -11
  173. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +9 -9
  174. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  175. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  176. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +4 -3
  177. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +7 -5
  178. model_compression_toolkit/quantizers_infrastructure/__init__.py +2 -2
  179. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  180. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  181. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +5 -0
  182. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +2 -2
  183. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +147 -0
  184. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +5 -5
  185. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  186. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  187. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +3 -5
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +2 -3
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  211. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +9 -9
  212. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  213. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -6
  214. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  215. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  217. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +2 -2
  218. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  219. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  220. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  221. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  222. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  223. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  224. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  225. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  226. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  227. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  228. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  229. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  232. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  233. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  234. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  235. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  236. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  237. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  238. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  239. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  240. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  241. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  242. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  243. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  244. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  248. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  250. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  254. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  255. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  257. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  259. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  261. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  263. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  265. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  273. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  274. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  275. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  276. model_compression_toolkit/exporter/model_exporter/tflite/__init__.py +0 -14
  277. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +0 -73
  278. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/LICENSE.md +0 -0
  279. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/WHEEL +0 -0
  280. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/top_level.txt +0 -0
  281. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  282. /model_compression_toolkit/{core/common/logger.py → logger.py} +0 -0
  283. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  284. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  285. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  286. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  287. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  288. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  289. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  290. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  291. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  292. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  293. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  294. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -0,0 +1,194 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import torch
16
+ import torch.nn as nn
17
+ from typing import Dict
18
+ import numpy as np
19
+
20
+ from model_compression_toolkit import quantizers_infrastructure as qi
21
+ from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
+ from model_compression_toolkit.gptq.common.gptq_config import RoundingType
24
+ from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
25
+ BasePytorchGPTQTrainableQuantizer
26
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
27
+ from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
28
+ from model_compression_toolkit.gptq.common.gptq_constants import SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
29
+ from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import fix_range_to_include_zero
30
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
31
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
32
+ mark_quantizer
33
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import \
34
+ VariableGroup
35
+ from model_compression_toolkit.constants import RANGE_MAX, RANGE_MIN
36
+
37
+
38
+ def soft_rounding_unifrom_quantizer(input_tensor: torch.Tensor,
39
+ auxvar_tensor: torch.Tensor,
40
+ min_range: torch.Tensor,
41
+ max_range: torch.Tensor,
42
+ num_bits: int) -> torch.Tensor:
43
+ """
44
+ Quantize a tensor uniformly for GPTQ quantizers.
45
+
46
+ Args:
47
+ input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
48
+ auxvar_tensor: Tensor that manifests the bit shift of the quantized weights due to gptq training.
49
+ min_range: Tensor with min values to compute the delta grid.
50
+ max_range: Tensor with max values to compute the delta grid.
51
+ num_bits: Num of bits to use.
52
+
53
+ Returns:
54
+ A quantized tensor.
55
+ """
56
+ # adjusts the quantization range so the quantization grid includes zero.
57
+ min_range, max_range = fix_range_to_include_zero(min_range, max_range, num_bits)
58
+ delta = qutils.calculate_delta_uniform(min_range, max_range, num_bits)
59
+ input_tensor_int = qutils.ste_floor((input_tensor - min_range) / delta)
60
+ tensor_q = input_tensor_int + auxvar_tensor
61
+ return delta * qutils.ste_clip(tensor_q,
62
+ min_val=0,
63
+ max_val=2 ** num_bits - 1) + min_range
64
+
65
+
66
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
67
+ quantization_method=[QuantizationMethod.UNIFORM],
68
+ quantizer_type=RoundingType.SoftQuantizer)
69
+ class UniformSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
70
+ """
71
+ Trainable uniform quantizer to optimize the rounding of the quantized values using a soft quantization method.
72
+ """
73
+
74
+ def __init__(self,
75
+ quantization_config: TrainableQuantizerWeightsConfig,
76
+ quantization_parameter_learning: bool = False):
77
+ """
78
+ Construct a Pytorch model that utilize a fake weight quantizer of soft-quantizer for symmetric quantizer.
79
+
80
+ Args:
81
+ quantization_config: Trainable weights quantizer config.
82
+ quantization_parameter_learning (Bool): Whether to learn the min/max ranges or not
83
+ """
84
+
85
+ super().__init__(quantization_config)
86
+ self.num_bits = quantization_config.weights_n_bits
87
+ self.per_channel = quantization_config.weights_per_channel_threshold
88
+
89
+ self.min_values = quantization_config.weights_quantization_params[RANGE_MIN]
90
+ self.max_values = quantization_config.weights_quantization_params[RANGE_MAX]
91
+
92
+ self.quantization_axis = quantization_config.weights_channels_axis
93
+ self.quantization_parameter_learning = quantization_parameter_learning
94
+
95
+ # gamma and zeta are stretch parameters for computing the rectified sigmoid function.
96
+ # See: https://arxiv.org/pdf/2004.10568.pdf
97
+ self.gamma = SOFT_ROUNDING_GAMMA
98
+ self.zeta = SOFT_ROUNDING_ZETA
99
+
100
+ def initialize_quantization(self,
101
+ tensor_shape: torch.Size,
102
+ name: str,
103
+ layer: qi.PytorchQuantizationWrapper):
104
+ """
105
+ Add quantizer parameters to the quantizer parameters dictionary
106
+
107
+ Args:
108
+ tensor_shape: tensor shape of the quantized tensor.
109
+ name: Tensor name.
110
+ layer: Layer to quantize.
111
+ """
112
+
113
+ # Add min and max variables to layer.
114
+ if self.per_channel:
115
+ min_values = to_torch_tensor(self.min_values)
116
+ max_values = to_torch_tensor(self.max_values)
117
+ else:
118
+ min_values = torch.tensor(self.min_values)
119
+ max_values = torch.tensor(self.max_values)
120
+
121
+ layer.register_parameter(name+"_"+FQ_MIN, nn.Parameter(min_values, requires_grad=self.quantization_parameter_learning))
122
+ layer.register_parameter(name+"_"+FQ_MAX, nn.Parameter(max_values, requires_grad=self.quantization_parameter_learning))
123
+
124
+ w = layer.layer.weight
125
+ delta = qutils.calculate_delta_uniform(min_values, max_values, self.num_bits)
126
+ w_clipped_normed = torch.clip((w - min_values) / delta, 0, 2 ** self.num_bits - 1)
127
+ rest = w_clipped_normed - torch.floor(w_clipped_normed) # rest of rounding [0, 1)
128
+ alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest
129
+ layer.register_parameter(f"{name}_{AUXVAR}", nn.Parameter(alpha, requires_grad=True))
130
+
131
+ # Save the quantizer parameters
132
+ self.add_quantizer_variable(FQ_MIN, layer.get_parameter(name+"_"+FQ_MIN), VariableGroup.QPARAMS)
133
+ self.add_quantizer_variable(FQ_MAX, layer.get_parameter(name+"_"+FQ_MAX), VariableGroup.QPARAMS)
134
+ self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
135
+
136
+ def get_soft_targets(self) -> torch.Tensor:
137
+ """
138
+ Computes the rectified sigmoid function for the quantization target parameters.
139
+
140
+ Returns:
141
+ A tensor with the soft rounding targets values.
142
+
143
+ """
144
+ scaled_sigmoid = torch.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma
145
+ return torch.clip(scaled_sigmoid, min=0, max=1)
146
+
147
+ def get_quant_config(self) -> Dict[str, np.ndarray]:
148
+ """
149
+ Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
150
+
151
+ Returns:
152
+ A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
153
+ Keys must match NodeQuantizationConfig attributes
154
+
155
+ """
156
+ min_values = torch_tensor_to_numpy(self.get_quantizer_variable(FQ_MIN))
157
+ max_values = torch_tensor_to_numpy(self.get_quantizer_variable(FQ_MAX))
158
+ return {RANGE_MIN: min_values,
159
+ RANGE_MAX: max_values}
160
+
161
+ def __call__(self,
162
+ inputs: nn.Parameter,
163
+ training: bool) -> torch.Tensor:
164
+ """
165
+ Quantize a tensor.
166
+
167
+ Args:
168
+ inputs: Input tensor to quantize.
169
+ training: whether in training mode or not
170
+
171
+ Returns:
172
+ quantized tensor
173
+ """
174
+ auxvar = self.get_quantizer_variable(AUXVAR)
175
+ min_range = self.get_quantizer_variable(FQ_MIN)
176
+ max_range = self.get_quantizer_variable(FQ_MAX)
177
+
178
+ #####################################################
179
+ # Soft Rounding
180
+ #####################################################
181
+ aux_var = self.get_soft_targets()
182
+ if not training:
183
+ aux_var = (aux_var >= 0.5).to(auxvar.dtype)
184
+
185
+ #####################################################
186
+ # Quantized Input
187
+ #####################################################
188
+ q_tensor = soft_rounding_unifrom_quantizer(input_tensor=inputs,
189
+ auxvar_tensor=aux_var,
190
+ min_range=min_range,
191
+ max_range=max_range,
192
+ num_bits=self.num_bits)
193
+
194
+ return q_tensor
@@ -19,14 +19,14 @@ import numpy as np
19
19
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
20
20
 
21
21
  from model_compression_toolkit import quantizers_infrastructure as qi
22
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
23
  from model_compression_toolkit.gptq.common.gptq_config import RoundingType
24
24
  from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
25
25
  BasePytorchGPTQTrainableQuantizer
26
26
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
27
27
  from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
28
28
  from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD, MAX_LSB_CHANGE
29
- from model_compression_toolkit.core.common.constants import THRESHOLD
29
+ from model_compression_toolkit.constants import THRESHOLD
30
30
  from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
31
31
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
32
32
  mark_quantizer
@@ -15,7 +15,7 @@
15
15
 
16
16
  from typing import Callable
17
17
 
18
- from model_compression_toolkit import CoreConfig
18
+ from model_compression_toolkit.core import CoreConfig
19
19
  from model_compression_toolkit.core import common
20
20
  from model_compression_toolkit.core.common.statistics_correction.statistics_correction import \
21
21
  apply_statistics_correction
@@ -28,6 +28,7 @@ from model_compression_toolkit.gptq.common.gptq_training import gptq_training
28
28
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
29
29
  from model_compression_toolkit.core.common.statistics_correction.apply_bias_correction_to_graph import \
30
30
  apply_bias_correction_to_graph
31
+ from model_compression_toolkit.logger import Logger
31
32
 
32
33
 
33
34
  def _apply_gptq(gptq_config: GradientPTQConfigV2,
@@ -55,7 +56,7 @@ def _apply_gptq(gptq_config: GradientPTQConfigV2,
55
56
 
56
57
  """
57
58
  if gptq_config is not None and gptq_config.n_epochs > 0:
58
- common.Logger.info("Using experimental Gradient Based PTQ: If you encounter an issue "
59
+ Logger.info("Using experimental Gradient Based PTQ: If you encounter an issue "
59
60
  "please file a bug. To disable it, do not pass a gptq configuration.")
60
61
 
61
62
  tg_bias = gptq_training(tg,
@@ -15,9 +15,8 @@
15
15
 
16
16
  from typing import Callable, List, Tuple
17
17
 
18
- from model_compression_toolkit.core import common
19
- from model_compression_toolkit.core.common import Logger
20
- from model_compression_toolkit.core.common.constants import TENSORFLOW
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import TENSORFLOW
21
20
  from model_compression_toolkit.core.common.user_info import UserInformation
22
21
  from model_compression_toolkit.gptq import GradientPTQConfig, GradientPTQConfigV2
23
22
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
@@ -35,15 +34,15 @@ from model_compression_toolkit.ptq.runner import ptq_runner
35
34
  from model_compression_toolkit.core.exporter import export_model
36
35
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
37
36
 
38
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework import TargetPlatformCapabilities
39
- from model_compression_toolkit.core.common.constants import FOUND_TF
37
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
38
+ from model_compression_toolkit.constants import FOUND_TF
40
39
 
41
40
  if FOUND_TF:
42
41
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
43
42
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
44
43
  from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
45
44
  from tensorflow.keras.models import Model
46
- from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL
45
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
47
46
 
48
47
  from model_compression_toolkit import get_target_platform_capabilities
49
48
 
@@ -81,7 +80,7 @@ if FOUND_TF:
81
80
  network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
82
81
  gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
83
82
  analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
84
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to. `Default Keras TPC <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/tpc_models/keras_tp_models/keras_default.py>`_
83
+ target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
85
84
 
86
85
  Returns:
87
86
  A quantized model and information the user may need to handle the quantized model.
@@ -184,7 +183,7 @@ if FOUND_TF:
184
183
  network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
185
184
  gptq_config (GradientPTQConfig): Configuration for using GPTQ (e.g. optimizer).
186
185
  analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
187
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to. `Default Keras TPC <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/tpc_models/keras_tp_models/keras_default.py>`_
186
+ target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
188
187
 
189
188
 
190
189
  Returns:
@@ -209,13 +208,13 @@ if FOUND_TF:
209
208
  Create a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
210
209
  The candidates bitwidth for quantization should be defined in the target platform model:
211
210
 
212
- >>> config = mct.MixedPrecisionQuantizationConfig()
211
+ >>> config = mct.core.MixedPrecisionQuantizationConfig()
213
212
 
214
213
  Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
215
214
  that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
216
215
  while the bias will not):
217
216
 
218
- >>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
217
+ >>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
219
218
 
220
219
  Pass the model, the representative dataset generator, the configuration and the target KPI to get a
221
220
  quantized model:
@@ -229,11 +228,11 @@ if FOUND_TF:
229
228
  fw_info=fw_info).validate()
230
229
 
231
230
  if not isinstance(quant_config, MixedPrecisionQuantizationConfig):
232
- common.Logger.error("Given quantization config to mixed-precision facade is not of type "
231
+ Logger.error("Given quantization config to mixed-precision facade is not of type "
233
232
  "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization API,"
234
233
  "or pass a valid mixed precision configuration.")
235
234
 
236
- common.Logger.info("Using experimental mixed-precision quantization. "
235
+ Logger.info("Using experimental mixed-precision quantization. "
237
236
  "If you encounter an issue please file a bug.")
238
237
 
239
238
  quantization_config, mp_config = quant_config.separate_configs()
@@ -14,12 +14,11 @@
14
14
  # ==============================================================================
15
15
  from typing import Callable, List, Tuple
16
16
 
17
- from model_compression_toolkit.core import common
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.constants import PYTORCH
17
+ from model_compression_toolkit.logger import Logger
18
+ from model_compression_toolkit.constants import PYTORCH
20
19
  from model_compression_toolkit.core.common.user_info import UserInformation
21
20
  from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GradientPTQConfigV2
22
- from model_compression_toolkit.core.common.target_platform import TargetPlatformCapabilities
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
23
22
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
24
23
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
25
24
  from model_compression_toolkit.core.common.network_editors.actions import EditRule
@@ -34,12 +33,12 @@ from model_compression_toolkit.gptq.runner import gptq_runner
34
33
  from model_compression_toolkit.ptq.runner import ptq_runner
35
34
  from model_compression_toolkit.core.exporter import export_model
36
35
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
37
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
36
+ from model_compression_toolkit.constants import FOUND_TORCH
38
37
 
39
38
  if FOUND_TORCH:
40
39
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
41
40
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
42
- from model_compression_toolkit.core.pytorch.constants import DEFAULT_TP_MODEL
41
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
43
42
  from torch.nn import Module
44
43
 
45
44
  from model_compression_toolkit import get_target_platform_capabilities
@@ -76,7 +75,7 @@ if FOUND_TORCH:
76
75
  network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
77
76
  gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
78
77
  analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
79
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to. `Default PyTorch TPC <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/tpc_models/pytorch_tp_models/pytorch_default.py>`_
78
+ target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
80
79
 
81
80
 
82
81
  Returns:
@@ -175,7 +174,7 @@ if FOUND_TORCH:
175
174
  network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
176
175
  gptq_config (GradientPTQConfig): Configuration for using GPTQ (e.g. optimizer).
177
176
  analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
178
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to. `Default PyTorch TPC <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/tpc_models/pytorch_tp_models/pytorch_default.py>`_
177
+ target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
179
178
 
180
179
  Returns:
181
180
  A quantized model and information the user may need to handle the quantized model.
@@ -199,13 +198,13 @@ if FOUND_TORCH:
199
198
  Create a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
200
199
  The candidates bitwidth for quantization should be defined in the target platform model:
201
200
 
202
- >>> config = mct.MixedPrecisionQuantizationConfig()
201
+ >>> config = mct.core.MixedPrecisionQuantizationConfig()
203
202
 
204
203
  Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
205
204
  that should be quantized (for example, the kernel of Conv2D in PyTorch will be affected by this value,
206
205
  while the bias will not):
207
206
 
208
- >>> kpi = mct.KPI(sum(p.numel() for p in module.parameters()) * 0.75) # About 0.75 of the model size when quantized with 8 bits.
207
+ >>> kpi = mct.core.KPI(sum(p.numel() for p in module.parameters()) * 0.75) # About 0.75 of the model size when quantized with 8 bits.
209
208
 
210
209
  Pass the model, the representative dataset generator, the configuration and the target KPI to get a
211
210
  quantized model:
@@ -217,11 +216,11 @@ if FOUND_TORCH:
217
216
  """
218
217
 
219
218
  if not isinstance(quant_config, MixedPrecisionQuantizationConfig):
220
- common.Logger.error("Given quantization config to mixed-precision facade is not of type "
219
+ Logger.error("Given quantization config to mixed-precision facade is not of type "
221
220
  "MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API, "
222
221
  "or pass a valid mixed precision configuration.")
223
222
 
224
- common.Logger.info("Using experimental mixed-precision quantization. "
223
+ Logger.info("Using experimental mixed-precision quantization. "
225
224
  "If you encounter an issue please file a bug.")
226
225
 
227
226
  quantization_config, mp_config = quant_config.separate_configs()
@@ -12,3 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
16
+ from model_compression_toolkit.ptq.pytorch.quantization_facade import pytorch_post_training_quantization_experimental
17
+ from model_compression_toolkit.ptq.keras.quantization_facade import keras_post_training_quantization_experimental
@@ -15,15 +15,14 @@
15
15
 
16
16
  from typing import Callable
17
17
 
18
- from model_compression_toolkit import CoreConfig
19
- from model_compression_toolkit.core import common
18
+ from model_compression_toolkit.core import CoreConfig
20
19
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
21
- from model_compression_toolkit.core.common import Logger
22
- from model_compression_toolkit.core.common.constants import TENSORFLOW, FOUND_TF
20
+ from model_compression_toolkit.logger import Logger
21
+ from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
23
22
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
24
23
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
25
24
  MixedPrecisionQuantizationConfigV2
26
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework import TargetPlatformCapabilities
25
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
27
26
  from model_compression_toolkit.core.exporter import export_model
28
27
  from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
29
28
  from model_compression_toolkit.ptq.runner import ptq_runner
@@ -33,7 +32,7 @@ if FOUND_TF:
33
32
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
34
33
  from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
35
34
  from tensorflow.keras.models import Model
36
- from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL
35
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
37
36
  from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
38
37
 
39
38
  from model_compression_toolkit import get_target_platform_capabilities
@@ -93,25 +92,25 @@ if FOUND_TF:
93
92
 
94
93
  Create a MCT core config, containing the quantization configuration:
95
94
 
96
- >>> config = mct.CoreConfig()
95
+ >>> config = mct.core.CoreConfig()
97
96
 
98
97
  If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
99
98
  The candidates bitwidth for quantization should be defined in the target platform model.
100
99
  In this example we use 1 image to search mixed-precision configuration:
101
100
 
102
- >>> config = mct.CoreConfig(mixed_precision_config=mct.MixedPrecisionQuantizationConfigV2(num_of_images=1))
101
+ >>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfigV2(num_of_images=1))
103
102
 
104
103
  For mixed-precision set a target KPI object:
105
104
  Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
106
105
  that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
107
106
  while the bias will not):
108
107
 
109
- >>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
108
+ >>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
110
109
 
111
110
  Pass the model, the representative dataset generator, the configuration and the target KPI to get a
112
111
  quantized model:
113
112
 
114
- >>> quantized_model, quantization_info = mct.keras_post_training_quantization_experimental(model, repr_datagen, kpi, core_config=config)
113
+ >>> quantized_model, quantization_info = mct.ptq.keras_post_training_quantization_experimental(model, repr_datagen, kpi, core_config=config)
115
114
 
116
115
  For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
117
116
 
@@ -124,11 +123,11 @@ if FOUND_TF:
124
123
 
125
124
  if core_config.mixed_precision_enable:
126
125
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
127
- common.Logger.error("Given quantization config to mixed-precision facade is not of type "
126
+ Logger.error("Given quantization config to mixed-precision facade is not of type "
128
127
  "MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
129
128
  "API, or pass a valid mixed precision configuration.") # pragma: no cover
130
129
 
131
- common.Logger.info("Using experimental mixed-precision quantization. "
130
+ Logger.info("Using experimental mixed-precision quantization. "
132
131
  "If you encounter an issue please file a bug.")
133
132
 
134
133
  tb_w = _init_tensorboard_writer(fw_info)
@@ -15,11 +15,11 @@
15
15
  from typing import Callable
16
16
 
17
17
  from model_compression_toolkit.core import common
18
- from model_compression_toolkit.core.common import Logger
19
- from model_compression_toolkit.core.common.constants import PYTORCH, FOUND_TORCH
20
- from model_compression_toolkit.core.common.target_platform import TargetPlatformCapabilities
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import PYTORCH, FOUND_TORCH
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
21
21
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
22
- from model_compression_toolkit import CoreConfig
22
+ from model_compression_toolkit.core import CoreConfig
23
23
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
24
24
  MixedPrecisionQuantizationConfigV2
25
25
  from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
@@ -31,7 +31,7 @@ from model_compression_toolkit.core.analyzer import analyzer_model_quantization
31
31
  if FOUND_TORCH:
32
32
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
33
33
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
34
- from model_compression_toolkit.core.pytorch.constants import DEFAULT_TP_MODEL
34
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
35
35
  from torch.nn import Module
36
36
  from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
37
37
  from model_compression_toolkit import get_target_platform_capabilities
@@ -88,18 +88,18 @@ if FOUND_TORCH:
88
88
  Set number of clibration iterations to 1:
89
89
 
90
90
  >>> import model_compression_toolkit as mct
91
- >>> quantized_module, quantization_info = mct.pytorch_post_training_quantization_experimental(module, repr_datagen)
91
+ >>> quantized_module, quantization_info = mct.ptq.pytorch_post_training_quantization_experimental(module, repr_datagen)
92
92
 
93
93
  """
94
94
 
95
95
  if core_config.mixed_precision_enable:
96
96
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
97
- common.Logger.error("Given quantization config to mixed-precision facade is not of type "
97
+ Logger.error("Given quantization config to mixed-precision facade is not of type "
98
98
  "MixedPrecisionQuantizationConfigV2. Please use "
99
99
  "pytorch_post_training_quantization API, or pass a valid mixed precision "
100
100
  "configuration.") # pragma: no cover
101
101
 
102
- common.Logger.info("Using experimental mixed-precision quantization. "
102
+ Logger.info("Using experimental mixed-precision quantization. "
103
103
  "If you encounter an issue please file a bug.")
104
104
 
105
105
  tb_w = _init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
@@ -12,3 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from model_compression_toolkit.qat.common.qat_config import QATConfig, TrainingMethod
16
+
17
+ from model_compression_toolkit.qat.keras.quantization_facade import keras_quantization_aware_training_init, keras_quantization_aware_training_finalize
18
+ from model_compression_toolkit.qat.pytorch.quantization_facade import pytorch_quantization_aware_training_init, pytorch_quantization_aware_training_finalize
@@ -12,5 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
16
- from model_compression_toolkit.qat.common.constants import THRESHOLD_TENSOR, WEIGHTS_QUANTIZATION_PARAMS
15
+ from model_compression_toolkit.quantizers_infrastructure.constants import THRESHOLD_TENSOR, WEIGHTS_QUANTIZATION_PARAMS
@@ -17,6 +17,8 @@ from typing import Dict
17
17
  from enum import Enum
18
18
  from model_compression_toolkit.core import common
19
19
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
20
+ from model_compression_toolkit.logger import Logger
21
+
20
22
 
21
23
  def _is_qat_applicable(node: common.BaseNode,
22
24
  fw_info: FrameworkInfo) -> bool:
@@ -31,7 +33,7 @@ def _is_qat_applicable(node: common.BaseNode,
31
33
  """
32
34
 
33
35
  if node.is_weights_quantization_enabled() and not fw_info.is_kernel_op(node.type):
34
- common.Logger.error("QAT Error: Quantizing a node without a kernel isn't supported")
36
+ Logger.error("QAT Error: Quantizing a node without a kernel isn't supported")
35
37
  return node.is_weights_quantization_enabled() or node.is_activation_quantization_enabled()
36
38
 
37
39
 
@@ -40,8 +42,10 @@ class TrainingMethod(Enum):
40
42
  An enum for selecting a QAT training method
41
43
 
42
44
  STE - Standard straight-through estimator. Includes PowerOfTwo, symmetric & uniform quantizers
45
+ DQA - DNN Quantization with Attention. Includes a smooth quantization introduces by DQA method
43
46
  """
44
47
  STE = "STE",
48
+ DQA = "DQA"
45
49
 
46
50
 
47
51
  class QATConfig: