mct-nightly 1.8.0.22032023.post333__py3-none-any.whl → 1.8.0.22052023.post408__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/METADATA +4 -3
  2. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/RECORD +294 -284
  3. model_compression_toolkit/__init__.py +9 -32
  4. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  5. model_compression_toolkit/core/__init__.py +14 -0
  6. model_compression_toolkit/core/analyzer.py +3 -2
  7. model_compression_toolkit/core/common/__init__.py +0 -1
  8. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  9. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  11. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  12. model_compression_toolkit/core/common/framework_info.py +1 -1
  13. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  14. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  15. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  16. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  18. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  19. model_compression_toolkit/core/common/memory_computation.py +1 -1
  20. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  21. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  23. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  26. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  28. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  29. model_compression_toolkit/core/common/model_collector.py +2 -2
  30. model_compression_toolkit/core/common/model_validation.py +1 -1
  31. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  32. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  33. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  34. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  35. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  36. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  37. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  38. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  39. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  50. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  51. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  52. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  54. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  55. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  56. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  57. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  58. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  60. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  63. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  65. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  66. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  67. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  69. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +66 -21
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  73. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  74. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  75. model_compression_toolkit/core/keras/constants.py +0 -7
  76. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  85. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  86. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  87. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  88. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  89. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  90. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  91. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  92. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  93. model_compression_toolkit/core/keras/reader/common.py +1 -1
  94. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  95. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  99. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  100. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/constants.py +0 -6
  102. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  103. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  109. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  110. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  111. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  112. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  113. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  114. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  115. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  116. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  117. model_compression_toolkit/core/runner.py +7 -7
  118. model_compression_toolkit/exporter/__init__.py +6 -3
  119. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  120. model_compression_toolkit/exporter/model_exporter/keras/export_serialization_format.py +20 -0
  121. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/fakely_quant_tflite_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/{tflite → keras}/int8_tflite_exporter.py +1 -1
  124. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +60 -22
  125. model_compression_toolkit/exporter/model_exporter/pytorch/export_serialization_format.py +20 -0
  126. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  127. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  128. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +54 -31
  129. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +5 -3
  130. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +5 -3
  131. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +2 -2
  132. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +3 -2
  133. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +4 -3
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +2 -2
  135. model_compression_toolkit/gptq/common/gptq_config.py +2 -4
  136. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  137. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  138. model_compression_toolkit/gptq/common/gptq_training.py +5 -4
  139. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  140. model_compression_toolkit/gptq/keras/gptq_training.py +41 -14
  141. model_compression_toolkit/gptq/keras/graph_info.py +4 -0
  142. model_compression_toolkit/gptq/keras/quantization_facade.py +27 -20
  143. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  144. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -2
  145. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  146. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -5
  147. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +21 -16
  148. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  149. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +2 -2
  150. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  151. model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
  152. model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -13
  153. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  154. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -3
  155. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  156. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -2
  157. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +13 -5
  158. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  159. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +2 -2
  160. model_compression_toolkit/gptq/runner.py +3 -2
  161. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +11 -12
  162. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  163. model_compression_toolkit/ptq/__init__.py +3 -0
  164. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  165. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  166. model_compression_toolkit/qat/__init__.py +4 -0
  167. model_compression_toolkit/qat/common/__init__.py +1 -2
  168. model_compression_toolkit/qat/common/qat_config.py +5 -1
  169. model_compression_toolkit/qat/keras/quantization_facade.py +34 -28
  170. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  171. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +31 -4
  172. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +13 -11
  173. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +9 -9
  174. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  175. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  176. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +4 -3
  177. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +7 -5
  178. model_compression_toolkit/quantizers_infrastructure/__init__.py +2 -2
  179. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  180. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  181. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +5 -0
  182. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +2 -2
  183. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +147 -0
  184. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +5 -5
  185. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  186. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  187. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +3 -5
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +2 -3
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  211. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +9 -9
  212. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  213. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -6
  214. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  215. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  217. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +2 -2
  218. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  219. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  220. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  221. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  222. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  223. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  224. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  225. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  226. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  227. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  228. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  229. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  232. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  233. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  234. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  235. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  236. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  237. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  238. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  239. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  240. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  241. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  242. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  243. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  244. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  248. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  250. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  254. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  255. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  257. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  259. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  261. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  263. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  265. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  273. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  274. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  275. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  276. model_compression_toolkit/exporter/model_exporter/tflite/__init__.py +0 -14
  277. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +0 -73
  278. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/LICENSE.md +0 -0
  279. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/WHEEL +0 -0
  280. {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/top_level.txt +0 -0
  281. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  282. /model_compression_toolkit/{core/common/logger.py → logger.py} +0 -0
  283. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  284. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  285. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  286. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  287. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  288. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  289. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  290. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  291. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  292. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  293. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  294. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -0,0 +1,147 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from model_compression_toolkit.constants import FOUND_TF
17
+ from model_compression_toolkit.logger import Logger
18
+ from model_compression_toolkit.quantizers_infrastructure import BaseInferableQuantizer, BaseKerasTrainableQuantizer
19
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import \
20
+ ACTIVATION_HOLDER_QUANTIZER, TRAINING, STEPS
21
+
22
+ if FOUND_TF:
23
+ import tensorflow as tf
24
+ from keras.utils import tf_inspect
25
+ from tensorflow_model_optimization.python.core.keras import utils
26
+
27
+ def _make_quantizer_fn(quantizer, x, training):
28
+ """Use currying to return True/False specialized fns to the cond."""
29
+
30
+ def quantizer_fn():
31
+ return quantizer(x, training)
32
+
33
+ return quantizer_fn
34
+
35
+ keras = tf.keras
36
+
37
+ class ActivationQuantizationHolder(keras.layers.Layer):
38
+ """
39
+ Keras layer to hold an activation quantizer and quantize during inference.
40
+ """
41
+ def __init__(self,
42
+ activation_holder_quantizer: BaseInferableQuantizer,
43
+ **kwargs):
44
+ """
45
+
46
+ Args:
47
+ activation_holder_quantizer: Quantizer to use during inference.
48
+ **kwargs: Key-word arguments for the base layer
49
+ """
50
+
51
+ super(ActivationQuantizationHolder, self).__init__(**kwargs)
52
+ self.activation_holder_quantizer = activation_holder_quantizer
53
+
54
+ def get_config(self):
55
+ """
56
+ Returns: Configuration of ActivationQuantizationHolder.
57
+
58
+ """
59
+ base_config = super(ActivationQuantizationHolder, self).get_config()
60
+ config = {
61
+ ACTIVATION_HOLDER_QUANTIZER: keras.utils.serialize_keras_object(self.activation_holder_quantizer)}
62
+
63
+ return dict(list(base_config.items()) + list(config.items()))
64
+
65
+ @classmethod
66
+ def from_config(cls, config):
67
+ """
68
+
69
+ Args:
70
+ config(dict): dictionary of ActivationQuantizationHolder Configuration
71
+
72
+ Returns: A ActivationQuantizationHolder object
73
+
74
+ """
75
+ config = config.copy()
76
+ activation_holder_quantizer = keras.utils.deserialize_keras_object(config.pop(ACTIVATION_HOLDER_QUANTIZER),
77
+ module_objects=globals(),
78
+ custom_objects=None)
79
+
80
+ return cls(activation_holder_quantizer=activation_holder_quantizer,
81
+ **config)
82
+
83
+ def build(self, input_shape):
84
+ """
85
+ ActivationQuantizationHolder build function.
86
+ Args:
87
+ input_shape: the layer input shape
88
+
89
+ Returns: None
90
+
91
+ """
92
+ super(ActivationQuantizationHolder, self).build(input_shape)
93
+
94
+ self.optimizer_step = self.add_weight(
95
+ STEPS,
96
+ initializer=tf.keras.initializers.Constant(-1),
97
+ dtype=tf.dtypes.int32,
98
+ trainable=False)
99
+
100
+ self.activation_holder_quantizer.initialize_quantization(None,
101
+ self.name + '/out_',
102
+ self)
103
+
104
+ def call(self,
105
+ inputs: tf.Tensor,
106
+ training=None) -> tf.Tensor:
107
+ """
108
+ Quantizes the input tensor using the activation quantizer the ActivationQuantizationHolder holds.
109
+
110
+ Args:
111
+ inputs: Input tensors to quantize use the activation quantizer the object holds
112
+ training: a boolean stating if layer is in training mode.
113
+
114
+ Returns: Output of the activation quantizer (quantized input tensor).
115
+
116
+ """
117
+ if training is None:
118
+ training = tf.keras.backend.learning_phase()
119
+
120
+ activation_quantizer_args_spec = tf_inspect.getfullargspec(self.activation_holder_quantizer.__call__).args
121
+ if TRAINING in activation_quantizer_args_spec:
122
+ return utils.smart_cond(
123
+ training,
124
+ _make_quantizer_fn(self.activation_holder_quantizer, inputs, True),
125
+ _make_quantizer_fn(self.activation_holder_quantizer, inputs, False))
126
+
127
+ return self.activation_holder_quantizer(inputs)
128
+
129
+ def convert_to_inferable_quantizers(self):
130
+ """
131
+ Convert layer's quantizer to inferable quantizer.
132
+
133
+ Returns:
134
+ None
135
+ """
136
+ if isinstance(self.activation_holder_quantizer, BaseKerasTrainableQuantizer):
137
+ self.activation_holder_quantizer = self.activation_holder_quantizer.convert2inferable()
138
+
139
+
140
+
141
+
142
+ else:
143
+ class ActivationQuantizationHolder: # pragma: no cover
144
+ def __init__(self, *args, **kwargs):
145
+ Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
146
+ 'when using ActivationQuantizationHolder. '
147
+ 'Could not find Tensorflow package.')
@@ -12,17 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from model_compression_toolkit.core.common import Logger
16
- from model_compression_toolkit.core.common.constants import FOUND_TF
15
+ from model_compression_toolkit.logger import Logger
16
+ from model_compression_toolkit.constants import FOUND_TF
17
17
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_all_subclasses import get_all_subclasses
18
18
 
19
19
  if FOUND_TF:
20
20
  import tensorflow as tf
21
21
  from model_compression_toolkit import quantizers_infrastructure as qi
22
22
  from model_compression_toolkit.quantizers_infrastructure import BaseKerasTrainableQuantizer
23
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer \
24
- import \
25
- BaseKerasInferableQuantizer
23
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import BaseKerasInferableQuantizer
26
24
  keras = tf.keras
27
25
 
28
26
  def keras_load_quantized_model(filepath, custom_objects=None, compile=True, options=None):
@@ -57,6 +55,8 @@ if FOUND_TF:
57
55
 
58
56
  # Add non-quantizers custom objects
59
57
  qi_custom_objects.update({qi.KerasQuantizationWrapper.__name__: qi.KerasQuantizationWrapper})
58
+ qi_custom_objects.update({qi.ActivationQuantizationHolder.__name__: qi.ActivationQuantizationHolder})
59
+
60
60
  if custom_objects is not None:
61
61
  qi_custom_objects.update(custom_objects)
62
62
  return tf.keras.models.load_model(filepath,
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  from typing import Dict, List, Any, Tuple
16
16
  from model_compression_toolkit import quantizers_infrastructure as qi
17
- from model_compression_toolkit.core.common.constants import FOUND_TF
18
- from model_compression_toolkit.core.common.logger import Logger
17
+ from model_compression_toolkit.constants import FOUND_TF
18
+ from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.quantizers_infrastructure import BaseInferableQuantizer
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import WEIGHTS_QUANTIZERS, ACTIVATION_QUANTIZERS, LAYER, STEPS, TRAINING
21
21
 
@@ -17,10 +17,10 @@ from typing import List
17
17
 
18
18
  import numpy as np
19
19
 
20
- from model_compression_toolkit.core.common.logger import Logger
21
- from model_compression_toolkit.core.common.constants import FOUND_TF
20
+ from model_compression_toolkit.logger import Logger
21
+ from model_compression_toolkit.constants import FOUND_TF
22
22
 
23
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
23
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
24
24
  from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
25
25
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
26
26
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import MULTIPLIER_N_BITS, EPS
@@ -16,10 +16,10 @@ from typing import List
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.logger import Logger
20
- from model_compression_toolkit.core.common.constants import FOUND_TF
19
+ from model_compression_toolkit.logger import Logger
20
+ from model_compression_toolkit.constants import FOUND_TF
21
21
 
22
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
23
  from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
24
24
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
25
25
 
@@ -16,9 +16,9 @@ from typing import List
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
19
+ from model_compression_toolkit.constants import FOUND_TF
20
20
 
21
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
22
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
23
23
  QuantizationTarget
24
24
 
@@ -16,9 +16,9 @@ from typing import List
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.logger import Logger
20
- from model_compression_toolkit.core.common.constants import FOUND_TF
21
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+ from model_compression_toolkit.logger import Logger
20
+ from model_compression_toolkit.constants import FOUND_TF
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
22
  from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
23
23
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
24
24
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.quant_utils import \
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from abc import abstractmethod
16
16
 
17
- from model_compression_toolkit.core.common.constants import FOUND_TF
17
+ from model_compression_toolkit.constants import FOUND_TF
18
18
  from model_compression_toolkit.quantizers_infrastructure import BaseInferableQuantizer
19
19
 
20
20
  if FOUND_TF:
@@ -16,8 +16,8 @@ from typing import List
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+ from model_compression_toolkit.constants import FOUND_TF
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
21
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
22
22
  QuantizationTarget
23
23
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import MULTIPLIER_N_BITS, EPS
@@ -17,8 +17,8 @@ from typing import List
17
17
 
18
18
  import numpy as np
19
19
 
20
- from model_compression_toolkit.core.common.constants import FOUND_TF
21
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
+ from model_compression_toolkit.constants import FOUND_TF
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
22
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
23
23
  QuantizationTarget
24
24
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import MULTIPLIER_N_BITS, EPS
@@ -16,8 +16,8 @@ from typing import List
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+ from model_compression_toolkit.constants import FOUND_TF
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
21
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, QuantizationTarget
22
22
 
23
23
  if FOUND_TF:
@@ -16,9 +16,9 @@ from typing import List
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
19
+ from model_compression_toolkit.constants import FOUND_TF
20
20
 
21
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
22
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, QuantizationTarget
23
23
 
24
24
  if FOUND_TF:
@@ -16,8 +16,8 @@ from typing import List
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+ from model_compression_toolkit.constants import FOUND_TF
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
21
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, QuantizationTarget
22
22
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.quant_utils import \
23
23
  adjust_range_to_include_zero
@@ -16,7 +16,7 @@ from typing import Any
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common import Logger
19
+ from model_compression_toolkit.logger import Logger
20
20
 
21
21
 
22
22
  def validate_uniform_min_max_ranges(min_range: Any, max_range: Any) -> None:
@@ -13,8 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================f
15
15
  from typing import List, Union, Any, Dict, Tuple
16
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
17
- from model_compression_toolkit.core.common.logger import Logger
16
+ from model_compression_toolkit.constants import FOUND_TORCH
17
+ from model_compression_toolkit.logger import Logger
18
18
  from model_compression_toolkit.quantizers_infrastructure import BaseInferableQuantizer
19
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import LAYER, TRAINING
20
20
  import inspect
@@ -184,13 +184,11 @@ if FOUND_TORCH:
184
184
  return self._weights_vars
185
185
 
186
186
  def forward(self,
187
- x: torch.Tensor,
188
187
  *args: List[Any],
189
188
  **kwargs: Dict[str, Any]) -> Union[torch.Tensor, List[torch.Tensor]]:
190
189
  """
191
190
  PytorchQuantizationWrapper forward functions
192
191
  Args:
193
- x: layer's inputs
194
192
  args: arguments to pass to internal layer.
195
193
  kwargs: key-word dictionary to pass to the internal layer.
196
194
 
@@ -218,7 +216,7 @@ if FOUND_TORCH:
218
216
  # ----------------------------------
219
217
  # Layer operation
220
218
  # ----------------------------------
221
- outputs = self.layer(x, *args, **kwargs)
219
+ outputs = self.layer(*args, **kwargs)
222
220
 
223
221
  # ----------------------------------
224
222
  # Quantize all activations
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  import numpy as np
16
16
 
17
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
19
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer \
20
20
  import mark_quantizer, QuantizationTarget
21
21
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants \
@@ -85,7 +85,6 @@ if FOUND_TORCH:
85
85
  Returns:
86
86
  quantized tensor.
87
87
  """
88
- inputs.requires_grad = False
89
88
  return lut_quantizer(inputs, cluster_centers=self.cluster_centers, signed=self.signed,
90
89
  threshold=self.threshold, multiplier_n_bits=self.multiplier_n_bits, eps=self.eps)
91
90
 
@@ -15,8 +15,8 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
21
21
  QuantizationTarget
22
22
 
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  import numpy as np
16
16
 
17
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
19
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
20
20
  QuantizationTarget
21
21
 
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  import numpy as np
16
16
 
17
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
19
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
20
20
  QuantizationTarget
21
21
 
@@ -15,8 +15,8 @@
15
15
  import numpy as np
16
16
  import warnings
17
17
 
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer \
21
21
  import mark_quantizer
22
22
 
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from abc import abstractmethod
16
16
 
17
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
18
  from model_compression_toolkit.quantizers_infrastructure import BaseInferableQuantizer
19
19
 
20
20
  if FOUND_TORCH:
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  import numpy as np
16
16
 
17
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
19
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
20
20
 
21
21
  if FOUND_TORCH:
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  import numpy as np
16
16
 
17
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
19
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
20
20
 
21
21
  if FOUND_TORCH:
@@ -15,8 +15,8 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer \
21
21
  import mark_quantizer, QuantizationTarget
22
22
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants \
@@ -15,8 +15,8 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer \
21
21
  import mark_quantizer, \
22
22
  QuantizationTarget
@@ -15,8 +15,8 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
21
21
  QuantizationTarget
22
22
 
@@ -15,8 +15,8 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
21
21
  QuantizationTarget
22
22
 
@@ -15,9 +15,9 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
- from model_compression_toolkit.core.common.logger import Logger
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.logger import Logger
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
21
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget, \
22
22
  mark_quantizer
23
23
 
@@ -18,7 +18,7 @@ from typing import Union, List, Any
18
18
  from inspect import signature
19
19
 
20
20
  from model_compression_toolkit.core import common
21
- from model_compression_toolkit.core.common import Logger
21
+ from model_compression_toolkit.logger import Logger
22
22
 
23
23
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer, \
24
24
  QuantizationTarget
@@ -55,9 +55,9 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
55
55
  for i, (k, v) in enumerate(self.get_sig().parameters.items()):
56
56
  if i == 0:
57
57
  if v.annotation not in [TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]:
58
- common.Logger.error(f"First parameter must be either TrainableQuantizerWeightsConfig or TrainableQuantizerActivationConfig") # pragma: no cover
58
+ Logger.error(f"First parameter must be either TrainableQuantizerWeightsConfig or TrainableQuantizerActivationConfig") # pragma: no cover
59
59
  elif v.default is v.empty:
60
- common.Logger.error(f"Parameter {k} doesn't have a default value") # pragma: no cover
60
+ Logger.error(f"Parameter {k} doesn't have a default value") # pragma: no cover
61
61
 
62
62
  super(BaseTrainableQuantizer, self).__init__()
63
63
  self.quantization_config = quantization_config
@@ -73,15 +73,15 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
73
73
  if static_quantization_target == QuantizationTarget.Weights:
74
74
  self.validate_weights()
75
75
  if self.quantization_config.weights_quantization_method not in static_quantization_method:
76
- common.Logger.error(
76
+ Logger.error(
77
77
  f'Quantization method mismatch expected: {static_quantization_method} and got {self.quantization_config.weights_quantization_method}')
78
78
  elif static_quantization_target == QuantizationTarget.Activation:
79
79
  self.validate_activation()
80
80
  if self.quantization_config.activation_quantization_method not in static_quantization_method:
81
- common.Logger.error(
81
+ Logger.error(
82
82
  f'Quantization method mismatch expected: {static_quantization_method} and got {self.quantization_config.activation_quantization_method}')
83
83
  else:
84
- common.Logger.error(
84
+ Logger.error(
85
85
  f'Unknown Quantization Part:{static_quantization_target}') # pragma: no cover
86
86
 
87
87
  self.quantizer_parameters = {}
@@ -145,7 +145,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
145
145
 
146
146
  """
147
147
  if self.activation_quantization() or not self.weights_quantization():
148
- common.Logger.error(f'Expect weight quantization got activation')
148
+ Logger.error(f'Expect weight quantization got activation')
149
149
 
150
150
  def validate_activation(self) -> None:
151
151
  """
@@ -153,7 +153,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
153
153
 
154
154
  """
155
155
  if not self.activation_quantization() or self.weights_quantization():
156
- common.Logger.error(f'Expect activation quantization got weight')
156
+ Logger.error(f'Expect activation quantization got weight')
157
157
 
158
158
  def convert2inferable(self) -> BaseInferableQuantizer:
159
159
  """
@@ -183,7 +183,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
183
183
  if name in self.quantizer_parameters:
184
184
  return self.quantizer_parameters[name][VAR]
185
185
  else:
186
- common.Logger.error(f'Variable {name} is not exist in quantizers parameters!') # pragma: no cover
186
+ Logger.error(f'Variable {name} is not exist in quantizers parameters!') # pragma: no cover
187
187
 
188
188
 
189
189
  @abstractmethod
@@ -13,7 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from typing import List
16
- from model_compression_toolkit.core.common import BaseNode, Logger
16
+ from model_compression_toolkit.core.common import BaseNode
17
+ from model_compression_toolkit.logger import Logger
17
18
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
18
19
  TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig, TrainableQuantizerCandidateConfig
19
20
 
@@ -12,12 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Union
15
+ from typing import Union, Any
16
16
 
17
- from model_compression_toolkit.gptq import RoundingType
18
- from model_compression_toolkit import TrainingMethod
19
- from model_compression_toolkit.core.common import Logger
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.logger import Logger
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
19
  from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
22
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants \
23
21
  import QUANTIZATION_TARGET, QUANTIZATION_METHOD, QUANTIZER_TYPE
@@ -26,7 +24,7 @@ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructur
26
24
 
27
25
 
28
26
  def get_trainable_quantizer_class(quant_target: QuantizationTarget,
29
- quantizer_type: Union[TrainingMethod, RoundingType],
27
+ quantizer_type: Union[Any, Any],
30
28
  quant_method: QuantizationMethod,
31
29
  quantizer_base_class: type) -> type:
32
30
  """