mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
  2. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
  3. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +12 -41
  5. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  6. model_compression_toolkit/core/__init__.py +14 -0
  7. model_compression_toolkit/core/analyzer.py +3 -2
  8. model_compression_toolkit/core/common/__init__.py +0 -1
  9. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  11. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  12. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  13. model_compression_toolkit/core/common/framework_info.py +1 -1
  14. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  15. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  16. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  18. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  19. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  20. model_compression_toolkit/core/common/memory_computation.py +1 -1
  21. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  23. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  26. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  28. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  29. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  30. model_compression_toolkit/core/common/model_collector.py +2 -2
  31. model_compression_toolkit/core/common/model_validation.py +1 -1
  32. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  33. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  34. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  35. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  36. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  37. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  38. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  39. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  50. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  51. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  52. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  54. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  55. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  56. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  57. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  58. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  59. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  60. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  61. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  62. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  63. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  65. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  66. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  67. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  68. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  69. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  72. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
  73. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  74. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  75. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  76. model_compression_toolkit/core/keras/constants.py +0 -7
  77. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  85. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  86. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  87. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  88. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  89. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  90. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  91. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  92. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  93. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  94. model_compression_toolkit/core/keras/reader/common.py +1 -1
  95. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  99. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  100. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  102. model_compression_toolkit/core/pytorch/constants.py +4 -6
  103. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  109. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  110. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  111. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  112. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  113. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  114. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  115. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  116. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  117. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  118. model_compression_toolkit/core/runner.py +7 -7
  119. model_compression_toolkit/exporter/__init__.py +5 -0
  120. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
  121. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
  124. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  125. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  126. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
  127. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
  128. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
  129. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
  130. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
  131. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
  132. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
  133. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
  135. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
  136. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
  137. model_compression_toolkit/gptq/__init__.py +6 -0
  138. model_compression_toolkit/gptq/common/gptq_config.py +57 -104
  139. model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
  140. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  141. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  142. model_compression_toolkit/gptq/common/gptq_training.py +30 -39
  143. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  144. model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
  145. model_compression_toolkit/gptq/keras/graph_info.py +8 -33
  146. model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
  147. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  148. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
  149. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  150. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
  151. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  152. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
  153. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
  154. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  155. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
  156. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  157. model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
  158. model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
  159. model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
  160. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  161. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
  162. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  163. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
  164. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  165. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
  166. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
  167. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  168. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
  169. model_compression_toolkit/gptq/runner.py +3 -2
  170. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
  171. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  172. model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
  173. model_compression_toolkit/ptq/__init__.py +3 -0
  174. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  175. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  176. model_compression_toolkit/qat/__init__.py +4 -0
  177. model_compression_toolkit/qat/common/__init__.py +1 -2
  178. model_compression_toolkit/qat/common/qat_config.py +3 -1
  179. model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
  180. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  181. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
  182. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
  183. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  184. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  185. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
  186. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
  187. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  211. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  212. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  213. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  214. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  215. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
  217. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  218. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  219. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  220. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  221. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  222. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
  223. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  224. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
  225. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  226. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
  227. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  228. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
  229. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  232. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  233. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  234. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  235. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  236. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  237. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  238. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  239. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  240. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  241. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  242. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  243. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  244. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  248. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  250. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  254. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  255. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  257. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  259. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  261. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  263. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  265. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  273. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  274. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  275. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  276. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  277. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  278. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  279. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  280. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  281. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  282. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  283. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  284. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  285. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  286. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  287. model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
  288. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
  289. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
  290. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  291. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  292. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  293. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  294. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  300. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  301. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  302. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  303. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  304. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  305. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  306. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  307. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -12,18 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Dict
16
-
17
15
  import numpy as np
18
16
  import torch
19
17
  import torch.nn as nn
20
18
  from torch import Tensor
21
19
 
22
- from model_compression_toolkit.core.common.constants import RANGE_MAX, RANGE_MIN
23
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
24
- from model_compression_toolkit.qat.common.constants import FQ_MIN, FQ_MAX
25
- from model_compression_toolkit.core.common import constants as C
26
- from model_compression_toolkit import quantizers_infrastructure as qi, TrainingMethod
20
+ from model_compression_toolkit.constants import RANGE_MAX, RANGE_MIN
21
+ from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
22
+
23
+ from model_compression_toolkit.qat import TrainingMethod
24
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
25
+ from model_compression_toolkit import quantizers_infrastructure as qi, constants as C
26
+
27
27
  from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
28
28
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
29
29
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
@@ -32,6 +32,7 @@ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructur
32
32
  WeightsUniformInferableQuantizer, ActivationUniformInferableQuantizer
33
33
  from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
34
34
  TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
35
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
35
36
 
36
37
 
37
38
  @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
@@ -64,22 +65,18 @@ class STEUniformWeightQATQuantizer(BasePytorchQATTrainableQuantizer):
64
65
  [-1]) if self.quantization_config.weights_per_channel_threshold else float(
65
66
  self.min_values)
66
67
 
67
- self.quantizer_parameters = {}
68
68
 
69
69
  def initialize_quantization(self,
70
70
  tensor_shape: torch.Size,
71
71
  name: str,
72
- layer: qi.PytorchQuantizationWrapper) -> Dict[str, nn.Parameter]:
72
+ layer: qi.PytorchQuantizationWrapper):
73
73
  """
74
- Add min and max variables to layer.
75
- Args:
76
- tensor_shape: Tensor shape the quantizer quantize.
77
- name: Prefix of variables names.
78
- layer: Layer to add the variables to. The variables are saved
79
- in the layer's scope.
74
+ Add quantizer parameters to the quantizer parameters dictionary
80
75
 
81
- Returns:
82
- Dictionary of new variables.
76
+ Args:
77
+ tensor_shape: tensor shape of the quantized tensor.
78
+ name: Tensor name.
79
+ layer: Layer to quantize.
83
80
  """
84
81
 
85
82
  # Add min and max variables to layer.
@@ -87,9 +84,9 @@ class STEUniformWeightQATQuantizer(BasePytorchQATTrainableQuantizer):
87
84
  layer.register_parameter(name+"_"+FQ_MAX, nn.Parameter(to_torch_tensor(self.max_values), requires_grad=False))
88
85
 
89
86
  # Save the quantizer parameters for later calculations
90
- self.quantizer_parameters = {FQ_MIN: layer.get_parameter(name+"_"+FQ_MIN), FQ_MAX: layer.get_parameter(name+"_"+FQ_MAX)}
87
+ self.add_quantizer_variable(FQ_MIN, layer.get_parameter(name+"_"+FQ_MIN), VariableGroup.QPARAMS)
88
+ self.add_quantizer_variable(FQ_MAX, layer.get_parameter(name+"_"+FQ_MAX), VariableGroup.QPARAMS)
91
89
 
92
- return self.quantizer_parameters
93
90
 
94
91
  def __call__(self,
95
92
  inputs: nn.Parameter,
@@ -102,7 +99,7 @@ class STEUniformWeightQATQuantizer(BasePytorchQATTrainableQuantizer):
102
99
  Returns:
103
100
  quantized tensor
104
101
  """
105
- return uniform_quantizer(inputs, self.quantizer_parameters[FQ_MIN], self.quantizer_parameters[FQ_MAX], self.num_bits)
102
+ return uniform_quantizer(inputs, self.get_quantizer_variable(FQ_MIN), self.get_quantizer_variable(FQ_MAX), self.num_bits)
106
103
 
107
104
  def convert2inferable(self) -> WeightsUniformInferableQuantizer:
108
105
  """
@@ -111,8 +108,8 @@ class STEUniformWeightQATQuantizer(BasePytorchQATTrainableQuantizer):
111
108
  Returns:
112
109
  A pytorch inferable quanizer object.
113
110
  """
114
- _min = self.quantizer_parameters[FQ_MIN].cpu().detach().numpy()
115
- _max = self.quantizer_parameters[FQ_MAX].cpu().detach().numpy()
111
+ _min = self.get_quantizer_variable(FQ_MIN).cpu().detach().numpy()
112
+ _max = self.get_quantizer_variable(FQ_MAX).cpu().detach().numpy()
116
113
 
117
114
  return WeightsUniformInferableQuantizer(num_bits=self.num_bits,
118
115
  min_range=_min, max_range=_max,
@@ -143,21 +140,25 @@ class STEUniformActivationQATQuantizer(BasePytorchQATTrainableQuantizer):
143
140
  self.min_range_tensor = torch.Tensor([np_min_range])
144
141
  self.max_range_tensor = torch.Tensor([np_max_range])
145
142
  self.num_bits = quantization_config.activation_n_bits
146
- self.quantizer_parameters = {}
147
143
 
148
144
  def initialize_quantization(self,
149
145
  tensor_shape: torch.Size,
150
146
  name: str,
151
- layer: qi.PytorchQuantizationWrapper) -> Dict[str, nn.Parameter]:
147
+ layer: qi.PytorchQuantizationWrapper):
152
148
  """
153
- Add min and max variables to layer.
149
+ Add quantizer parameters to the quantizer parameters dictionary
150
+
151
+ Args:
152
+ tensor_shape: tensor shape of the quantized tensor.
153
+ name: Tensor name.
154
+ layer: Layer to quantize.
154
155
  """
155
156
  layer.register_parameter(name+"_"+FQ_MIN, nn.Parameter(to_torch_tensor(self.min_range_tensor), requires_grad=True))
156
157
  layer.register_parameter(name+"_"+FQ_MAX, nn.Parameter(to_torch_tensor(self.max_range_tensor), requires_grad=True))
157
158
 
158
159
  # Save the quantizer parameters for later calculations
159
- self.quantizer_parameters = {FQ_MIN: layer.get_parameter(name+"_"+FQ_MIN), FQ_MAX: layer.get_parameter(name+"_"+FQ_MAX)}
160
- return self.quantizer_parameters
160
+ self.add_quantizer_variable(FQ_MIN, layer.get_parameter(name+"_"+FQ_MIN), VariableGroup.QPARAMS)
161
+ self.add_quantizer_variable(FQ_MAX, layer.get_parameter(name+"_"+FQ_MAX), VariableGroup.QPARAMS)
161
162
 
162
163
  def __call__(self,
163
164
  inputs: torch.Tensor,
@@ -172,8 +173,8 @@ class STEUniformActivationQATQuantizer(BasePytorchQATTrainableQuantizer):
172
173
  The quantized tensor.
173
174
  """
174
175
 
175
- _min = self.quantizer_parameters[FQ_MIN]
176
- _max = self.quantizer_parameters[FQ_MAX]
176
+ _min = self.get_quantizer_variable(FQ_MIN)
177
+ _max = self.get_quantizer_variable(FQ_MAX)
177
178
  q_tensor = uniform_quantizer(inputs, _min, _max, self.num_bits)
178
179
  return q_tensor
179
180
 
@@ -184,8 +185,8 @@ class STEUniformActivationQATQuantizer(BasePytorchQATTrainableQuantizer):
184
185
  Returns:
185
186
  A pytorch inferable quanizer object.
186
187
  """
187
- _min = self.quantizer_parameters[FQ_MIN].cpu().detach().numpy()
188
- _max = self.quantizer_parameters[FQ_MAX].cpu().detach().numpy()
188
+ _min = self.get_quantizer_variable(FQ_MIN).cpu().detach().numpy()
189
+ _max = self.get_quantizer_variable(FQ_MAX).cpu().detach().numpy()
189
190
 
190
191
  return ActivationUniformInferableQuantizer(num_bits=self.num_bits,
191
192
  min_range=_min, max_range=_max)
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ # Quantizers constants (for GPTQ, QAT, etc.)
16
17
  FQ_MIN = "min"
17
18
  FQ_MAX = "max"
18
19
  THRESHOLD_TENSOR = "ptq_threshold_tensor"
@@ -15,7 +15,7 @@
15
15
  from enum import Enum
16
16
  from typing import Any, Dict, List
17
17
 
18
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
19
19
 
20
20
 
21
21
  class QuantizationTarget(Enum):
@@ -28,5 +28,4 @@ def get_all_subclasses(cls: type) -> Set[type]:
28
28
 
29
29
  """
30
30
 
31
- return set(cls.__subclasses__()).union(
32
- [s for c in cls.__subclasses__() for s in get_all_subclasses(c)])
31
+ return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in get_all_subclasses(c)])
@@ -13,8 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common import Logger
17
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
16
+ from model_compression_toolkit.logger import Logger
17
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
18
18
  from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
19
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import QUANTIZATION_TARGET, \
20
20
  QUANTIZATION_METHOD
@@ -41,7 +41,7 @@ def get_inferable_quantizer_class(quant_target: QuantizationTarget,
41
41
  qat_quantizer_classes = get_all_subclasses(quantizer_base_class)
42
42
  filtered_quantizers = list(filter(lambda q_class: getattr(q_class, QUANTIZATION_TARGET) == quant_target and
43
43
  getattr(q_class, QUANTIZATION_METHOD) is not None and
44
- quant_method in getattr(q_class, QUANTIZATION_METHOD),
44
+ quant_method in getattr(q_class, QUANTIZATION_METHOD),
45
45
  qat_quantizer_classes))
46
46
 
47
47
  if len(filtered_quantizers) != 1:
@@ -12,8 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from model_compression_toolkit.core.common import Logger
16
- from model_compression_toolkit.core.common.constants import FOUND_TF
15
+ from model_compression_toolkit.logger import Logger
16
+ from model_compression_toolkit.constants import FOUND_TF
17
17
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_all_subclasses import get_all_subclasses
18
18
 
19
19
  if FOUND_TF:
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  from typing import Dict, List, Any, Tuple
16
16
  from model_compression_toolkit import quantizers_infrastructure as qi
17
- from model_compression_toolkit.core.common.constants import FOUND_TF
18
- from model_compression_toolkit.core.common.logger import Logger
17
+ from model_compression_toolkit.constants import FOUND_TF
18
+ from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.quantizers_infrastructure import BaseInferableQuantizer
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import WEIGHTS_QUANTIZERS, ACTIVATION_QUANTIZERS, LAYER, STEPS, TRAINING
21
21
 
@@ -18,6 +18,10 @@ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructur
18
18
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.weights_inferable_quantizers.weights_pot_inferable_quantizer import WeightsPOTInferableQuantizer
19
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.weights_inferable_quantizers.weights_symmetric_inferable_quantizer import WeightsSymmetricInferableQuantizer
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.weights_inferable_quantizers.weights_uniform_inferable_quantizer import WeightsUniformInferableQuantizer
21
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.weights_inferable_quantizers.weights_lut_symmetric_inferable_quantizer import WeightsLUTSymmetricInferableQuantizer
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.weights_inferable_quantizers.weights_lut_pot_inferable_quantizer import WeightsLUTPOTInferableQuantizer
23
+
21
24
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.activation_inferable_quantizers.activation_pot_inferable_quantizer import ActivationPOTInferableQuantizer
22
25
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.activation_inferable_quantizers.activation_symmetric_inferable_quantizer import ActivationSymmetricInferableQuantizer
23
26
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.activation_inferable_quantizers.activation_uniform_inferable_quantizer import ActivationUniformInferableQuantizer
27
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.activation_inferable_quantizers.activation_lut_pot_inferable_quantizer import ActivationLutPOTInferableQuantizer
@@ -17,10 +17,10 @@ from typing import List
17
17
 
18
18
  import numpy as np
19
19
 
20
- from model_compression_toolkit.core.common.logger import Logger
21
- from model_compression_toolkit.core.common.constants import FOUND_TF
20
+ from model_compression_toolkit.logger import Logger
21
+ from model_compression_toolkit.constants import FOUND_TF
22
22
 
23
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
23
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
24
24
  from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
25
25
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
26
26
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import MULTIPLIER_N_BITS, EPS
@@ -16,10 +16,10 @@ from typing import List
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.logger import Logger
20
- from model_compression_toolkit.core.common.constants import FOUND_TF
19
+ from model_compression_toolkit.logger import Logger
20
+ from model_compression_toolkit.constants import FOUND_TF
21
21
 
22
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
23
  from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
24
24
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
25
25
 
@@ -16,9 +16,9 @@ from typing import List
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
19
+ from model_compression_toolkit.constants import FOUND_TF
20
20
 
21
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
22
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
23
23
  QuantizationTarget
24
24
 
@@ -16,9 +16,9 @@ from typing import List
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.logger import Logger
20
- from model_compression_toolkit.core.common.constants import FOUND_TF
21
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+ from model_compression_toolkit.logger import Logger
20
+ from model_compression_toolkit.constants import FOUND_TF
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
22
  from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
23
23
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
24
24
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.quant_utils import \
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from abc import abstractmethod
16
16
 
17
- from model_compression_toolkit.core.common.constants import FOUND_TF
17
+ from model_compression_toolkit.constants import FOUND_TF
18
18
  from model_compression_toolkit.quantizers_infrastructure import BaseInferableQuantizer
19
19
 
20
20
  if FOUND_TF:
@@ -22,3 +22,4 @@ MIN_RANGE = 'min_range'
22
22
  MAX_RANGE = 'max_range'
23
23
  CHANNEL_AXIS = 'channel_axis'
24
24
  INPUT_RANK = 'input_rank'
25
+ CLUSTER_CENTERS = 'cluster_centers'
@@ -16,8 +16,8 @@ from typing import List
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+ from model_compression_toolkit.constants import FOUND_TF
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
21
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
22
22
  QuantizationTarget
23
23
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import MULTIPLIER_N_BITS, EPS
@@ -17,8 +17,8 @@ from typing import List
17
17
 
18
18
  import numpy as np
19
19
 
20
- from model_compression_toolkit.core.common.constants import FOUND_TF
21
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
20
+ from model_compression_toolkit.constants import FOUND_TF
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
22
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
23
23
  QuantizationTarget
24
24
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import MULTIPLIER_N_BITS, EPS
@@ -16,8 +16,8 @@ from typing import List
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+ from model_compression_toolkit.constants import FOUND_TF
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
21
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, QuantizationTarget
22
22
 
23
23
  if FOUND_TF:
@@ -16,9 +16,9 @@ from typing import List
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
19
+ from model_compression_toolkit.constants import FOUND_TF
20
20
 
21
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
22
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, QuantizationTarget
23
23
 
24
24
  if FOUND_TF:
@@ -16,8 +16,8 @@ from typing import List
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common.constants import FOUND_TF
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
19
+ from model_compression_toolkit.constants import FOUND_TF
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
21
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, QuantizationTarget
22
22
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.quant_utils import \
23
23
  adjust_range_to_include_zero
@@ -16,7 +16,7 @@ from typing import Any
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.core.common import Logger
19
+ from model_compression_toolkit.logger import Logger
20
20
 
21
21
 
22
22
  def validate_uniform_min_max_ranges(min_range: Any, max_range: Any) -> None:
@@ -13,8 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================f
15
15
  from typing import List, Union, Any, Dict, Tuple
16
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
17
- from model_compression_toolkit.core.common.logger import Logger
16
+ from model_compression_toolkit.constants import FOUND_TORCH
17
+ from model_compression_toolkit.logger import Logger
18
18
  from model_compression_toolkit.quantizers_infrastructure import BaseInferableQuantizer
19
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import LAYER, TRAINING
20
20
  import inspect
@@ -184,13 +184,11 @@ if FOUND_TORCH:
184
184
  return self._weights_vars
185
185
 
186
186
  def forward(self,
187
- x: torch.Tensor,
188
187
  *args: List[Any],
189
188
  **kwargs: Dict[str, Any]) -> Union[torch.Tensor, List[torch.Tensor]]:
190
189
  """
191
190
  PytorchQuantizationWrapper forward functions
192
191
  Args:
193
- x: layer's inputs
194
192
  args: arguments to pass to internal layer.
195
193
  kwargs: key-word dictionary to pass to the internal layer.
196
194
 
@@ -218,7 +216,7 @@ if FOUND_TORCH:
218
216
  # ----------------------------------
219
217
  # Layer operation
220
218
  # ----------------------------------
221
- outputs = self.layer(x, *args, **kwargs)
219
+ outputs = self.layer(*args, **kwargs)
222
220
 
223
221
  # ----------------------------------
224
222
  # Quantize all activations
@@ -240,6 +238,18 @@ if FOUND_TORCH:
240
238
 
241
239
  return outputs
242
240
 
241
+ def get_quantized_weights(self) -> Dict[str, torch.Tensor]:
242
+ """
243
+
244
+ Returns: A dictionary of weights attributes to quantized weights.
245
+
246
+ """
247
+ quantized_weights = {}
248
+ weights_var = self.get_weights_vars()
249
+ for name, w, quantizer in weights_var:
250
+ quantized_weights[name] = quantizer(w)
251
+ return quantized_weights
252
+
243
253
  else:
244
254
  class PytorchQuantizationWrapper(object):
245
255
  def __init__(self,
@@ -19,6 +19,8 @@ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructur
19
19
  import ActivationSymmetricInferableQuantizer
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.activation_inferable_quantizers.activation_uniform_inferable_quantizer \
21
21
  import ActivationUniformInferableQuantizer
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.activation_inferable_quantizers.activation_lut_pot_inferable_quantizer \
23
+ import ActivationLutPOTInferableQuantizer
22
24
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.base_pytorch_inferable_quantizer \
23
25
  import BasePyTorchInferableQuantizer
24
26
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.weights_inferable_quantizers.weights_pot_inferable_quantizer \
@@ -27,3 +29,7 @@ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructur
27
29
  import WeightsSymmetricInferableQuantizer
28
30
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.weights_inferable_quantizers.weights_uniform_inferable_quantizer \
29
31
  import WeightsUniformInferableQuantizer
32
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.weights_inferable_quantizers.weights_lut_symmetric_inferable_quantizer \
33
+ import WeightsLUTSymmetricInferableQuantizer
34
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.weights_inferable_quantizers.weights_lut_pot_inferable_quantizer \
35
+ import WeightsLUTPOTInferableQuantizer
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  import numpy as np
16
16
 
17
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
19
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer \
20
20
  import mark_quantizer, QuantizationTarget
21
21
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants \
@@ -15,8 +15,8 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
21
21
  QuantizationTarget
22
22
 
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  import numpy as np
16
16
 
17
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
19
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
20
20
  QuantizationTarget
21
21
 
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  import numpy as np
16
16
 
17
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
19
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
20
20
  QuantizationTarget
21
21
 
@@ -15,8 +15,8 @@
15
15
  import numpy as np
16
16
  import warnings
17
17
 
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer \
21
21
  import mark_quantizer
22
22
 
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from abc import abstractmethod
16
16
 
17
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
18
  from model_compression_toolkit.quantizers_infrastructure import BaseInferableQuantizer
19
19
 
20
20
  if FOUND_TORCH:
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  import numpy as np
16
16
 
17
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
19
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
20
20
 
21
21
  if FOUND_TORCH:
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  import numpy as np
16
16
 
17
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
19
19
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
20
20
 
21
21
  if FOUND_TORCH:
@@ -21,3 +21,6 @@ PER_CHANNEL = 'per_channel'
21
21
  MIN_RANGE = 'min_range'
22
22
  MAX_RANGE = 'max_range'
23
23
  CHANNEL_AXIS = 'channel_axis'
24
+ CLUSTER_CENTERS = 'cluster_centers'
25
+ MULTIPLIER_N_BITS = 'multiplier_n_bits'
26
+ EPS = 'eps'
@@ -15,8 +15,8 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer \
21
21
  import mark_quantizer, QuantizationTarget
22
22
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants \
@@ -15,8 +15,8 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer \
21
21
  import mark_quantizer, \
22
22
  QuantizationTarget
@@ -15,8 +15,8 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
21
21
  QuantizationTarget
22
22
 
@@ -15,8 +15,8 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
20
20
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
21
21
  QuantizationTarget
22
22
 
@@ -15,9 +15,9 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from model_compression_toolkit.core.common.constants import FOUND_TORCH
19
- from model_compression_toolkit.core.common.logger import Logger
20
- from model_compression_toolkit.core.common.target_platform import QuantizationMethod
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.logger import Logger
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
21
  from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget, \
22
22
  mark_quantizer
23
23