mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
  2. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
  3. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +12 -41
  5. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  6. model_compression_toolkit/core/__init__.py +14 -0
  7. model_compression_toolkit/core/analyzer.py +3 -2
  8. model_compression_toolkit/core/common/__init__.py +0 -1
  9. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  11. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  12. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  13. model_compression_toolkit/core/common/framework_info.py +1 -1
  14. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  15. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  16. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  18. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  19. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  20. model_compression_toolkit/core/common/memory_computation.py +1 -1
  21. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  23. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  26. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  28. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  29. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  30. model_compression_toolkit/core/common/model_collector.py +2 -2
  31. model_compression_toolkit/core/common/model_validation.py +1 -1
  32. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  33. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  34. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  35. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  36. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  37. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  38. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  39. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  50. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  51. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  52. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  54. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  55. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  56. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  57. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  58. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  59. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  60. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  61. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  62. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  63. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  65. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  66. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  67. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  68. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  69. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  72. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
  73. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  74. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  75. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  76. model_compression_toolkit/core/keras/constants.py +0 -7
  77. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  85. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  86. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  87. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  88. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  89. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  90. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  91. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  92. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  93. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  94. model_compression_toolkit/core/keras/reader/common.py +1 -1
  95. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  99. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  100. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  102. model_compression_toolkit/core/pytorch/constants.py +4 -6
  103. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  109. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  110. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  111. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  112. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  113. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  114. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  115. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  116. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  117. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  118. model_compression_toolkit/core/runner.py +7 -7
  119. model_compression_toolkit/exporter/__init__.py +5 -0
  120. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
  121. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
  124. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  125. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  126. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
  127. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
  128. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
  129. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
  130. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
  131. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
  132. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
  133. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
  135. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
  136. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
  137. model_compression_toolkit/gptq/__init__.py +6 -0
  138. model_compression_toolkit/gptq/common/gptq_config.py +57 -104
  139. model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
  140. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  141. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  142. model_compression_toolkit/gptq/common/gptq_training.py +30 -39
  143. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  144. model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
  145. model_compression_toolkit/gptq/keras/graph_info.py +8 -33
  146. model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
  147. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  148. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
  149. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  150. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
  151. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  152. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
  153. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
  154. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  155. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
  156. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  157. model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
  158. model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
  159. model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
  160. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  161. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
  162. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  163. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
  164. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  165. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
  166. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
  167. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  168. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
  169. model_compression_toolkit/gptq/runner.py +3 -2
  170. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
  171. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  172. model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
  173. model_compression_toolkit/ptq/__init__.py +3 -0
  174. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  175. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  176. model_compression_toolkit/qat/__init__.py +4 -0
  177. model_compression_toolkit/qat/common/__init__.py +1 -2
  178. model_compression_toolkit/qat/common/qat_config.py +3 -1
  179. model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
  180. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  181. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
  182. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
  183. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  184. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  185. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
  186. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
  187. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  211. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  212. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  213. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  214. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  215. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
  217. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  218. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  219. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  220. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  221. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  222. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
  223. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  224. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
  225. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  226. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
  227. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  228. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
  229. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  232. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  233. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  234. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  235. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  236. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  237. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  238. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  239. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  240. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  241. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  242. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  243. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  244. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  248. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  250. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  254. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  255. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  257. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  259. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  261. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  263. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  265. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  273. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  274. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  275. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  276. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  277. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  278. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  279. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  280. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  281. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  282. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  283. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  284. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  285. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  286. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  287. model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
  288. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
  289. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
  290. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  291. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  292. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  293. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  294. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  300. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  301. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  302. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  303. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  304. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  305. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  306. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  307. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from abc import abstractmethod
17
-
18
16
  import tensorflow as tf
19
17
  from keras.engine.input_layer import InputLayer
20
18
  from keras.models import Model, clone_model
@@ -22,7 +20,7 @@ from packaging import version
22
20
 
23
21
  from model_compression_toolkit.core.common.back2framework.base_model_builder import BaseModelBuilder
24
22
  from model_compression_toolkit.core.common.user_info import UserInformation
25
- from model_compression_toolkit.core.common.constants import INPUT_BASE_NAME
23
+ from model_compression_toolkit.constants import INPUT_BASE_NAME
26
24
 
27
25
  # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
28
26
  if version.parse(tf.__version__) < version.parse("2.6"):
@@ -38,7 +36,6 @@ else:
38
36
  from typing import Any, Dict, List, Tuple, Callable
39
37
  from tensorflow.python.util.object_identity import Reference as TFReference
40
38
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
41
- from model_compression_toolkit.core.common.logger import Logger
42
39
  from model_compression_toolkit.core import common
43
40
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
44
41
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
@@ -36,7 +36,7 @@ else:
36
36
  from keras.layers.core import TFOpLambda, SlicingOpLambda
37
37
 
38
38
  from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
39
- from model_compression_toolkit.core.common.logger import Logger
39
+ from model_compression_toolkit.logger import Logger
40
40
  from model_compression_toolkit.core import common
41
41
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
42
42
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
@@ -26,13 +26,13 @@ else:
26
26
 
27
27
  from typing import Any, Dict, List, Tuple
28
28
  from tensorflow.python.util.object_identity import Reference as TFReference
29
- from model_compression_toolkit.core.common.constants import EPS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE
29
+ from model_compression_toolkit.constants import EPS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE
30
30
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
31
31
  from model_compression_toolkit.core import common
32
32
  from model_compression_toolkit.core.common import BaseNode, Graph
33
33
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
34
34
  from model_compression_toolkit.core.keras.back2framework.instance_builder import OperationHandler
35
- from model_compression_toolkit.core.common.logger import Logger
35
+ from model_compression_toolkit.logger import Logger
36
36
 
37
37
 
38
38
  def build_input_tensors_list(node: BaseNode,
@@ -171,8 +171,9 @@ def keras_iterative_approx_jacobian_trace(graph_float: common.Graph,
171
171
 
172
172
  # If the change to the mean Jacobian approximation is insignificant we stop the calculation
173
173
  if j > MIN_JACOBIANS_ITER:
174
- delta = np.mean([jac_trace_approx, *trace_jv]) - np.mean(trace_jv)
175
- if np.abs(delta) / (np.abs(np.mean(trace_jv)) + 1e-6) < JACOBIANS_COMP_TOLERANCE:
174
+ new_mean = np.mean([jac_trace_approx, *trace_jv])
175
+ delta = new_mean - np.mean(trace_jv)
176
+ if np.abs(delta) / (np.abs(new_mean) + 1e-6) < JACOBIANS_COMP_TOLERANCE:
176
177
  trace_jv.append(jac_trace_approx)
177
178
  break
178
179
 
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import List
16
16
 
17
- from model_compression_toolkit import FrameworkInfo
17
+ from model_compression_toolkit.core import FrameworkInfo
18
18
  from model_compression_toolkit.core import common
19
19
  from model_compression_toolkit.core.common import BaseNode
20
20
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
@@ -99,13 +99,6 @@ OUTPUT_BIAS = '/attention_output/bias'
99
99
  # ReLU bound constants
100
100
  RELU_POT_BOUND = 8.0
101
101
 
102
- # Supported TP models names for Tensorflow:
103
- DEFAULT_TP_MODEL = 'default'
104
- IMX500_TP_MODEL = 'imx500'
105
- TFLITE_TP_MODEL = 'tflite'
106
- QNNPACK_TP_MODEL = 'qnnpack'
107
-
108
-
109
102
  # TFOpLambda functions:
110
103
  ADD = 'add'
111
104
  PAD = 'pad'
@@ -25,9 +25,9 @@ else:
25
25
  from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU
26
26
 
27
27
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
28
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo, ChannelAxis
29
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
30
- from model_compression_toolkit.core.common.constants import SOFTMAX_THRESHOLD
28
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
29
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
30
+ from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
31
31
  from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
32
32
  KERNEL, DEPTHWISE_KERNEL
33
33
  from model_compression_toolkit.core.keras.quantizer.fake_quant_builder import power_of_two_quantization, symmetric_quantization, uniform_quantization
@@ -17,7 +17,7 @@
17
17
  from tensorflow.keras.layers import Dense, DepthwiseConv2D, Conv2D, Conv2DTranspose, Activation, SeparableConv2D
18
18
 
19
19
  from model_compression_toolkit.core import common
20
- from model_compression_toolkit.core.common.constants import FLOAT_32, DATA_TYPE
20
+ from model_compression_toolkit.constants import FLOAT_32, DATA_TYPE
21
21
  from model_compression_toolkit.core.common.graph.base_graph import Graph
22
22
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, \
23
23
  NodeFrameworkAttrMatcher
@@ -23,7 +23,7 @@ from model_compression_toolkit.core.common.graph.base_graph import Graph
23
23
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, EdgeMatcher, WalkMatcher
24
24
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
25
25
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
26
- from model_compression_toolkit.core.common.constants import THRESHOLD
26
+ from model_compression_toolkit.constants import THRESHOLD
27
27
  from model_compression_toolkit.core.keras.constants import KERNEL
28
28
 
29
29
  input_node = NodeOperationMatcher(InputLayer)
@@ -21,7 +21,7 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
21
21
  from model_compression_toolkit.core.common.substitutions.linear_collapsing import Conv2DCollapsing
22
22
  from model_compression_toolkit.core.keras.constants import KERNEL, KERNEL_SIZE, STRIDES, DILATIONS, LINEAR, \
23
23
  ACTIVATION, BIAS, USE_BIAS, LAYER_NAME, FILTERS, PADDING, GROUPS, DATA_FORMAT
24
- from model_compression_toolkit.core.common.logger import Logger
24
+ from model_compression_toolkit.logger import Logger
25
25
 
26
26
 
27
27
  def linear_collapsing_node_matchers() -> Tuple[NodeOperationMatcher, NodeOperationMatcher]:
@@ -23,17 +23,16 @@ else:
23
23
  from keras.layers.core import TFOpLambda
24
24
  from keras.layers import MultiHeadAttention, Conv2D, Softmax, Concatenate, Reshape, Permute
25
25
 
26
- from model_compression_toolkit.core.common.logger import Logger
26
+ from model_compression_toolkit.logger import Logger
27
27
  from model_compression_toolkit.core import common
28
28
  from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
29
29
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
30
30
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
31
- from model_compression_toolkit.core.common.constants import REUSE, REUSE_GROUP
32
- from model_compression_toolkit.core.keras.reader.node_builder import REUSED_IDENTIFIER
31
+ from model_compression_toolkit.constants import REUSE, REUSE_GROUP
33
32
  from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, NUM_HEADS, KEY_DIM, VALUE_DIM, \
34
33
  QUERY_SHAPE, KEY_SHAPE, VALUE_SHAPE, OUTPUT_SHAPE, ATTENTION_AXES, ACTIVATION, LINEAR, FILTERS, \
35
34
  FUNCTION, DIMS, TARGET_SHAPE, F_STRIDED_SLICE, F_STACK, Q_KERNEL, Q_BIAS, K_KERNEL, K_BIAS, V_KERNEL, V_BIAS, \
36
- OUTPUT_KERNEL, OUTPUT_BIAS, F_MATMUL, TRANSPOSE_B, KERNEL_SIZE, AXIS, F_STRIDED_SLICE_BEGIN, F_STRIDED_SLICE_END
35
+ OUTPUT_KERNEL, OUTPUT_BIAS, F_MATMUL, KERNEL_SIZE, AXIS, F_STRIDED_SLICE_BEGIN, F_STRIDED_SLICE_END
37
36
 
38
37
 
39
38
  class MHAParams:
@@ -23,6 +23,7 @@ from model_compression_toolkit.core import common
23
23
  from model_compression_toolkit.core.common import Graph, BaseNode
24
24
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, WalkMatcher
25
25
  from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, RELU_MAX_VALUE, RELU_POT_BOUND
26
+ from model_compression_toolkit.logger import Logger
26
27
 
27
28
 
28
29
  class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
@@ -81,7 +82,7 @@ class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
81
82
  scale_factor = max_value / self.threshold
82
83
 
83
84
  non_linear_node.framework_attr[RELU_MAX_VALUE] = np.float32(self.threshold)
84
- common.Logger.debug(
85
+ Logger.debug(
85
86
  f"Node named:{non_linear_node.name} max value change "
86
87
  f"to:{non_linear_node.framework_attr[RELU_MAX_VALUE]}")
87
88
 
@@ -20,7 +20,8 @@ from model_compression_toolkit.core import common
20
20
  from model_compression_toolkit.core.common import Graph, BaseNode
21
21
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher,NodeFrameworkAttrMatcher
22
22
  from model_compression_toolkit.core.keras.constants import RELU_MAX_VALUE
23
- from model_compression_toolkit.core.common.constants import THRESHOLD
23
+ from model_compression_toolkit.constants import THRESHOLD
24
+ from model_compression_toolkit.logger import Logger
24
25
 
25
26
  MATCHER = NodeOperationMatcher(ReLU) & NodeFrameworkAttrMatcher(RELU_MAX_VALUE, None).logic_not()
26
27
 
@@ -56,5 +57,5 @@ class RemoveReLUUpperBound(common.BaseSubstitution):
56
57
  node.final_activation_quantization_cfg.activation_quantization_params.get(THRESHOLD) == \
57
58
  node.framework_attr.get(RELU_MAX_VALUE):
58
59
  node.framework_attr[RELU_MAX_VALUE] = None
59
- common.Logger.info(f'Removing upper bound of {node.name}. Threshold and upper bound are equal.')
60
+ Logger.info(f'Removing upper bound of {node.name}. Threshold and upper bound are equal.')
60
61
  return graph
@@ -21,7 +21,7 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
21
21
  NodeFrameworkAttrMatcher
22
22
  from model_compression_toolkit.core.common.substitutions.residual_collapsing import ResidualCollapsing
23
23
  from model_compression_toolkit.core.keras.constants import KERNEL, LINEAR, ACTIVATION, LAYER_NAME
24
- from model_compression_toolkit.core.common.logger import Logger
24
+ from model_compression_toolkit.logger import Logger
25
25
 
26
26
 
27
27
  def residual_collapsing_node_matchers() -> Tuple[NodeOperationMatcher, NodeOperationMatcher]:
@@ -22,7 +22,7 @@ from tensorflow.python.keras.layers.core import TFOpLambda
22
22
  from tensorflow.keras.layers import Activation, Conv2D, Dense, DepthwiseConv2D, ZeroPadding2D, Reshape, \
23
23
  GlobalAveragePooling2D, Dropout, ReLU, PReLU, ELU
24
24
 
25
- from model_compression_toolkit import CoreConfig, FrameworkInfo
25
+ from model_compression_toolkit.core import CoreConfig, FrameworkInfo
26
26
  from model_compression_toolkit.core.common import BaseNode, Graph
27
27
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
28
28
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, \
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import List, Any, Tuple, Callable, Type, Dict
15
+ from typing import List, Any, Tuple, Callable, Dict
16
16
 
17
17
  import numpy as np
18
18
  import tensorflow as tf
@@ -43,7 +43,7 @@ else:
43
43
  Concatenate, Add
44
44
  from keras.layers.core import TFOpLambda
45
45
 
46
- from model_compression_toolkit import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfigV2
46
+ from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfigV2
47
47
  from model_compression_toolkit.core import common
48
48
  from model_compression_toolkit.core.common import Graph, BaseNode
49
49
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
@@ -52,8 +52,6 @@ from model_compression_toolkit.core.common.model_builder_mode import ModelBuilde
52
52
  from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
53
53
  from model_compression_toolkit.core.common.user_info import UserInformation
54
54
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
55
- from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
56
- from model_compression_toolkit.gptq.keras.gptq_training import KerasGPTQTrainer
57
55
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.activation_decomposition import \
58
56
  ActivationDecomposition
59
57
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.softmax_shift import \
@@ -348,12 +346,6 @@ class KerasImplementation(FrameworkImplementation):
348
346
  substitutions_list.append(keras_batchnorm_refusing())
349
347
  return substitutions_list
350
348
 
351
- def get_gptq_trainer_obj(self) -> Type[GPTQTrainer]:
352
- """
353
- Returns: Keras object of GPTQTrainer
354
- """
355
- return KerasGPTQTrainer
356
-
357
349
  def get_sensitivity_evaluator(self,
358
350
  graph: Graph,
359
351
  quant_config: MixedPrecisionQuantizationConfigV2,
@@ -1,6 +1,6 @@
1
1
  from tensorflow.keras.models import Model
2
2
 
3
- from model_compression_toolkit import FrameworkInfo
3
+ from model_compression_toolkit.core import FrameworkInfo
4
4
  from model_compression_toolkit.core.common.framework_info import ChannelAxis
5
5
  from model_compression_toolkit.core.common.model_validation import ModelValidation
6
6
  from model_compression_toolkit.core.keras.constants import CHANNELS_FORMAT, CHANNELS_FORMAT_LAST, CHANNELS_FORMAT_FIRST
@@ -8,7 +8,7 @@ if version.parse(tf.__version__) < version.parse("2.6"):
8
8
  else:
9
9
  from keras.layers import Activation, ReLU, BatchNormalization
10
10
 
11
- from model_compression_toolkit import FrameworkInfo
11
+ from model_compression_toolkit.core import FrameworkInfo
12
12
  from model_compression_toolkit.core.common import BaseNode
13
13
  from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
14
14
  from model_compression_toolkit.core.keras.constants import ACTIVATION, RELU_MAX_VALUE, NEGATIVE_SLOPE, THRESHOLD, \
@@ -15,19 +15,19 @@
15
15
 
16
16
  from typing import Callable
17
17
 
18
- from model_compression_toolkit import MixedPrecisionQuantizationConfig, CoreConfig, MixedPrecisionQuantizationConfigV2
18
+ from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, CoreConfig, MixedPrecisionQuantizationConfigV2
19
19
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
20
- from model_compression_toolkit.core.common import Logger
21
- from model_compression_toolkit.core.common.constants import TENSORFLOW
22
- from model_compression_toolkit.core.common.target_platform import TargetPlatformCapabilities
20
+ from model_compression_toolkit.logger import Logger
21
+ from model_compression_toolkit.constants import TENSORFLOW
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
23
23
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_data import compute_kpi_data
24
24
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
25
25
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
26
26
  DEFAULT_MIXEDPRECISION_CONFIG
27
- from model_compression_toolkit.core.common.constants import FOUND_TF
27
+ from model_compression_toolkit.constants import FOUND_TF
28
28
 
29
29
  if FOUND_TF:
30
- from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL
30
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
31
31
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
32
32
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
33
33
  from tensorflow.keras.models import Model
@@ -51,7 +51,7 @@ if FOUND_TF:
51
51
  representative_data_gen (Callable): Dataset used for calibration.
52
52
  quant_config (MixedPrecisionQuantizationConfig): MixedPrecisionQuantizationConfig containing parameters of how the model should be quantized.
53
53
  fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
54
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to. `Default Keras TPC <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/tpc_models/keras_tp_models/keras_default.py>`_
54
+ target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
55
55
 
56
56
  Returns:
57
57
  A KPI object with total weights parameters sum, max activation tensor and total kpi.
@@ -75,7 +75,7 @@ if FOUND_TF:
75
75
  Import MCT and call for KPI data calculation:
76
76
 
77
77
  >>> import model_compression_toolkit as mct
78
- >>> kpi_data = mct.keras_kpi_data(model, repr_datagen)
78
+ >>> kpi_data = mct.core.keras_kpi_data(model, repr_datagen)
79
79
 
80
80
 
81
81
  """
@@ -112,7 +112,7 @@ if FOUND_TF:
112
112
  representative_data_gen (Callable): Dataset used for calibration.
113
113
  core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision of how the model should be quantized.
114
114
  fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
115
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to. `Default Keras TPC <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/tpc_models/keras_tp_models/keras_default.py>`_
115
+ target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
116
116
 
117
117
  Returns:
118
118
 
@@ -133,7 +133,7 @@ if FOUND_TF:
133
133
  Import MCT and call for KPI data calculation:
134
134
 
135
135
  >>> import model_compression_toolkit as mct
136
- >>> kpi_data = mct.keras_kpi_data(model, repr_datagen)
136
+ >>> kpi_data = mct.core.keras_kpi_data(model, repr_datagen)
137
137
 
138
138
  """
139
139
 
@@ -20,8 +20,8 @@ import tensorflow as tf
20
20
  import numpy as np
21
21
  from tensorflow.python.util.object_identity import Reference as TFReference
22
22
 
23
- from model_compression_toolkit.core.common.logger import Logger
24
- from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX
23
+ from model_compression_toolkit.logger import Logger
24
+ from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX
25
25
  from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import threshold_is_power_of_two
26
26
 
27
27
 
@@ -23,7 +23,7 @@ from tensorflow_model_optimization.python.core.quantization.keras.quantize_confi
23
23
 
24
24
 
25
25
  from model_compression_toolkit.core.common import BaseNode
26
- from model_compression_toolkit.core.common.constants import INPUT_BASE_NAME
26
+ from model_compression_toolkit.constants import INPUT_BASE_NAME
27
27
 
28
28
 
29
29
  class InputLayerWrapperTransform(InputLayerQuantize):
@@ -1,11 +1,11 @@
1
- from typing import Tuple, Dict, Any, Callable
1
+ from typing import Tuple, Dict, Callable
2
2
 
3
3
  import numpy as np
4
4
  import tensorflow as tf
5
5
  from keras.layers import Layer
6
6
  from tensorflow.python.util.object_identity import Reference as TFReference
7
7
 
8
- from model_compression_toolkit.core.common.constants import SIGNED, CLUSTER_CENTERS, EPS, \
8
+ from model_compression_toolkit.constants import SIGNED, CLUSTER_CENTERS, EPS, \
9
9
  MULTIPLIER_N_BITS, THRESHOLD
10
10
 
11
11
 
@@ -24,7 +24,7 @@ from model_compression_toolkit.core.common.quantization.candidate_node_quantizat
24
24
  from model_compression_toolkit.core.keras.quantizer.mixed_precision.selective_activation_quantizer import \
25
25
  SelectiveActivationQuantizer
26
26
  from packaging import version
27
- from model_compression_toolkit.core.common.logger import Logger
27
+ from model_compression_toolkit.logger import Logger
28
28
 
29
29
  if version.parse(tf.__version__) < version.parse("2.6"):
30
30
  from tensorflow.python.keras.layers import Layer # pragma: no cover
@@ -29,7 +29,7 @@ else:
29
29
  from keras.engine.functional import Functional
30
30
  from keras.engine.sequential import Sequential
31
31
 
32
- from model_compression_toolkit.core.common.logger import Logger
32
+ from model_compression_toolkit.logger import Logger
33
33
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
34
34
 
35
35
 
@@ -19,7 +19,7 @@ from tensorflow.keras.layers import BatchNormalization
19
19
  from tqdm import tqdm
20
20
 
21
21
  import model_compression_toolkit.core.keras.constants as keras_constants
22
- from model_compression_toolkit import CoreConfig
22
+ from model_compression_toolkit.core import CoreConfig
23
23
  from model_compression_toolkit.core import common
24
24
 
25
25
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common import Logger
16
+ from model_compression_toolkit.logger import Logger
17
17
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
18
18
  from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
19
19
  from model_compression_toolkit.core.pytorch.back2framework.mixed_precision_model_builder import \
@@ -17,7 +17,7 @@ from typing import List, Tuple
17
17
 
18
18
  import torch
19
19
 
20
- from model_compression_toolkit import FrameworkInfo
20
+ from model_compression_toolkit.core import FrameworkInfo
21
21
  from model_compression_toolkit.core import common
22
22
  from model_compression_toolkit.core.common import BaseNode
23
23
  from model_compression_toolkit.core.common.user_info import UserInformation
@@ -17,7 +17,7 @@ from typing import List, Any, Tuple
17
17
 
18
18
  import torch
19
19
 
20
- from model_compression_toolkit import FrameworkInfo
20
+ from model_compression_toolkit.core import FrameworkInfo
21
21
  from model_compression_toolkit.core import common
22
22
  from model_compression_toolkit.core.common import BaseNode
23
23
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
@@ -22,13 +22,14 @@ import numpy as np
22
22
 
23
23
  from model_compression_toolkit.core import common
24
24
  from model_compression_toolkit.core.common import BaseNode, Graph
25
- from model_compression_toolkit.core.common.constants import EPS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE
25
+ from model_compression_toolkit.constants import EPS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE
26
26
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
27
27
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
28
28
  from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder
29
- from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
30
- from model_compression_toolkit.core.pytorch.utils import torch_tensor_to_numpy
31
- from model_compression_toolkit.core.common.logger import Logger
29
+ from model_compression_toolkit.core.pytorch.constants import BUFFER
30
+ from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder, BufferHolder
31
+ from model_compression_toolkit.core.pytorch.utils import torch_tensor_to_numpy, get_working_device
32
+ from model_compression_toolkit.logger import Logger
32
33
 
33
34
 
34
35
  def build_input_tensors_list(node: BaseNode,
@@ -133,7 +134,13 @@ class PytorchModelGradients(torch.nn.Module):
133
134
 
134
135
  for n in self.node_sort:
135
136
  if not isinstance(n, FunctionalNode):
136
- self.add_module(n.name, node_builder(n))
137
+ if n.type == BufferHolder:
138
+ self.add_module(n.name, node_builder(n))
139
+ self.get_submodule(n.name). \
140
+ register_buffer(n.name,
141
+ torch.Tensor(n.get_weights_by_keys(BUFFER)).to(get_working_device()))
142
+ else:
143
+ self.add_module(n.name, node_builder(n))
137
144
 
138
145
  def forward(self,
139
146
  *args: Any) -> Any:
@@ -289,9 +296,9 @@ def pytorch_iterative_approx_jacobian_trace(graph_float: common.Graph,
289
296
 
290
297
  # If the change to the mean Jacobian approximation is insignificant we stop the calculation
291
298
  if j > MIN_JACOBIANS_ITER:
292
- delta = torch.mean(torch.stack([jac_trace_approx, *trace_jv])) - torch.mean(
293
- torch.stack(trace_jv))
294
- if torch.abs(delta) / (torch.abs(torch.mean(torch.stack(trace_jv))) + 1e-6) < JACOBIANS_COMP_TOLERANCE:
299
+ new_mean = torch.mean(torch.stack([jac_trace_approx, *trace_jv]))
300
+ delta = new_mean - torch.mean(torch.stack(trace_jv))
301
+ if torch.abs(delta) / (torch.abs(new_mean) + 1e-6) < JACOBIANS_COMP_TOLERANCE:
295
302
  trace_jv.append(jac_trace_approx)
296
303
  break
297
304
 
@@ -18,7 +18,7 @@ from typing import Tuple, Any, Dict, List, Union, Callable
18
18
  import torch
19
19
  from networkx import topological_sort
20
20
 
21
- from model_compression_toolkit import FrameworkInfo
21
+ from model_compression_toolkit.core import FrameworkInfo
22
22
  from model_compression_toolkit.core import common
23
23
  from model_compression_toolkit.core.common import BaseNode, Graph
24
24
  from model_compression_toolkit.core.common.back2framework.base_model_builder import BaseModelBuilder
@@ -17,7 +17,7 @@ from typing import List, Tuple
17
17
 
18
18
  import torch
19
19
 
20
- from model_compression_toolkit import FrameworkInfo
20
+ from model_compression_toolkit.core import FrameworkInfo
21
21
  from model_compression_toolkit.core import common
22
22
  from model_compression_toolkit.core.common import BaseNode
23
23
  from model_compression_toolkit.core.common.user_info import UserInformation
@@ -69,12 +69,6 @@ CPU = 'cpu'
69
69
  # ReLU bound constants
70
70
  RELU_POT_BOUND = 8.0
71
71
 
72
- # Supported TP models names for Pytorch:
73
- DEFAULT_TP_MODEL = 'default'
74
- IMX500_TP_MODEL = 'imx500'
75
- TFLITE_TP_MODEL = 'tflite'
76
- QNNPACK_TP_MODEL = 'qnnpack'
77
-
78
72
  # MultiHeadAttention layer attributes:
79
73
  EMBED_DIM = 'embed_dim'
80
74
  NUM_HEADS = 'num_heads'
@@ -92,3 +86,7 @@ IN_PROJ_WEIGHT = 'in_proj_weight'
92
86
  IN_PROJ_BIAS = 'in_proj_bias'
93
87
  BIAS_K = 'bias_k'
94
88
  BIAS_V = 'bias_v'
89
+
90
+ # # Batch size value for 'reshape' and 'view' operators,
91
+ # # the value is -1 so the batch size is inferred from the length of the array and remaining dimensions.
92
+ BATCH_DIM_VALUE = -1
@@ -19,8 +19,8 @@ from torch import sigmoid
19
19
 
20
20
  from model_compression_toolkit.core.common.defaultdict import DefaultDict
21
21
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo, ChannelAxis
22
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
23
- from model_compression_toolkit.core.common.constants import SOFTMAX_THRESHOLD
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
+ from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
24
24
  from model_compression_toolkit.core.pytorch.constants import KERNEL
25
25
  from model_compression_toolkit.core.pytorch.quantizer.fake_quant_builder import power_of_two_quantization, \
26
26
  symmetric_quantization, uniform_quantization
@@ -22,7 +22,7 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
22
22
  from model_compression_toolkit.core.common import BaseNode
23
23
  from model_compression_toolkit.core.common.substitutions.linear_collapsing import Conv2DCollapsing
24
24
  from model_compression_toolkit.core.pytorch.constants import KERNEL, KERNEL_SIZE, STRIDES, DILATIONS, BIAS, USE_BIAS, FILTERS, PADDING, GROUPS
25
- from model_compression_toolkit.core.common.logger import Logger
25
+ from model_compression_toolkit.logger import Logger
26
26
 
27
27
 
28
28
  def linear_collapsing_node_matchers() -> Tuple[NodeOperationMatcher, NodeOperationMatcher]:
@@ -20,7 +20,7 @@ import torch.nn as nn
20
20
  import operator
21
21
  from typing import List
22
22
 
23
- from model_compression_toolkit.core.common.logger import Logger
23
+ from model_compression_toolkit.logger import Logger
24
24
  from model_compression_toolkit.core import common
25
25
  from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
26
26
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
@@ -25,6 +25,7 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
25
25
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
26
26
  from model_compression_toolkit.core.pytorch.constants import KERNEL, BIAS, INPLACE, HARDTANH_MIN_VAL, HARDTANH_MAX_VAL, \
27
27
  RELU_POT_BOUND
28
+ from model_compression_toolkit.logger import Logger
28
29
 
29
30
 
30
31
  class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
@@ -102,8 +103,8 @@ class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
102
103
  else:
103
104
  return graph
104
105
  else:
105
- common.Logger.error(f"In substitution with wrong matched pattern")
106
- common.Logger.debug(
106
+ Logger.error(f"In substitution with wrong matched pattern")
107
+ Logger.debug(
107
108
  f"Node named:{non_linear_node.name} changed "
108
109
  f"to:{non_linear_node.type}")
109
110
 
@@ -14,10 +14,13 @@
14
14
  # ==============================================================================
15
15
  from torch import reshape
16
16
  import torch
17
+
18
+ from model_compression_toolkit.logger import Logger
17
19
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
18
20
  from model_compression_toolkit.core import common
19
21
  from model_compression_toolkit.core.common.graph.base_graph import Graph
20
22
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
23
+ from model_compression_toolkit.core.pytorch.constants import BATCH_DIM_VALUE
21
24
 
22
25
 
23
26
  class ReshapeWithStaticShapes(common.BaseSubstitution):
@@ -47,14 +50,25 @@ class ReshapeWithStaticShapes(common.BaseSubstitution):
47
50
  Returns:
48
51
  Graph after applying the substitution.
49
52
  """
53
+ # we want the batch size value to infer from the length of the array and remaining dimensions
54
+ if len(node.output_shape) == 1:
55
+ node.output_shape[0][0] = BATCH_DIM_VALUE
56
+ else:
57
+ Logger.error('Reshape or view nodes should have a single output shape') # pragma: no cover
58
+
50
59
  # configure the new static output shape attribute
51
60
  node.op_call_args = node.output_shape
52
61
 
53
62
  # modify the node input info
54
63
  node.input_shape = [node.input_shape[0]]
64
+
65
+ # the first input is the tensor to be reshaped, we want his batch size value to infer
66
+ # from the length of the array and remaining dimensions
67
+ node.input_shape[0][0] = BATCH_DIM_VALUE
68
+
55
69
  nodes_to_check = []
56
70
  for in_edge in graph.incoming_edges(node):
57
- if in_edge.sink_index > 0: # the first input is the tensor to be reshaped
71
+ if in_edge.sink_index > 0: # the first input is the tensor to be reshaped
58
72
  nodes_to_check.append(in_edge.source_node)
59
73
  graph.remove_edge(in_edge.source_node, node)
60
74
  for n in nodes_to_check:
@@ -80,4 +94,4 @@ def clean_graph_from_nodes_without_out_edges(graph: Graph,
80
94
  graph.remove_edge(in_edge.source_node, node)
81
95
  graph.remove_node(node)
82
96
  for n in nodes_to_check:
83
- clean_graph_from_nodes_without_out_edges(graph, n)
97
+ clean_graph_from_nodes_without_out_edges(graph, n)