mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
  2. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
  3. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +12 -41
  5. model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
  6. model_compression_toolkit/core/__init__.py +14 -0
  7. model_compression_toolkit/core/analyzer.py +3 -2
  8. model_compression_toolkit/core/common/__init__.py +0 -1
  9. model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
  10. model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
  11. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
  12. model_compression_toolkit/core/common/framework_implementation.py +1 -8
  13. model_compression_toolkit/core/common/framework_info.py +1 -1
  14. model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
  15. model_compression_toolkit/core/common/graph/base_graph.py +2 -2
  16. model_compression_toolkit/core/common/graph/base_node.py +57 -1
  17. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
  18. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
  19. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  20. model_compression_toolkit/core/common/memory_computation.py +1 -1
  21. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
  22. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
  23. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
  24. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  25. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
  26. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  27. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
  28. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  29. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  30. model_compression_toolkit/core/common/model_collector.py +2 -2
  31. model_compression_toolkit/core/common/model_validation.py +1 -1
  32. model_compression_toolkit/core/common/network_editors/actions.py +4 -1
  33. model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
  34. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
  35. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
  36. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
  37. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
  38. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  39. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
  42. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
  43. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
  44. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
  45. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
  46. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
  47. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
  48. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
  50. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
  51. model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
  52. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
  53. model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
  54. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
  55. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
  56. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
  57. model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
  58. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  59. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
  60. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
  61. model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
  62. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  63. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
  64. model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
  65. model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
  66. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  67. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  68. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  69. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
  70. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
  71. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
  72. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
  73. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
  74. model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
  75. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
  76. model_compression_toolkit/core/keras/constants.py +0 -7
  77. model_compression_toolkit/core/keras/default_framework_info.py +3 -3
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
  80. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  81. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
  82. model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
  83. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
  84. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  85. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  86. model_compression_toolkit/core/keras/keras_implementation.py +2 -10
  87. model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
  88. model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
  89. model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
  90. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
  91. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
  92. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
  93. model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
  94. model_compression_toolkit/core/keras/reader/common.py +1 -1
  95. model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
  96. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
  97. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
  98. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
  99. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
  100. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  101. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
  102. model_compression_toolkit/core/pytorch/constants.py +4 -6
  103. model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
  104. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  105. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  106. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
  107. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  108. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  109. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
  110. model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
  111. model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
  112. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
  113. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
  114. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
  115. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  116. model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
  117. model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
  118. model_compression_toolkit/core/runner.py +7 -7
  119. model_compression_toolkit/exporter/__init__.py +5 -0
  120. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
  121. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  122. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  123. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
  124. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
  125. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  126. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
  127. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
  128. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
  129. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
  130. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
  131. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
  132. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
  133. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
  134. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
  135. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
  136. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
  137. model_compression_toolkit/gptq/__init__.py +6 -0
  138. model_compression_toolkit/gptq/common/gptq_config.py +57 -104
  139. model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
  140. model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
  141. model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
  142. model_compression_toolkit/gptq/common/gptq_training.py +30 -39
  143. model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
  144. model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
  145. model_compression_toolkit/gptq/keras/graph_info.py +8 -33
  146. model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
  147. model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
  148. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
  149. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
  150. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
  151. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  152. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
  153. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
  154. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
  155. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
  156. model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
  157. model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
  158. model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
  159. model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
  160. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
  161. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
  162. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
  163. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
  164. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  165. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
  166. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
  167. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
  168. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
  169. model_compression_toolkit/gptq/runner.py +3 -2
  170. model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
  171. model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
  172. model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
  173. model_compression_toolkit/ptq/__init__.py +3 -0
  174. model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
  175. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
  176. model_compression_toolkit/qat/__init__.py +4 -0
  177. model_compression_toolkit/qat/common/__init__.py +1 -2
  178. model_compression_toolkit/qat/common/qat_config.py +3 -1
  179. model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
  180. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
  181. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
  182. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
  183. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
  184. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  185. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
  186. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
  187. model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
  188. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
  189. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
  190. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
  191. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
  192. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
  193. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
  194. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
  195. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
  196. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  197. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
  198. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
  199. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
  200. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  201. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  202. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  203. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  204. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
  205. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
  206. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
  207. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
  208. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
  209. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
  210. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
  211. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
  212. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
  213. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
  214. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
  215. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
  216. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
  217. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
  218. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
  219. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
  220. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
  221. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
  222. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
  223. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
  224. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
  225. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
  226. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
  227. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
  228. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
  229. model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
  230. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
  231. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
  232. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
  233. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
  234. model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
  235. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
  236. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
  237. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
  238. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
  239. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
  240. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
  241. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
  242. model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
  243. model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
  244. model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
  245. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
  246. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
  247. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
  248. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
  249. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
  250. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
  251. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
  252. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
  253. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
  254. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
  255. model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
  256. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
  257. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
  258. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
  259. model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
  260. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
  261. model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
  262. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
  263. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
  264. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
  265. model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
  266. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
  267. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
  268. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
  269. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
  270. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
  271. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
  272. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
  273. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
  274. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
  275. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
  276. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
  277. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
  278. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
  279. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
  280. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
  281. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
  282. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
  283. model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
  284. model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
  285. model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
  286. model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
  287. model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
  288. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
  289. {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
  290. /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
  291. /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
  292. /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
  293. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
  294. /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
  295. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
  296. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
  297. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
  298. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
  299. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
  300. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
  301. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
  302. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
  303. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
  304. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
  305. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
  306. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
  307. /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
@@ -13,7 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH, LATEST
16
+ from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH
17
+ from model_compression_toolkit.target_platform_capabilities.constants import LATEST
17
18
 
18
19
 
19
20
  ###############################
@@ -21,8 +22,8 @@ from model_compression_toolkit.core.common.constants import FOUND_TF, FOUND_TORC
21
22
  ###############################
22
23
  keras_tpc_models_dict = None
23
24
  if FOUND_TF:
24
- from model_compression_toolkit.core.tpc_models.qnnpack_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
25
- from model_compression_toolkit.core.tpc_models.qnnpack_tpc.latest import get_keras_tpc_latest
25
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
26
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.latest import get_keras_tpc_latest
26
27
 
27
28
  # Keras: TPC versioning
28
29
  keras_tpc_models_dict = {'v1': get_keras_tpc_v1(),
@@ -33,9 +34,9 @@ if FOUND_TF:
33
34
  ###############################
34
35
  pytorch_tpc_models_dict = None
35
36
  if FOUND_TORCH:
36
- from model_compression_toolkit.core.tpc_models.qnnpack_tpc.v1.tpc_pytorch import \
37
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_pytorch import \
37
38
  get_pytorch_tpc as get_pytorch_tpc_v1
38
- from model_compression_toolkit.core.tpc_models.qnnpack_tpc.latest import get_pytorch_tpc_latest
39
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.latest import get_pytorch_tpc_latest
39
40
 
40
41
  # Pytorch: TPC versioning
41
42
  pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1(),
@@ -15,7 +15,10 @@
15
15
  from typing import List, Tuple
16
16
 
17
17
  import model_compression_toolkit as mct
18
- from model_compression_toolkit.core.common.target_platform import OpQuantizationConfig, TargetPlatformModel
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \
19
+ TargetPlatformModel
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform.quantization_format import \
21
+ QuantizationFormat
19
22
 
20
23
  tp = mct.target_platform
21
24
 
@@ -120,4 +123,7 @@ def generate_tp_model(default_config: OpQuantizationConfig,
120
123
  tp.Fusing([conv, relu])
121
124
  tp.Fusing([linear, relu])
122
125
 
126
+ # Set quantization format to fakely quant
127
+ generated_tpc.set_quantization_format(QuantizationFormat.FAKELY_QUANT)
128
+
123
129
  return generated_tpc
@@ -15,7 +15,7 @@
15
15
  import tensorflow as tf
16
16
 
17
17
  from packaging import version
18
- from model_compression_toolkit.core.tpc_models.qnnpack_tpc.v1 import __version__ as TPC_VERSION
18
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1 import __version__ as TPC_VERSION
19
19
 
20
20
  if version.parse(tf.__version__) < version.parse("2.6"):
21
21
  from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense, BatchNormalization, ReLU, \
@@ -23,7 +23,7 @@ if version.parse(tf.__version__) < version.parse("2.6"):
23
23
  else:
24
24
  from keras.layers import Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense, BatchNormalization, ReLU, Activation
25
25
 
26
- from model_compression_toolkit.core.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model
26
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model
27
27
  import model_compression_toolkit as mct
28
28
 
29
29
  tp = mct.target_platform
@@ -16,9 +16,9 @@ import torch
16
16
  from torch.nn import Conv2d, Linear, BatchNorm2d, ConvTranspose2d, Hardtanh, ReLU, ReLU6
17
17
  from torch.nn.functional import relu, relu6, hardtanh
18
18
 
19
- from model_compression_toolkit.core.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model
19
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model
20
20
  import model_compression_toolkit as mct
21
- from model_compression_toolkit.core.tpc_models.qnnpack_tpc.v1 import __version__ as TPC_VERSION
21
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1 import __version__ as TPC_VERSION
22
22
 
23
23
  tp = mct.target_platform
24
24
 
@@ -0,0 +1,22 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH
16
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tp_model import get_tp_model, generate_tp_model, get_op_quantization_configs
17
+ if FOUND_TF:
18
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_latest
19
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_keras import generate_keras_tpc
20
+ if FOUND_TORCH:
21
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_pytorch import get_pytorch_tpc as get_pytorch_tpc_latest
22
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_pytorch import generate_pytorch_tpc
@@ -13,7 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH, LATEST
16
+ from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH
17
+ from model_compression_toolkit.target_platform_capabilities.constants import LATEST
17
18
 
18
19
 
19
20
  ###############################
@@ -21,8 +22,8 @@ from model_compression_toolkit.core.common.constants import FOUND_TF, FOUND_TORC
21
22
  ###############################
22
23
  keras_tpc_models_dict = None
23
24
  if FOUND_TF:
24
- from model_compression_toolkit.core.tpc_models.tflite_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
25
- from model_compression_toolkit.core.tpc_models.tflite_tpc.latest import get_keras_tpc_latest
25
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
26
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.latest import get_keras_tpc_latest
26
27
 
27
28
  # Keras: TPC versioning
28
29
  keras_tpc_models_dict = {'v1': get_keras_tpc_v1(),
@@ -33,9 +34,9 @@ if FOUND_TF:
33
34
  ###############################
34
35
  pytorch_tpc_models_dict = None
35
36
  if FOUND_TORCH:
36
- from model_compression_toolkit.core.tpc_models.tflite_tpc.v1.tpc_pytorch import \
37
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_pytorch import \
37
38
  get_pytorch_tpc as get_pytorch_tpc_v1
38
- from model_compression_toolkit.core.tpc_models.tflite_tpc.latest import get_pytorch_tpc_latest
39
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.latest import get_pytorch_tpc_latest
39
40
 
40
41
  # Pytorch: TPC versioning
41
42
  pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1(),
@@ -15,8 +15,12 @@
15
15
  from typing import List, Tuple
16
16
 
17
17
  import model_compression_toolkit as mct
18
- from model_compression_toolkit.core.common.target_platform import OpQuantizationConfig, TargetPlatformModel
19
- from model_compression_toolkit.core.common.target_platform.op_quantization_config import QuantizationMethod
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \
19
+ TargetPlatformModel
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import \
21
+ QuantizationMethod
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform.quantization_format import \
23
+ QuantizationFormat
20
24
 
21
25
  tp = mct.target_platform
22
26
 
@@ -65,7 +69,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
65
69
  weights_multiplier_nbits=None
66
70
  )
67
71
 
68
- mixed_precision_cfg_list = [] # No mixed precision
72
+ mixed_precision_cfg_list = [] # No mixed precision
69
73
 
70
74
  return eight_bits, mixed_precision_cfg_list
71
75
 
@@ -106,28 +110,28 @@ def generate_tp_model(default_config: OpQuantizationConfig,
106
110
  # differently. For more details:
107
111
  # https://www.tensorflow.org/lite/performance/quantization_spec#int8_quantized_operator_specifications
108
112
  tp.OperatorsSet("NoQuantization",
109
- tp.get_default_quantization_config_options().clone_and_edit(
110
- quantization_preserving=True))
113
+ tp.get_default_quantization_config_options().clone_and_edit(
114
+ quantization_preserving=True))
111
115
 
112
116
  fc = tp.OperatorsSet("FullyConnected",
113
- tp.get_default_quantization_config_options().clone_and_edit(
114
- weights_per_channel_threshold=False))
117
+ tp.get_default_quantization_config_options().clone_and_edit(
118
+ weights_per_channel_threshold=False))
115
119
 
116
120
  tp.OperatorsSet("L2Normalization",
117
- tp.get_default_quantization_config_options().clone_and_edit(
118
- fixed_zero_point=0, fixed_scale=1 / 128))
121
+ tp.get_default_quantization_config_options().clone_and_edit(
122
+ fixed_zero_point=0, fixed_scale=1 / 128))
119
123
  tp.OperatorsSet("LogSoftmax",
120
- tp.get_default_quantization_config_options().clone_and_edit(
121
- fixed_zero_point=127, fixed_scale=16 / 256))
124
+ tp.get_default_quantization_config_options().clone_and_edit(
125
+ fixed_zero_point=127, fixed_scale=16 / 256))
122
126
  tp.OperatorsSet("Tanh",
123
- tp.get_default_quantization_config_options().clone_and_edit(
124
- fixed_zero_point=0, fixed_scale=1 / 128))
127
+ tp.get_default_quantization_config_options().clone_and_edit(
128
+ fixed_zero_point=0, fixed_scale=1 / 128))
125
129
  tp.OperatorsSet("Softmax",
126
- tp.get_default_quantization_config_options().clone_and_edit(
127
- fixed_zero_point=-128, fixed_scale=1 / 256))
130
+ tp.get_default_quantization_config_options().clone_and_edit(
131
+ fixed_zero_point=-128, fixed_scale=1 / 256))
128
132
  tp.OperatorsSet("Logistic",
129
- tp.get_default_quantization_config_options().clone_and_edit(
130
- fixed_zero_point=-128, fixed_scale=1 / 256))
133
+ tp.get_default_quantization_config_options().clone_and_edit(
134
+ fixed_zero_point=-128, fixed_scale=1 / 256))
131
135
 
132
136
  conv2d = tp.OperatorsSet("Conv2d")
133
137
  kernel = tp.OperatorSetConcat(conv2d, fc)
@@ -140,7 +144,8 @@ def generate_tp_model(default_config: OpQuantizationConfig,
140
144
  bias_add = tp.OperatorsSet("BiasAdd")
141
145
  add = tp.OperatorsSet("Add")
142
146
  squeeze = tp.OperatorsSet("Squeeze",
143
- qc_options=tp.get_default_quantization_config_options().clone_and_edit(quantization_preserving=True))
147
+ qc_options=tp.get_default_quantization_config_options().clone_and_edit(
148
+ quantization_preserving=True))
144
149
  # ------------------- #
145
150
  # Fusions
146
151
  # ------------------- #
@@ -152,4 +157,7 @@ def generate_tp_model(default_config: OpQuantizationConfig,
152
157
  tp.Fusing([batch_norm, activations_to_fuse])
153
158
  tp.Fusing([batch_norm, add, activations_to_fuse])
154
159
 
160
+ # Set quantization format to int8
161
+ generated_tpc.set_quantization_format(QuantizationFormat.INT8)
162
+
155
163
  return generated_tpc
@@ -24,11 +24,11 @@ else:
24
24
 
25
25
  from tensorflow.python.keras.layers.core import SlicingOpLambda
26
26
  from tensorflow.python.ops.image_ops_impl import ResizeMethod
27
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.attribute_filter import Eq
27
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import Eq
28
28
 
29
- from model_compression_toolkit.core.tpc_models.tflite_tpc.v1.tp_model import get_tp_model
29
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tp_model import get_tp_model
30
30
  import model_compression_toolkit as mct
31
- from model_compression_toolkit.core.tpc_models.tflite_tpc.v1 import __version__ as TPC_VERSION
31
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1 import __version__ as TPC_VERSION
32
32
 
33
33
  tp = mct.target_platform
34
34
 
@@ -15,11 +15,11 @@
15
15
  import torch
16
16
  from torch.nn import AvgPool2d, MaxPool2d
17
17
  from torch.nn.functional import avg_pool2d, max_pool2d, interpolate
18
- from model_compression_toolkit.core.common.target_platform.targetplatform2framework.attribute_filter import Eq
18
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import Eq
19
19
 
20
- from model_compression_toolkit.core.tpc_models.tflite_tpc.v1.tp_model import get_tp_model
20
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tp_model import get_tp_model
21
21
  import model_compression_toolkit as mct
22
- from model_compression_toolkit.core.tpc_models.tflite_tpc.v1 import __version__ as TPC_VERSION
22
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1 import __version__ as TPC_VERSION
23
23
 
24
24
  tp = mct.target_platform
25
25
 
@@ -1,25 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from model_compression_toolkit.core.common.constants import FOUND_TF, FOUND_TORCH
16
-
17
- from model_compression_toolkit.core.tpc_models.default_tpc.v4.tp_model import get_tp_model, generate_tp_model, get_op_quantization_configs
18
-
19
- if FOUND_TF:
20
- from model_compression_toolkit.core.tpc_models.default_tpc.v4.tpc_keras import get_keras_tpc as get_keras_tpc_latest
21
- from model_compression_toolkit.core.tpc_models.default_tpc.v4.tpc_keras import generate_keras_tpc
22
-
23
- if FOUND_TORCH:
24
- from model_compression_toolkit.core.tpc_models.default_tpc.v4.tpc_pytorch import get_pytorch_tpc as get_pytorch_tpc_latest
25
- from model_compression_toolkit.core.tpc_models.default_tpc.v4.tpc_pytorch import generate_pytorch_tpc
@@ -1,22 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from model_compression_toolkit.core.common.constants import FOUND_TF, FOUND_TORCH
16
- from model_compression_toolkit.core.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model, generate_tp_model, get_op_quantization_configs
17
- if FOUND_TF:
18
- from model_compression_toolkit.core.tpc_models.qnnpack_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_latest
19
- from model_compression_toolkit.core.tpc_models.qnnpack_tpc.v1.tpc_keras import generate_keras_tpc
20
- if FOUND_TORCH:
21
- from model_compression_toolkit.core.tpc_models.qnnpack_tpc.v1.tpc_pytorch import get_pytorch_tpc as get_pytorch_tpc_latest
22
- from model_compression_toolkit.core.tpc_models.qnnpack_tpc.v1.tpc_pytorch import generate_pytorch_tpc
@@ -1,22 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from model_compression_toolkit.core.common.constants import FOUND_TF, FOUND_TORCH
16
- from model_compression_toolkit.core.tpc_models.tflite_tpc.v1.tp_model import get_tp_model, generate_tp_model, get_op_quantization_configs
17
- if FOUND_TF:
18
- from model_compression_toolkit.core.tpc_models.tflite_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_latest
19
- from model_compression_toolkit.core.tpc_models.tflite_tpc.v1.tpc_keras import generate_keras_tpc
20
- if FOUND_TORCH:
21
- from model_compression_toolkit.core.tpc_models.tflite_tpc.v1.tpc_pytorch import get_pytorch_tpc as get_pytorch_tpc_latest
22
- from model_compression_toolkit.core.tpc_models.tflite_tpc.v1.tpc_pytorch import generate_pytorch_tpc
@@ -1,93 +0,0 @@
1
- # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from typing import Any, List, Callable
16
-
17
- from model_compression_toolkit.core.common import Logger
18
- from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, REGULARIZATION_VALUES
19
-
20
-
21
- class GPTQQuantizerConfig:
22
- """
23
- A base class to define specific quantizer configuration for GPTQ quantizer.
24
- """
25
-
26
- def __init__(self):
27
- self.n_batches = None
28
-
29
- def get_regularization_value(self, fxp_model: Any, **kwargs) -> float:
30
- """
31
- Computes a regularization value for the quantizer's loss (if needed).
32
- In the base class it only returns 0, to be used for GPTQ quantizers that don't require regularization.
33
-
34
- Args:
35
- fxp_model: The quantized model that is being trained.
36
- **kwargs: Additional arguments for the quantizer regularization computation.
37
-
38
- Returns: The regularization value.
39
- """
40
-
41
- return 0
42
-
43
- def set_num_batches(self, num_batches: int):
44
- """
45
- Allows to set the number of batches that the quantizer uses for training (in each epoch).
46
-
47
- Args:
48
- num_batches: number of batches to be set.
49
-
50
- """
51
- self.n_batches = num_batches
52
-
53
-
54
- class SoftQuantizerConfig(GPTQQuantizerConfig):
55
- def __init__(self, entropy_regularization: float = REG_DEFAULT):
56
- """
57
- Initializes an object that holds the arguments that are needed for soft rounding quantizer.
58
-
59
- Args:
60
- entropy_regularization (float): A floating point number that defines the gumbel entropy regularization factor.
61
- """
62
-
63
- super().__init__()
64
- self.entropy_regularization = entropy_regularization
65
-
66
-
67
- def get_regularization_value(self, fxp_model: Any, **kwargs) -> float:
68
- """
69
- Computes a regularization value for the soft quantizer.
70
-
71
- Args:
72
- fxp_model: The quantized model that is being trained.
73
- **kwargs: Additional arguments for the quantizer regularization computation.
74
-
75
- Returns: The regularization value.
76
- """
77
-
78
- soft_rounding_reg_values = kwargs.get(REGULARIZATION_VALUES)
79
-
80
- if soft_rounding_reg_values is None:
81
- Logger.error("No regularization values has been provided for computing the regularization " # pragma: no cover
82
- "of the soft quantizer.")
83
- if not isinstance(soft_rounding_reg_values, List):
84
- Logger.error("The provided regularization values parameter of the soft quantizer " # pragma: no cover
85
- "is not compatible (should be a list).")
86
-
87
- reg = 0
88
-
89
- for sq in soft_rounding_reg_values:
90
- reg += sq
91
-
92
- return self.entropy_regularization * reg
93
-