mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -0,0 +1,101 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import tensorflow as tf
16
+ from packaging import version
17
+
18
+ if version.parse(tf.__version__) < version.parse("2.6"):
19
+ from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
20
+ MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
21
+ Conv2DTranspose
22
+ else:
23
+ from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
24
+ MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
25
+ Conv2DTranspose
26
+
27
+ from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tp_model import get_tp_model
28
+ import model_compression_toolkit as mct
29
+ from model_compression_toolkit.core.tpc_models.imx500_tpc.v1 import __version__ as TPC_VERSION
30
+
31
+ tp = mct.target_platform
32
+
33
+
34
+ def get_keras_tpc() -> tp.TargetPlatformCapabilities:
35
+ """
36
+ get a Keras TargetPlatformCapabilities object with default operation sets to layers mapping.
37
+ Returns: a Keras TargetPlatformCapabilities object for the given TargetPlatformModel.
38
+ """
39
+ imx500_tpc_tp_model = get_tp_model()
40
+ return generate_keras_tpc(name='imx500_tpc_keras_tpc', tp_model=imx500_tpc_tp_model)
41
+
42
+
43
+ def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel):
44
+ """
45
+ Generates a TargetPlatformCapabilities object with default operation sets to layers mapping.
46
+
47
+ Args:
48
+ name: Name of the TargetPlatformCapabilities.
49
+ tp_model: TargetPlatformModel object.
50
+
51
+ Returns: a TargetPlatformCapabilities object for the given TargetPlatformModel.
52
+ """
53
+
54
+ keras_tpc = tp.TargetPlatformCapabilities(tp_model, name=name, version=TPC_VERSION)
55
+
56
+ with keras_tpc:
57
+ tp.OperationsSetToLayers("NoQuantization", [Reshape,
58
+ tf.reshape,
59
+ Permute,
60
+ tf.transpose,
61
+ Flatten,
62
+ Cropping2D,
63
+ ZeroPadding2D,
64
+ Dropout,
65
+ MaxPooling2D,
66
+ tf.split,
67
+ tf.quantization.fake_quant_with_min_max_vars,
68
+ tf.math.argmax,
69
+ tf.shape,
70
+ tf.math.equal,
71
+ tf.gather,
72
+ tf.cast,
73
+ tf.compat.v1.gather,
74
+ tf.nn.top_k,
75
+ tf.__operators__.getitem,
76
+ tf.compat.v1.shape])
77
+
78
+ tp.OperationsSetToLayers("Conv", [Conv2D,
79
+ DepthwiseConv2D,
80
+ Conv2DTranspose,
81
+ tf.nn.conv2d,
82
+ tf.nn.depthwise_conv2d,
83
+ tf.nn.conv2d_transpose])
84
+ tp.OperationsSetToLayers("FullyConnected", [Dense])
85
+ tp.OperationsSetToLayers("AnyReLU", [tf.nn.relu,
86
+ tf.nn.relu6,
87
+ tf.nn.leaky_relu,
88
+ ReLU,
89
+ LeakyReLU,
90
+ tp.LayerFilterParams(Activation, activation="relu"),
91
+ tp.LayerFilterParams(Activation, activation="leaky_relu")])
92
+ tp.OperationsSetToLayers("Add", [tf.add, Add])
93
+ tp.OperationsSetToLayers("Sub", [tf.subtract, Subtract])
94
+ tp.OperationsSetToLayers("Mul", [tf.math.multiply, Multiply])
95
+ tp.OperationsSetToLayers("Div", [tf.math.divide])
96
+ tp.OperationsSetToLayers("PReLU", [PReLU])
97
+ tp.OperationsSetToLayers("Swish", [tf.nn.swish, tp.LayerFilterParams(Activation, activation="swish")])
98
+ tp.OperationsSetToLayers("Sigmoid", [tf.nn.sigmoid, tp.LayerFilterParams(Activation, activation="sigmoid")])
99
+ tp.OperationsSetToLayers("Tanh", [tf.nn.tanh, tp.LayerFilterParams(Activation, activation="tanh")])
100
+
101
+ return keras_tpc
@@ -0,0 +1,95 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import operator
17
+
18
+ import torch
19
+ from torch import add, sub, mul, div, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, chunk, unbind, topk, \
20
+ gather, equal, transpose, permute
21
+ from torch.nn import Conv2d, Linear, BatchNorm2d, ConvTranspose2d
22
+ from torch.nn import Dropout, Flatten, Hardtanh
23
+ from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, LeakyReLU
24
+ from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, leaky_relu
25
+
26
+ from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tp_model import get_tp_model
27
+ import model_compression_toolkit as mct
28
+ from model_compression_toolkit.core.tpc_models.imx500_tpc.v1 import __version__ as TPC_VERSION
29
+
30
+ tp = mct.target_platform
31
+
32
+
33
+ def get_pytorch_tpc() -> tp.TargetPlatformCapabilities:
34
+ """
35
+ get a Pytorch TargetPlatformCapabilities object with default operation sets to layers mapping.
36
+ Returns: a Pytorch TargetPlatformCapabilities object for the given TargetPlatformModel.
37
+ """
38
+ imx500_tpc_tp_model = get_tp_model()
39
+ return generate_pytorch_tpc(name='imx500_tpc_pytorch_tpc', tp_model=imx500_tpc_tp_model)
40
+
41
+
42
+ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel):
43
+ """
44
+ Generates a TargetPlatformCapabilities object with default operation sets to layers mapping.
45
+ Args:
46
+ name: Name of the TargetPlatformModel.
47
+ tp_model: TargetPlatformModel object.
48
+ Returns: a TargetPlatformCapabilities object for the given TargetPlatformModel.
49
+ """
50
+
51
+ pytorch_tpc = tp.TargetPlatformCapabilities(tp_model,
52
+ name=name,
53
+ version=TPC_VERSION)
54
+
55
+ with pytorch_tpc:
56
+ tp.OperationsSetToLayers("NoQuantization", [Dropout,
57
+ Flatten,
58
+ dropout,
59
+ flatten,
60
+ split,
61
+ operator.getitem,
62
+ reshape,
63
+ unsqueeze,
64
+ BatchNorm2d,
65
+ chunk,
66
+ unbind,
67
+ torch.Tensor.size,
68
+ permute,
69
+ transpose,
70
+ equal,
71
+ gather,
72
+ topk])
73
+
74
+ tp.OperationsSetToLayers("Conv", [Conv2d, ConvTranspose2d])
75
+ tp.OperationsSetToLayers("FullyConnected", [Linear])
76
+ tp.OperationsSetToLayers("AnyReLU", [torch.relu,
77
+ ReLU,
78
+ ReLU6,
79
+ LeakyReLU,
80
+ relu,
81
+ relu6,
82
+ leaky_relu,
83
+ tp.LayerFilterParams(Hardtanh, min_val=0),
84
+ tp.LayerFilterParams(hardtanh, min_val=0)])
85
+
86
+ tp.OperationsSetToLayers("Add", [operator.add, add])
87
+ tp.OperationsSetToLayers("Sub", [operator.sub, sub])
88
+ tp.OperationsSetToLayers("Mul", [operator.mul, mul])
89
+ tp.OperationsSetToLayers("Div", [operator.truediv, div])
90
+ tp.OperationsSetToLayers("PReLU", [PReLU, prelu])
91
+ tp.OperationsSetToLayers("Swish", [SiLU, silu, Hardswish, hardswish])
92
+ tp.OperationsSetToLayers("Sigmoid", [Sigmoid, sigmoid])
93
+ tp.OperationsSetToLayers("Tanh", [Tanh, tanh])
94
+
95
+ return pytorch_tpc
@@ -12,3 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
16
+ from model_compression_toolkit.exporter.model_exporter.keras.keras_export_facade import keras_export_model, KerasExportMode
17
+ from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import PyTorchExportMode, pytorch_export_model
18
+ from model_compression_toolkit.exporter.model_exporter.tflite.tflite_export_facade import tflite_export_model, TFLiteExportMode
19
+
@@ -13,15 +13,3 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.constants import FOUND_TF, FOUND_TORCH
17
-
18
- if FOUND_TF:
19
- from model_compression_toolkit.exporter.model_exporter.keras.keras_export_facade import \
20
- keras_export_model, KerasExportMode
21
- from model_compression_toolkit.exporter.model_exporter.tflite.tflite_export_facade import tflite_export_model, \
22
- TFLiteExportMode
23
-
24
- if FOUND_TORCH:
25
- from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import PyTorchExportMode, \
26
- pytorch_export_model
27
-
@@ -54,4 +54,4 @@ class Exporter:
54
54
  Convert model and export it to a given path.
55
55
 
56
56
  """
57
- Logger.critical(f'Exporter {self.__class__} have to implement export method')
57
+ Logger.critical(f'Exporter {self.__class__} have to implement export method') # pragma: no cover
@@ -18,23 +18,12 @@ import keras.models
18
18
  import keras.models
19
19
  import tensorflow as tf
20
20
  from keras.engine.base_layer import Layer
21
- from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_configs import \
22
- NoOpQuantizeConfig
23
21
 
24
22
  from model_compression_toolkit.core.common import Logger
25
23
  from model_compression_toolkit.exporter.model_exporter.keras.base_keras_exporter import \
26
24
  BaseKerasExporter
27
- from model_compression_toolkit.exporter.model_wrapper.keras.builder.quantize_config_to_node import \
28
- SUPPORTED_QUANTIZATION_CONFIG
29
- from model_compression_toolkit.exporter.model_wrapper.keras.extended_quantize_wrapper import ExtendedQuantizeWrapper
30
- from model_compression_toolkit.exporter.model_wrapper.keras.quantize_configs.activation_quantize_config import \
31
- ActivationQuantizeConfig
32
- from model_compression_toolkit.exporter.model_wrapper.keras.quantize_configs.weights_activation_quantize_config \
33
- import \
34
- WeightsActivationQuantizeConfig
35
- from model_compression_toolkit.exporter.model_wrapper.keras.quantize_configs.weights_quantize_config import \
36
- WeightsQuantizeConfig
37
- from model_compression_toolkit.exporter.model_wrapper.keras.quantizers.fq_quantizer import FakeQuantQuantizer
25
+ from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
26
+
38
27
 
39
28
 
40
29
  class FakelyQuantKerasExporter(BaseKerasExporter):
@@ -83,7 +72,7 @@ class FakelyQuantKerasExporter(BaseKerasExporter):
83
72
  assert self.is_layer_exportable_fn(layer), f'Layer {layer.name} is not exportable.'
84
73
 
85
74
  # If weights are quantized, use the quantized weight for the new built layer.
86
- if type(layer.quantize_config) in [WeightsQuantizeConfig, WeightsActivationQuantizeConfig]:
75
+ if layer.is_weights_quantization:
87
76
  new_layer = layer.layer.__class__.from_config(layer.layer.get_config())
88
77
  with tf.name_scope(new_layer.name):
89
78
  new_layer.build(layer.input_shape)
@@ -100,8 +89,10 @@ class FakelyQuantKerasExporter(BaseKerasExporter):
100
89
  # that should be quantized. First, extract 'kernel' from variable name, check if the
101
90
  # quantize config contains this as an attribute for quantization. If so -
102
91
  # Take the quantized weight from the quantize_config and set it to the new layer.
103
- if w.name.split('/')[-1].split(':')[0] in layer.quantize_config.get_config()['weight_attrs']:
104
- val = layer.quantize_config.get_weights_and_quantizers(layer.layer)[0][1].weight
92
+ attribute_name = w.name.split('/')[-1].split(':')[0]
93
+ if attribute_name in layer.weights_quantizers.keys():
94
+ quantizer = layer.weights_quantizers.get(attribute_name)
95
+ val = quantizer(qw)
105
96
  else:
106
97
  val = qw
107
98
  if val is None:
@@ -113,23 +104,16 @@ class FakelyQuantKerasExporter(BaseKerasExporter):
113
104
 
114
105
  # If activations are also quantized, wrap the layer back using ActivationQuantizeConfig
115
106
  # from original wrapper (weights wrapping is no longer needed).
116
- if isinstance(layer.quantize_config, WeightsActivationQuantizeConfig):
117
- new_layer = ExtendedQuantizeWrapper(new_layer, layer.quantize_config.act_config)
107
+ if layer.is_activation_quantization:
108
+ new_layer = KerasQuantizationWrapper(layer=new_layer,
109
+ activation_quantizers=layer.activation_quantizers)
118
110
 
119
111
  return new_layer
120
112
 
121
113
  # If this is a layer with activation quantization only, just return it
122
114
  # as activation quantization in the fake-quant case uses the wrapper for quantization.
123
- elif type(layer.quantize_config) in [ActivationQuantizeConfig]:
124
- return layer
115
+ return layer
125
116
 
126
- # Ideally we want in the case of no quantization to simply use the inner layer.
127
- # But for some reason when using SNC we are having issues to use the inner layer.
128
- # The clone_model method tries to reconstruct a model from the unwrapped configuration,
129
- # but when we have two TFOpLambda (like in the case of SNC: add and pad) one after another,
130
- # the output shape of the first one is in correct (it adds a new axis
131
- elif isinstance(layer.quantize_config, NoOpQuantizeConfig):
132
- return layer
133
117
 
134
118
  # clone each layer in the model and apply _unwrap_quantize_wrapper to layers wrapped with a QuantizeWrapper.
135
119
  self.exported_model = tf.keras.models.clone_model(self.model,
@@ -137,7 +121,7 @@ class FakelyQuantKerasExporter(BaseKerasExporter):
137
121
  clone_function=_unwrap_quantize_wrapper)
138
122
 
139
123
  if self.exported_model is None:
140
- Logger.critical(f'Exporter can not save model as it is not exported')
124
+ Logger.critical(f'Exporter can not save model as it is not exported') # pragma: no cover
141
125
 
142
126
  Logger.info(f'Exporting FQ Keras model to: {self.save_model_path}')
143
127
 
@@ -145,14 +129,3 @@ class FakelyQuantKerasExporter(BaseKerasExporter):
145
129
 
146
130
  return FakelyQuantKerasExporter.get_custom_objects()
147
131
 
148
- @staticmethod
149
- def get_custom_objects() -> Dict[str, type]:
150
- """
151
-
152
- Returns: A dictionary with objects for loading the exported model.
153
-
154
- """
155
- return {ExtendedQuantizeWrapper.__name__: ExtendedQuantizeWrapper,
156
- ActivationQuantizeConfig.__name__: ActivationQuantizeConfig,
157
- FakeQuantQuantizer.__name__: FakeQuantQuantizer,
158
- NoOpQuantizeConfig.__name__: NoOpQuantizeConfig}
@@ -15,44 +15,56 @@
15
15
  from enum import Enum
16
16
  from typing import Callable, Dict
17
17
 
18
- import keras
19
18
  from model_compression_toolkit.core.common import Logger
20
- from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import \
21
- FakelyQuantKerasExporter
19
+ from model_compression_toolkit.core.common.constants import FOUND_TF
22
20
 
23
21
 
24
22
  class KerasExportMode(Enum):
25
23
  FAKELY_QUANT = 0
26
24
 
27
25
 
28
- def keras_export_model(model: keras.models.Model,
29
- is_layer_exportable_fn: Callable,
30
- mode: KerasExportMode = KerasExportMode.FAKELY_QUANT,
31
- save_model_path: str = None) -> Dict[str, type]:
32
- """
33
- Prepare and return fully quantized model for export. Save exported model to
34
- a path if passed.
26
+ if FOUND_TF:
27
+ import keras
28
+ from model_compression_toolkit.exporter.model_wrapper.keras.validate_layer import is_keras_layer_exportable
29
+ from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
35
30
 
36
- Args:
37
- model: Model to export.
38
- is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
39
- mode: Mode to export the model according to.
40
- save_model_path: Path to save the model.
31
+ def keras_export_model(model: keras.models.Model,
32
+ save_model_path: str,
33
+ is_layer_exportable_fn: Callable = is_keras_layer_exportable,
34
+ mode: KerasExportMode = KerasExportMode.FAKELY_QUANT) -> Dict[str, type]:
35
+ """
36
+ Export a Keras quantized model to h5 model.
37
+ The model will be saved to the path in save_model_path.
38
+ Mode can be used for different exported files. Currently, keras_export_model
39
+ supports KerasExportMode.FAKELY_QUANT (where weights and activations are
40
+ float fakely-quantized values).
41
41
 
42
- Returns:
43
- Custom objects dictionary needed to load the model.
42
+ Args:
43
+ model: Model to export.
44
+ is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
45
+ mode: Mode to export the model according to.
46
+ save_model_path: Path to save the model.
44
47
 
45
- """
48
+ Returns:
49
+ Custom objects dictionary needed to load the model.
46
50
 
47
- if mode == KerasExportMode.FAKELY_QUANT:
48
- exporter = FakelyQuantKerasExporter(model,
49
- is_layer_exportable_fn,
50
- save_model_path)
51
+ """
51
52
 
52
- else:
53
- Logger.critical(
54
- f'Unsupported mode was used {mode.name} to export Keras model. Please see API for supported modes.')
53
+ if mode == KerasExportMode.FAKELY_QUANT:
54
+ exporter = FakelyQuantKerasExporter(model,
55
+ is_layer_exportable_fn,
56
+ save_model_path)
55
57
 
56
- exporter.export()
58
+ else:
59
+ Logger.critical(
60
+ f'Unsupported mode was used {mode.name} to '
61
+ f'export Keras model. Please see API for supported modes.') # pragma: no cover
57
62
 
58
- return exporter.get_custom_objects()
63
+ exporter.export()
64
+
65
+ return exporter.get_custom_objects()
66
+ else:
67
+ def keras_export_model(*args, **kwargs):
68
+ Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
69
+ 'when using keras_export_model. '
70
+ 'Could not find some or all of TensorFlow packages.') # pragma: no cover
@@ -19,7 +19,13 @@ import torch.nn
19
19
  from model_compression_toolkit.core.common import Logger
20
20
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
21
21
  from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
22
+ from packaging import version
22
23
 
24
+ # ONNX opset version 16 is supported from PyTorch 1.12
25
+ if version.parse(torch.__version__) < version.parse("1.12"):
26
+ OPSET_VERSION = 15
27
+ else:
28
+ OPSET_VERSION = 16
23
29
 
24
30
  class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
25
31
  """
@@ -57,7 +63,9 @@ class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
57
63
  Returns:
58
64
  Fake-quant PyTorch model.
59
65
  """
60
- # assert self.is_layer_exportable_fn(layer), f'Layer {layer.name} is not exportable.'
66
+ for layer in self.model.children():
67
+ assert self.is_layer_exportable_fn(layer), f'Layer {layer.name} is not exportable.'
68
+
61
69
  model_input = to_torch_tensor(next(self.repr_dataset())[0])
62
70
 
63
71
  Logger.info(f"Exporting PyTorch fake quant onnx model: {self.save_model_path}")
@@ -65,7 +73,7 @@ class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
65
73
  torch.onnx.export(self.model,
66
74
  model_input,
67
75
  self.save_model_path,
68
- opset_version=13,
76
+ opset_version=OPSET_VERSION,
69
77
  verbose=False,
70
78
  input_names=['input'],
71
79
  output_names=['output'],
@@ -56,8 +56,12 @@ class FakelyQuantTorchScriptPyTorchExporter(BasePyTorchExporter):
56
56
  Returns:
57
57
  Fake-quant PyTorch model.
58
58
  """
59
- # assert self.is_layer_exportable_fn(layer), f'Layer {layer.name} is not exportable.'
60
- torch_traced = torch.jit.trace(self.model, to_torch_tensor(next(self.repr_dataset())))
59
+ for layer in self.model.children():
60
+ assert self.is_layer_exportable_fn(layer), f'Layer {layer} is not exportable.'
61
+
62
+ torch_traced = torch.jit.trace(self.model,
63
+ to_torch_tensor(next(self.repr_dataset())),
64
+ check_trace=True)
61
65
  self.exported_model = torch.jit.script(torch_traced)
62
66
  Logger.info(f"Exporting PyTorch torch script Model: {self.save_model_path}")
63
67
  torch.jit.save(self.exported_model, self.save_model_path)
@@ -15,13 +15,8 @@
15
15
  from enum import Enum
16
16
  from typing import Callable
17
17
 
18
- import torch.nn
19
-
20
18
  from model_compression_toolkit.core.common import Logger
21
- from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import \
22
- FakelyQuantONNXPyTorchExporter
23
- from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import \
24
- FakelyQuantTorchScriptPyTorchExporter
19
+ from model_compression_toolkit.core.common.constants import FOUND_TORCH
25
20
 
26
21
 
27
22
  class PyTorchExportMode(Enum):
@@ -29,38 +24,56 @@ class PyTorchExportMode(Enum):
29
24
  FAKELY_QUANT_ONNX = 1
30
25
 
31
26
 
32
- def pytorch_export_model(model: torch.nn.Module,
33
- is_layer_exportable_fn: Callable,
34
- mode: PyTorchExportMode = PyTorchExportMode.FAKELY_QUANT_TORCHSCRIPT,
35
- save_model_path: str = None,
36
- repr_dataset: Callable = None) -> None:
37
- """
38
- Prepare and return fully quantized model for export. Save exported model to
39
- a path if passed.
27
+ if FOUND_TORCH:
28
+ import torch.nn
29
+ from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import FakelyQuantONNXPyTorchExporter
30
+ from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import FakelyQuantTorchScriptPyTorchExporter
31
+ from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
32
+
33
+ def pytorch_export_model(model: torch.nn.Module,
34
+ save_model_path: str,
35
+ repr_dataset: Callable,
36
+ is_layer_exportable_fn: Callable = is_pytorch_layer_exportable,
37
+ mode: PyTorchExportMode = PyTorchExportMode.FAKELY_QUANT_TORCHSCRIPT) -> None:
38
+ """
39
+ Export a PyTorch quantized model to a torchscript or onnx model.
40
+ The model will be saved to the path in save_model_path.
41
+ Mode can be used for different exported files. Currently, pytorch_export_model
42
+ supports PyTorchExportMode.FAKELY_QUANT_TORCHSCRIPT (where the exported model
43
+ is in a TorchScript format and its weights and activations are float fakely-quantized values),
44
+ and PyTorchExportMode.FakelyQuantONNX (where the exported model
45
+ is in an ONNX format and its weights and activations are float fakely-quantized values)
46
+
47
+ Args:
48
+ model: Model to export.
49
+ is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
50
+ mode: Mode to export the model according to.
51
+ save_model_path: Path to save the model.
52
+ repr_dataset: Representative dataset for tracing the pytorch model (mandatory for exporting it).
40
53
 
41
- Args:
42
- model: Model to export.
43
- is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
44
- mode: Mode to export the model according to.
45
- save_model_path: Path to save the model.
46
- repr_dataset: Representative dataset for tracing the pytorch model (mandatory for exporting it).
54
+ """
47
55
 
48
- """
56
+ if mode == PyTorchExportMode.FAKELY_QUANT_TORCHSCRIPT:
57
+ exporter = FakelyQuantTorchScriptPyTorchExporter(model,
58
+ is_layer_exportable_fn,
59
+ save_model_path,
60
+ repr_dataset)
49
61
 
50
- if mode == PyTorchExportMode.FAKELY_QUANT_TORCHSCRIPT:
51
- exporter = FakelyQuantTorchScriptPyTorchExporter(model,
52
- is_layer_exportable_fn,
53
- save_model_path,
54
- repr_dataset)
62
+ elif mode == PyTorchExportMode.FAKELY_QUANT_ONNX:
63
+ exporter = FakelyQuantONNXPyTorchExporter(model,
64
+ is_layer_exportable_fn,
65
+ save_model_path,
66
+ repr_dataset)
55
67
 
56
- elif mode == PyTorchExportMode.FAKELY_QUANT_ONNX:
57
- exporter = FakelyQuantONNXPyTorchExporter(model,
58
- is_layer_exportable_fn,
59
- save_model_path,
60
- repr_dataset)
68
+ else:
69
+ Logger.critical(
70
+ f'Unsupported mode was used {mode.name} to export PyTorch model. '
71
+ f'Please see API for supported modes.') # pragma: no cover
61
72
 
62
- else:
63
- Logger.critical(
64
- f'Unsupported mode was used {mode.name} to export PyTorch model. Please see API for supported modes.')
73
+ exporter.export()
65
74
 
66
- exporter.export()
75
+ else:
76
+ def pytorch_export_model(*args, **kwargs):
77
+ Logger.error('Installing torch is mandatory '
78
+ 'when using pytorch_export_model. '
79
+ 'Could not find PyTorch packages.') # pragma: no cover
@@ -18,8 +18,8 @@ from typing import Callable
18
18
 
19
19
  import keras.models
20
20
  import tensorflow as tf
21
- from keras.models import load_model
22
21
 
22
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.load_model import keras_load_quantized_model
23
23
  from model_compression_toolkit.core.common import Logger
24
24
  from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
25
25
 
@@ -61,7 +61,8 @@ class FakelyQuantTFLiteExporter(FakelyQuantKerasExporter):
61
61
  custom_objects = FakelyQuantKerasExporter(self.model,
62
62
  self.is_layer_exportable_fn,
63
63
  tmp_h5_file).export()
64
- model = load_model(tmp_h5_file, custom_objects)
64
+
65
+ model = keras_load_quantized_model(tmp_h5_file)
65
66
  os.remove(tmp_h5_file)
66
67
 
67
68
  self.exported_model = tf.lite.TFLiteConverter.from_keras_model(model).convert()