mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from abc import abstractmethod
16
- from typing import Tuple, Any, Dict, List, Union
16
+ from typing import Tuple, Any, Dict, List, Union, Callable
17
17
 
18
18
  import torch
19
19
  from networkx import topological_sort
@@ -25,7 +25,7 @@ from model_compression_toolkit.core.common.back2framework.base_model_builder imp
25
25
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
26
26
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
27
27
  from model_compression_toolkit.core.common.user_info import UserInformation
28
- from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder
28
+ from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder, identity_wrapper
29
29
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
30
30
  from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder, BufferHolder
31
31
  from model_compression_toolkit.core.pytorch.utils import get_working_device
@@ -65,7 +65,8 @@ def _build_input_tensors_list(node: BaseNode,
65
65
  def _run_operation(n: BaseNode,
66
66
  input_tensors: List,
67
67
  op_func: Any,
68
- quantize_node_activation_fn) -> Tuple[Union[List,torch.Tensor], Union[List,torch.Tensor]]:
68
+ quantize_node_activation_fn,
69
+ is_wrapped: bool) -> Tuple[Union[List,torch.Tensor], Union[List,torch.Tensor]]:
69
70
  """
70
71
  Applying the layer (op_func) to the input tensors (input_tensors).
71
72
  If quantized is set to True, and the layer's corresponding node (n) has quantization
@@ -76,6 +77,7 @@ def _run_operation(n: BaseNode,
76
77
  input_tensors: List of Pytorch tensors that are the layer's inputs.
77
78
  op_func: Module/functional to apply to the input tensors.
78
79
  quantize_node_activation_fn: quantization function
80
+ is_wrapped : Flag to indicate if layer is already quantization wrapped so no activation is needed
79
81
  Returns:
80
82
  A tuple of Pytorch tensors. The Module/functional output tensors after applying the
81
83
  Module/functional to the input tensors.
@@ -90,7 +92,7 @@ def _run_operation(n: BaseNode,
90
92
 
91
93
  # Add a fake quant node if the node has an activation threshold.
92
94
  out_tensors_of_n = out_tensors_of_n_float
93
- if n.is_activation_quantization_enabled():
95
+ if n.is_activation_quantization_enabled() and not is_wrapped:
94
96
  if isinstance(out_tensors_of_n_float, list):
95
97
  out_tensors_of_n_float = torch.cat(out_tensors_of_n_float, dim=0)
96
98
  out_tensors_of_n = quantize_node_activation_fn(n, out_tensors_of_n_float)
@@ -142,7 +144,8 @@ class PytorchModel(torch.nn.Module):
142
144
  graph: Graph,
143
145
  append2output: List[Any] = None,
144
146
  fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
145
- return_float_outputs: bool = False):
147
+ return_float_outputs: bool = False,
148
+ wrapper: Callable = identity_wrapper):
146
149
  """
147
150
  Construct a Pytorch model.
148
151
 
@@ -151,6 +154,7 @@ class PytorchModel(torch.nn.Module):
151
154
  append2output: List of nodes or OutTensor objects.
152
155
  fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
153
156
  return_float_outputs: Whether the model returns float tensors or not.
157
+ wrapper: A function wrapper Pytorch Layers.
154
158
  """
155
159
  super(PytorchModel, self).__init__()
156
160
  self.graph = graph
@@ -159,6 +163,7 @@ class PytorchModel(torch.nn.Module):
159
163
  self.append2output = append2output
160
164
  self.return_float_outputs = return_float_outputs
161
165
  self.fw_info = fw_info
166
+ self.wrapper = wrapper
162
167
  self._add_modules()
163
168
 
164
169
  @abstractmethod
@@ -176,17 +181,21 @@ class PytorchModel(torch.nn.Module):
176
181
  Output of the node.
177
182
 
178
183
  """
179
- raise NotImplemented(f'{self.__class__.__name__} have to implement a method for quantization activation nodes.')
184
+ raise NotImplemented(f'{self.__class__.__name__} '
185
+ f'have to implement a method for quantization activation nodes.') # pragma: no cover
180
186
 
181
187
  def _add_modules(self):
182
188
  for n in self.node_sort:
183
- if not isinstance(n, FunctionalNode):
189
+ if isinstance(n, FunctionalNode):
190
+ # for functional layers
191
+ setattr(self, n.name, self.wrapper(n, n.type))
192
+ else:
184
193
  if n.type == BufferHolder:
185
194
  self.add_module(n.name, node_builder(n))
186
195
  self.get_submodule(n.name). \
187
196
  register_buffer(n.name, torch.Tensor(n.get_weights_by_keys(BUFFER)).to(get_working_device()))
188
197
  else:
189
- self.add_module(n.name, node_builder(n))
198
+ self.add_module(n.name, self.wrapper(n, node_builder(n)))
190
199
 
191
200
  def forward(self,
192
201
  *args: Any) -> Any:
@@ -211,7 +220,8 @@ class PytorchModel(torch.nn.Module):
211
220
  out_tensors_of_n, out_tensors_of_n_float = _run_operation(n,
212
221
  input_tensors,
213
222
  op_func=op_func,
214
- quantize_node_activation_fn=self._quantize_node_activations)
223
+ quantize_node_activation_fn=self._quantize_node_activations,
224
+ is_wrapped=self.wrapper is not identity_wrapper)
215
225
 
216
226
  if isinstance(out_tensors_of_n, list):
217
227
  node_to_output_tensors_dict.update({n: out_tensors_of_n})
@@ -244,7 +254,7 @@ class PytorchModel(torch.nn.Module):
244
254
  Returns: Module/functional to apply to the input tensors.
245
255
 
246
256
  """
247
- return node.type if isinstance(node, FunctionalNode) else getattr(self, node.name)
257
+ return getattr(self, node.name)
248
258
 
249
259
 
250
260
  class PyTorchModelBuilder(BaseModelBuilder):
@@ -256,7 +266,8 @@ class PyTorchModelBuilder(BaseModelBuilder):
256
266
  graph: common.Graph,
257
267
  append2output=None,
258
268
  fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
259
- return_float_outputs: bool = False):
269
+ return_float_outputs: bool = False,
270
+ wrapper: Callable = identity_wrapper):
260
271
  """
261
272
 
262
273
  Args:
@@ -264,6 +275,7 @@ class PyTorchModelBuilder(BaseModelBuilder):
264
275
  append2output: Nodes to append to model's output.
265
276
  fw_info: Information about the specific framework of the model that is built.
266
277
  return_float_outputs: Whether the model returns float tensors or not.
278
+ wrapper: A function wrapper Pytorch Layers.
267
279
  """
268
280
 
269
281
  super().__init__(graph,
@@ -271,6 +283,8 @@ class PyTorchModelBuilder(BaseModelBuilder):
271
283
  fw_info,
272
284
  return_float_outputs)
273
285
 
286
+ self.wrapper = wrapper
287
+
274
288
  def build_model(self) -> Tuple[PytorchModel, UserInformation]:
275
289
  """
276
290
  Build a PyTorch model and return it.
@@ -279,4 +293,5 @@ class PyTorchModelBuilder(BaseModelBuilder):
279
293
  """
280
294
  return PytorchModel(self.graph,
281
295
  self.append2output,
282
- return_float_outputs=self.return_float_outputs), self.graph.user_info
296
+ return_float_outputs=self.return_float_outputs,
297
+ wrapper=self.wrapper), self.graph.user_info
@@ -38,8 +38,7 @@ class WrapperQuantizeConfig:
38
38
  Returns: A List of quantizers for weights quantization.
39
39
 
40
40
  """
41
- raise NotImplemented
42
-
41
+ raise NotImplemented # pragma: no cover
43
42
 
44
43
  def get_activation_quantizers(self) -> list:
45
44
  """
@@ -47,7 +46,7 @@ class WrapperQuantizeConfig:
47
46
  Returns: A List of quantizers for activation quantization.
48
47
 
49
48
  """
50
- raise NotImplemented
49
+ raise NotImplemented # pragma: no cover
51
50
 
52
51
 
53
52
 
@@ -71,6 +71,7 @@ RELU_POT_BOUND = 8.0
71
71
 
72
72
  # Supported TP models names for Pytorch:
73
73
  DEFAULT_TP_MODEL = 'default'
74
+ IMX500_TP_MODEL = 'imx500'
74
75
  TFLITE_TP_MODEL = 'tflite'
75
76
  QNNPACK_TP_MODEL = 'qnnpack'
76
77
 
@@ -91,3 +92,7 @@ IN_PROJ_WEIGHT = 'in_proj_weight'
91
92
  IN_PROJ_BIAS = 'in_proj_bias'
92
93
  BIAS_K = 'bias_k'
93
94
  BIAS_V = 'bias_v'
95
+
96
+ # # Batch size value for 'reshape' and 'view' operators,
97
+ # # the value is -1 so the batch size is inferred from the length of the array and remaining dimensions.
98
+ BATCH_DIM_VALUE = -1
@@ -20,6 +20,7 @@ import torch.nn as nn
20
20
  import operator
21
21
  from typing import List
22
22
 
23
+ from model_compression_toolkit.core.common.logger import Logger
23
24
  from model_compression_toolkit.core import common
24
25
  from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
25
26
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
@@ -46,32 +47,26 @@ class MHAParams:
46
47
  # Only batch first network is supported
47
48
  if BATCH_FIRST in mha_node.framework_attr.keys():
48
49
  if mha_node.framework_attr[BATCH_FIRST] is not True:
49
- raise Exception('Only batch first network is supported')
50
+ Logger.error('Only batch first network is supported') # pragma: no cover
50
51
  else:
51
- raise Exception('Only batch first network is supported')
52
+ Logger.error('Only batch first network is supported') # pragma: no cover
52
53
 
53
54
  # Add Zero Attn feature is Not Implemented
54
55
  if ADD_ZERO_ATTN in mha_node.framework_attr.keys():
55
56
  if mha_node.framework_attr[ADD_ZERO_ATTN] is not False:
56
- raise Exception('Add Zero Attn feature is Not Implemented')
57
+ Logger.error('Add Zero Attn feature is Not Implemented') # pragma: no cover
57
58
 
58
59
  # Check if Add Bias KV feature is Active
59
60
  if BIAS_K and BIAS_V in mha_node.weights.keys():
60
- if mha_node.weights[BIAS_K] and mha_node.weights[BIAS_V] is not None:
61
- raise Exception('Add BIAS_KV feature is Not Implemented')
61
+ if mha_node.weights[BIAS_K] is not None and mha_node.weights[BIAS_V] is not None:
62
+ Logger.error('Add BIAS_KV feature is Not Implemented') # pragma: no cover
62
63
 
63
64
  self.embed_dim = mha_node.framework_attr[EMBED_DIM]
64
65
  self.num_heads = mha_node.framework_attr[NUM_HEADS]
65
66
 
66
- if KEY_DIM in mha_node.framework_attr:
67
- self.kdim = mha_node.framework_attr[KEY_DIM]
68
- else:
69
- self.kdim = False
67
+ self.kdim = mha_node.framework_attr[KEY_DIM]
70
68
 
71
- if VALUE_DIM in mha_node.framework_attr:
72
- self.vdim = mha_node.framework_attr[VALUE_DIM]
73
- else:
74
- self.vdim = False
69
+ self.vdim = mha_node.framework_attr[VALUE_DIM]
75
70
 
76
71
  self.qdim = int(self.embed_dim / self.num_heads)
77
72
 
@@ -707,7 +702,7 @@ class MultiHeadAttentionDecomposition(common.BaseSubstitution):
707
702
  """
708
703
 
709
704
  if mha_node.reuse:
710
- raise Exception("MCT doesn't support reuse of MultiHeadAttention layer")
705
+ raise Exception("MCT doesn't support reuse of MultiHeadAttention layer") # pragma: no cover
711
706
  params = MHAParams(mha_node)
712
707
 
713
708
  # project
@@ -14,10 +14,13 @@
14
14
  # ==============================================================================
15
15
  from torch import reshape
16
16
  import torch
17
+
18
+ from model_compression_toolkit.core.common import Logger
17
19
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
18
20
  from model_compression_toolkit.core import common
19
21
  from model_compression_toolkit.core.common.graph.base_graph import Graph
20
22
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
23
+ from model_compression_toolkit.core.pytorch.constants import BATCH_DIM_VALUE
21
24
 
22
25
 
23
26
  class ReshapeWithStaticShapes(common.BaseSubstitution):
@@ -47,14 +50,25 @@ class ReshapeWithStaticShapes(common.BaseSubstitution):
47
50
  Returns:
48
51
  Graph after applying the substitution.
49
52
  """
53
+ # we want the batch size value to infer from the length of the array and remaining dimensions
54
+ if len(node.output_shape) == 1:
55
+ node.output_shape[0][0] = BATCH_DIM_VALUE
56
+ else:
57
+ Logger.error('Reshape or view nodes should have a single output shape') # pragma: no cover
58
+
50
59
  # configure the new static output shape attribute
51
60
  node.op_call_args = node.output_shape
52
61
 
53
62
  # modify the node input info
54
63
  node.input_shape = [node.input_shape[0]]
64
+
65
+ # the first input is the tensor to be reshaped, we want his batch size value to infer
66
+ # from the length of the array and remaining dimensions
67
+ node.input_shape[0][0] = BATCH_DIM_VALUE
68
+
55
69
  nodes_to_check = []
56
70
  for in_edge in graph.incoming_edges(node):
57
- if in_edge.sink_index > 0: # the first input is the tensor to be reshaped
71
+ if in_edge.sink_index > 0: # the first input is the tensor to be reshaped
58
72
  nodes_to_check.append(in_edge.source_node)
59
73
  graph.remove_edge(in_edge.source_node, node)
60
74
  for n in nodes_to_check:
@@ -80,4 +94,4 @@ def clean_graph_from_nodes_without_out_edges(graph: Graph,
80
94
  graph.remove_edge(in_edge.source_node, node)
81
95
  graph.remove_node(node)
82
96
  for n in nodes_to_check:
83
- clean_graph_from_nodes_without_out_edges(graph, n)
97
+ clean_graph_from_nodes_without_out_edges(graph, n)
@@ -154,9 +154,9 @@ else:
154
154
  # we raise an exception when trying to use this function.
155
155
  def pytorch_kpi_data(*args, **kwargs):
156
156
  Logger.critical('Installing torch is mandatory when using pytorch_kpi_data. '
157
- 'Could not find Tensorflow package.')
157
+ 'Could not find Tensorflow package.') # pragma: no cover
158
158
 
159
159
 
160
160
  def pytorch_kpi_data_experimental(*args, **kwargs):
161
161
  Logger.critical('Installing torch is mandatory when using pytorch_kpi_data. '
162
- 'Could not find Tensorflow package.')
162
+ 'Could not find Tensorflow package.') # pragma: no cover
@@ -269,9 +269,9 @@ else:
269
269
  def pytorch_post_training_quantization(*args, **kwargs):
270
270
  Logger.critical('Installing Pytorch is mandatory '
271
271
  'when using pytorch_post_training_quantization. '
272
- 'Could not find the torch package.')
272
+ 'Could not find the torch package.') # pragma: no cover
273
273
 
274
274
  def pytorch_post_training_quantization_mixed_precision(*args, **kwargs):
275
275
  Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
276
276
  'when using pytorch_post_training_quantization_mixed_precision. '
277
- 'Could not find Tensorflow package.')
277
+ 'Could not find Tensorflow package.') # pragma: no cover
@@ -17,6 +17,7 @@ import torch
17
17
 
18
18
  from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX
19
19
  from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import threshold_is_power_of_two
20
+ from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero
20
21
 
21
22
 
22
23
  def get_symmetric_quantization_range_and_scale(activation_is_signed: bool,
@@ -62,9 +63,9 @@ def power_of_two_quantization(activation_n_bits: int,
62
63
  activation_is_signed = quantization_params.get(SIGNED)
63
64
 
64
65
  if activation_threshold is None or activation_is_signed is None:
65
- return None
66
+ return None # pragma: no cover
66
67
  if not threshold_is_power_of_two(activation_threshold, per_channel=False):
67
- return None
68
+ return None # pragma: no cover
68
69
 
69
70
  min_value, max_value, scale = get_symmetric_quantization_range_and_scale(activation_is_signed,
70
71
  activation_n_bits,
@@ -90,7 +91,7 @@ def symmetric_quantization(activation_n_bits: int,
90
91
  activation_is_signed = quantization_params.get(SIGNED)
91
92
 
92
93
  if activation_threshold is None or activation_is_signed is None:
93
- return None
94
+ return None # pragma: no cover
94
95
 
95
96
  min_value, max_value, scale = get_symmetric_quantization_range_and_scale(activation_is_signed,
96
97
  activation_n_bits,
@@ -115,16 +116,17 @@ def uniform_quantization(activation_n_bits: int,
115
116
  a, b = quantization_params.get(RANGE_MIN), quantization_params.get(RANGE_MAX)
116
117
 
117
118
  if a is None or b is None:
118
- return None
119
+ return None # pragma: no cover
119
120
 
120
121
  # fixing quantization range to include 0
121
122
  a = 0 if a > 0 else a
122
123
  b = 0 if b < 0 else b
124
+ a, b = fix_range_to_include_zero(a, b, activation_n_bits)
123
125
 
124
126
  min_value = 0
125
127
  max_value = 2 ** activation_n_bits - 1
126
128
  scale = (b - a) / ((2 ** activation_n_bits) - 1)
127
- zero_point = -int(a / scale) # zp has to be positive, and a <=0, so we multiply by -1
129
+ zero_point = -round(a / scale) # zp has to be positive, and a <=0, so we multiply by -1
128
130
 
129
131
  return lambda x: q(x, min_value, max_value, scale, zero_point)
130
132
 
@@ -57,7 +57,7 @@ class PytorchLUTFakeQuant(torch.nn.Module):
57
57
  Quantized torch Tensor.
58
58
  """
59
59
  if self.activation_is_signed is None or self.cluster_centers is None or self.threshold is None:
60
- return None
60
+ return None # pragma: no cover
61
61
 
62
62
  _quant_output = self.lut_kmeans_quantizer(x)
63
63
  return _quant_output
@@ -17,14 +17,18 @@ from model_compression_toolkit.core.common.target_platform import TargetPlatform
17
17
 
18
18
  from model_compression_toolkit.core.tpc_models.default_tpc.target_platform_capabilities import \
19
19
  tpc_dict as default_tpc_dict
20
+ from model_compression_toolkit.core.tpc_models.imx500_tpc.target_platform_capabilities import \
21
+ tpc_dict as imx500_tpc_dict
20
22
  from model_compression_toolkit.core.tpc_models.tflite_tpc.target_platform_capabilities import \
21
23
  tpc_dict as tflite_tpc_dict
22
24
  from model_compression_toolkit.core.tpc_models.qnnpack_tpc.target_platform_capabilities import \
23
25
  tpc_dict as qnnpack_tpc_dict
24
- from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL, TFLITE_TP_MODEL, QNNPACK_TP_MODEL
26
+ from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, TFLITE_TP_MODEL, \
27
+ QNNPACK_TP_MODEL
25
28
  from model_compression_toolkit.core.common.constants import LATEST
26
29
 
27
30
  tpc_dict = {DEFAULT_TP_MODEL: default_tpc_dict,
31
+ IMX500_TP_MODEL: imx500_tpc_dict,
28
32
  TFLITE_TP_MODEL: tflite_tpc_dict,
29
33
  QNNPACK_TP_MODEL: qnnpack_tpc_dict}
30
34
 
@@ -35,7 +39,7 @@ def get_target_platform_capabilities(fw_name: str,
35
39
  """
36
40
  Get a TargetPlatformCapabilities by the target platform model name and the framework name.
37
41
  For now, it supports frameworks 'tensorflow' and 'pytorch'. For both of them
38
- the target platform model can be 'default','tflite', or 'qnnpack'.
42
+ the target platform model can be 'default', 'imx500', 'tflite', or 'qnnpack'.
39
43
 
40
44
  Args:
41
45
  fw_name: Framework name of the TargetPlatformCapabilities.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -0,0 +1,24 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from model_compression_toolkit.core.common.constants import FOUND_TF, FOUND_TORCH
16
+ from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tp_model import get_tp_model, generate_tp_model, \
17
+ get_op_quantization_configs
18
+ if FOUND_TF:
19
+ from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_latest
20
+ from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tpc_keras import generate_keras_tpc
21
+ if FOUND_TORCH:
22
+ from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tpc_pytorch import get_pytorch_tpc as \
23
+ get_pytorch_tpc_latest
24
+ from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tpc_pytorch import generate_pytorch_tpc
@@ -0,0 +1,45 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from model_compression_toolkit.core.common.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH, LATEST
17
+
18
+ ###############################
19
+ # Build Tensorflow TPC models
20
+ ###############################
21
+ keras_tpc_models_dict = None
22
+ if FOUND_TF:
23
+ from model_compression_toolkit.core.tpc_models.imx500_tpc.latest import get_keras_tpc_latest
24
+ from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
25
+
26
+ # Keras: TPC versioning
27
+ keras_tpc_models_dict = {'v1': get_keras_tpc_v1(),
28
+ LATEST: get_keras_tpc_latest()}
29
+
30
+ ###############################
31
+ # Build Pytorch TPC models
32
+ ###############################
33
+ pytorch_tpc_models_dict = None
34
+ if FOUND_TORCH:
35
+ from model_compression_toolkit.core.tpc_models.imx500_tpc.latest import get_pytorch_tpc_latest
36
+ from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tpc_pytorch import \
37
+ get_pytorch_tpc as get_pytorch_tpc_v1
38
+
39
+ # Pytorch: TPC versioning
40
+ pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1(),
41
+ LATEST: get_pytorch_tpc_latest()}
42
+
43
+ tpc_dict = {TENSORFLOW: keras_tpc_models_dict,
44
+ PYTORCH: pytorch_tpc_models_dict}
45
+
@@ -0,0 +1,16 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ __version__ = 'v1'
@@ -0,0 +1,156 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List, Tuple
16
+
17
+ import model_compression_toolkit as mct
18
+ from model_compression_toolkit.core.common.target_platform import OpQuantizationConfig, TargetPlatformModel
19
+
20
+ tp = mct.target_platform
21
+
22
+
23
+ def get_tp_model() -> TargetPlatformModel:
24
+ """
25
+ A method that generates a default target platform model, with base 8-bit quantization configuration and 8, 4, 2
26
+ bits configuration list for mixed-precision quantization.
27
+ NOTE: in order to generate a target platform model with different configurations but with the same Operators Sets
28
+ (for tests, experiments, etc.), use this method implementation as a test-case, i.e., override the
29
+ 'get_op_quantization_configs' method and use its output to call 'generate_tp_model' with your configurations.
30
+
31
+ Returns: A TargetPlatformModel object.
32
+
33
+ """
34
+ base_config, mixed_precision_cfg_list = get_op_quantization_configs()
35
+ return generate_tp_model(default_config=base_config,
36
+ base_config=base_config,
37
+ mixed_precision_cfg_list=mixed_precision_cfg_list,
38
+ name='imx500_tp_model')
39
+
40
+
41
+ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
42
+ """
43
+ Creates a default configuration object for 8-bit quantization, to be used to set a default TargetPlatformModel.
44
+ In addition, creates a default configuration objects list (with 8, 4 and 2 bit quantization) to be used as
45
+ default configuration for mixed-precision quantization.
46
+
47
+ Returns: An OpQuantizationConfig config object and a list of OpQuantizationConfig objects.
48
+
49
+ """
50
+ # Create a quantization config.
51
+ # A quantization configuration defines how an operator
52
+ # should be quantized on the modeled hardware:
53
+ eight_bits = tp.OpQuantizationConfig(
54
+ activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
55
+ weights_quantization_method=tp.QuantizationMethod.SYMMETRIC,
56
+ activation_n_bits=8,
57
+ weights_n_bits=8,
58
+ weights_per_channel_threshold=True,
59
+ enable_weights_quantization=True,
60
+ enable_activation_quantization=True,
61
+ quantization_preserving=False,
62
+ fixed_scale=None,
63
+ fixed_zero_point=None,
64
+ weights_multiplier_nbits=None)
65
+
66
+ # To quantize a model using mixed-precision, create
67
+ # a list with more than one OpQuantizationConfig.
68
+ # In this example, we quantize some operations' weights
69
+ # using 2, 4 or 8 bits, and when using 2 or 4 bits, it's possible
70
+ # to quantize the operations' activations using LUT.
71
+ four_bits = eight_bits.clone_and_edit(weights_n_bits=4)
72
+ two_bits = eight_bits.clone_and_edit(weights_n_bits=2)
73
+ mixed_precision_cfg_list = [eight_bits, four_bits, two_bits]
74
+
75
+ return eight_bits, mixed_precision_cfg_list
76
+
77
+
78
+ def generate_tp_model(default_config: OpQuantizationConfig,
79
+ base_config: OpQuantizationConfig,
80
+ mixed_precision_cfg_list: List[OpQuantizationConfig],
81
+ name: str) -> TargetPlatformModel:
82
+ """
83
+ Generates TargetPlatformModel with default defined Operators Sets, based on the given base configuration and
84
+ mixed-precision configurations options list.
85
+
86
+ Args
87
+ default_config: A default OpQuantizationConfig to set as the TP model default configuration.
88
+ base_config: An OpQuantizationConfig to set as the TargetPlatformModel base configuration for mixed-precision purposes only.
89
+ mixed_precision_cfg_list: A list of OpQuantizationConfig to be used as the TP model mixed-precision
90
+ quantization configuration options.
91
+ name: The name of the TargetPlatformModel.
92
+
93
+ Returns: A TargetPlatformModel object.
94
+
95
+ """
96
+ # Create a QuantizationConfigOptions, which defines a set
97
+ # of possible configurations to consider when quantizing a set of operations (in mixed-precision, for example).
98
+ # If the QuantizationConfigOptions contains only one configuration,
99
+ # this configuration will be used for the operation quantization:
100
+ default_configuration_options = tp.QuantizationConfigOptions([default_config])
101
+
102
+ # Create a TargetPlatformModel and set its default quantization config.
103
+ # This default configuration will be used for all operations
104
+ # unless specified otherwise (see OperatorsSet, for example):
105
+ generated_tpc = tp.TargetPlatformModel(default_configuration_options, name=name)
106
+
107
+ # To start defining the model's components (such as operator sets, and fusing patterns),
108
+ # use 'with' the TargetPlatformModel instance, and create them as below:
109
+ with generated_tpc:
110
+ # Create an OperatorsSet to represent a set of operations.
111
+ # Each OperatorsSet has a unique label.
112
+ # If a quantization configuration options is passed, these options will
113
+ # be used for operations that will be attached to this set's label.
114
+ # Otherwise, it will be a configure-less set (used in fusing):
115
+
116
+ # May suit for operations like: Dropout, Reshape, etc.
117
+ tp.OperatorsSet("NoQuantization",
118
+ tp.get_default_quantization_config_options().clone_and_edit(
119
+ enable_weights_quantization=False,
120
+ enable_activation_quantization=False))
121
+
122
+ # Create Mixed-Precision quantization configuration options from the given list of OpQuantizationConfig objects
123
+ mixed_precision_configuration_options = tp.QuantizationConfigOptions(mixed_precision_cfg_list,
124
+ base_config=base_config)
125
+
126
+ # Define operator sets that use mixed_precision_configuration_options:
127
+ conv = tp.OperatorsSet("Conv", mixed_precision_configuration_options)
128
+ fc = tp.OperatorsSet("FullyConnected", mixed_precision_configuration_options)
129
+
130
+ # Define operations sets without quantization configuration
131
+ # options (useful for creating fusing patterns, for example):
132
+ any_relu = tp.OperatorsSet("AnyReLU")
133
+ add = tp.OperatorsSet("Add")
134
+ sub = tp.OperatorsSet("Sub")
135
+ mul = tp.OperatorsSet("Mul")
136
+ div = tp.OperatorsSet("Div")
137
+ prelu = tp.OperatorsSet("PReLU")
138
+ swish = tp.OperatorsSet("Swish")
139
+ sigmoid = tp.OperatorsSet("Sigmoid")
140
+ tanh = tp.OperatorsSet("Tanh")
141
+
142
+ # Combine multiple operators into a single operator to avoid quantization between
143
+ # them. To do this we define fusing patterns using the OperatorsSets that were created.
144
+ # To group multiple sets with regard to fusing, an OperatorSetConcat can be created
145
+ activations_after_conv_to_fuse = tp.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh)
146
+ activations_after_fc_to_fuse = tp.OperatorSetConcat(any_relu, swish, sigmoid)
147
+ any_binary = tp.OperatorSetConcat(add, sub, mul, div)
148
+
149
+ # ------------------- #
150
+ # Fusions
151
+ # ------------------- #
152
+ tp.Fusing([conv, activations_after_conv_to_fuse])
153
+ tp.Fusing([fc, activations_after_fc_to_fuse])
154
+ tp.Fusing([any_binary, any_relu])
155
+
156
+ return generated_tpc