mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
  2. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
  3. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +13 -14
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
  6. model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
  7. model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
  8. model_compression_toolkit/core/common/constants.py +9 -4
  9. model_compression_toolkit/core/common/framework_implementation.py +32 -30
  10. model_compression_toolkit/core/common/graph/base_graph.py +8 -6
  11. model_compression_toolkit/core/common/logger.py +10 -2
  12. model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  16. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  17. model_compression_toolkit/core/common/model_validation.py +2 -1
  18. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  19. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
  20. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
  22. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
  23. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
  24. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
  25. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
  26. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
  27. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
  28. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  29. model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
  30. model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
  31. model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  32. model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
  33. model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
  34. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
  35. model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
  36. model_compression_toolkit/core/keras/constants.py +1 -0
  37. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
  38. model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
  39. model_compression_toolkit/core/keras/quantization_facade.py +3 -3
  40. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
  41. model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
  42. model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
  43. model_compression_toolkit/core/keras/reader/common.py +3 -2
  44. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
  45. model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
  46. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
  47. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
  48. model_compression_toolkit/core/pytorch/constants.py +5 -0
  49. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
  50. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  51. model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
  52. model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
  53. model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
  54. model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
  55. model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
  56. model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
  57. model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
  58. model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
  59. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
  60. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
  61. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
  62. model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
  63. model_compression_toolkit/exporter/__init__.py +5 -0
  64. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
  65. model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
  66. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
  67. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
  68. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
  69. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
  70. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
  71. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
  72. model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
  73. model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
  74. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
  75. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
  76. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
  77. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
  78. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
  79. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
  80. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
  81. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
  82. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
  83. model_compression_toolkit/gptq/__init__.py +6 -0
  84. model_compression_toolkit/gptq/common/gptq_config.py +57 -127
  85. model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
  86. model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
  87. model_compression_toolkit/gptq/common/gptq_training.py +32 -26
  88. model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
  89. model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
  90. model_compression_toolkit/gptq/keras/graph_info.py +24 -43
  91. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
  92. model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
  93. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
  94. model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
  95. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
  96. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  97. model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
  98. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
  100. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
  101. model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
  102. model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
  103. model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
  104. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
  105. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
  106. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
  107. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
  108. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  109. model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
  110. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  111. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
  112. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  113. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
  114. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
  115. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
  116. model_compression_toolkit/qat/common/qat_config.py +68 -0
  117. model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
  118. model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
  119. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
  120. model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
  121. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
  122. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
  123. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
  124. model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
  125. model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
  126. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
  127. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
  128. model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
  129. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
  130. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
  131. model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
  132. model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
  133. model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
  134. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
  135. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
  136. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
  137. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
  138. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
  139. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
  140. model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
  141. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
  142. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
  143. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
  144. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
  145. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
  146. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
  147. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
  148. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
  149. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
  150. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
  151. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  152. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
  153. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
  154. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
  155. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
  156. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
  157. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
  158. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
  159. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
  160. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
  161. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
  162. model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
  163. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
  164. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
  165. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
  166. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
  167. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
  168. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
  169. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
  170. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
  171. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
  172. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
  173. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
  174. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
  175. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
  176. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
  177. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
  178. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
  179. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
  180. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
  181. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
  182. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
  183. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
  184. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
  185. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
  186. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
  187. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
  188. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
  189. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
  190. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
  191. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
  192. model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
  193. model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
  194. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
  195. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
  196. model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
  197. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
  198. model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
  199. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
  200. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
  201. model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
  202. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
  203. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
  204. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
  205. model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
  206. model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
  207. model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
  208. model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
  209. model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
  210. model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
  211. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
  212. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
  213. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
  214. model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
  215. model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
  216. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
  217. model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
  218. model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
  219. model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
  220. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
  221. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
  222. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
  223. model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
  224. model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
  225. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
  226. model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
  227. model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
  228. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
  229. model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
  230. model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
  231. model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
  232. model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
  233. model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
  234. model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
  235. model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
  236. model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
  237. model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
  238. model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
  239. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
  240. {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
  241. /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
@@ -16,11 +16,13 @@
16
16
  from abc import abstractmethod
17
17
 
18
18
  import tensorflow as tf
19
- from keras.models import Model
19
+ from keras.engine.input_layer import InputLayer
20
+ from keras.models import Model, clone_model
20
21
  from packaging import version
21
22
 
22
23
  from model_compression_toolkit.core.common.back2framework.base_model_builder import BaseModelBuilder
23
24
  from model_compression_toolkit.core.common.user_info import UserInformation
25
+ from model_compression_toolkit.core.common.constants import INPUT_BASE_NAME
24
26
 
25
27
  # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
26
28
  if version.parse(tf.__version__) < version.parse("2.6"):
@@ -42,7 +44,7 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
42
44
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
43
45
  from model_compression_toolkit.core.common import BaseNode
44
46
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
45
- from model_compression_toolkit.core.keras.back2framework.instance_builder import OperationHandler, identity_wrapper
47
+ from model_compression_toolkit.core.keras.back2framework.instance_builder import OperationHandler
46
48
  from model_compression_toolkit.core.keras.reader.connectivity_handler import OutTensor
47
49
 
48
50
  # In tf2.3 fake quant node is implemented as TensorFlowOpLayer, while in tf2.4 as TFOpLambda.
@@ -93,7 +95,7 @@ class KerasModelBuilder(BaseModelBuilder):
93
95
  append2output=None,
94
96
  fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
95
97
  return_float_outputs: bool = False,
96
- wrapper: Callable = identity_wrapper):
98
+ wrapper: Callable = None):
97
99
  """
98
100
 
99
101
  Args:
@@ -101,6 +103,7 @@ class KerasModelBuilder(BaseModelBuilder):
101
103
  append2output: Nodes to append to model's output.
102
104
  fw_info: Information about the specific framework of the model that is built.
103
105
  return_float_outputs: Whether the model returns float tensors or not.
106
+ wrapper: A function wrapper keras Layers.
104
107
  """
105
108
 
106
109
  super().__init__(graph,
@@ -109,9 +112,9 @@ class KerasModelBuilder(BaseModelBuilder):
109
112
  return_float_outputs)
110
113
 
111
114
  # Build an OperationHandler to handle conversions from graph nodes to Keras operators.
112
- self.oh = OperationHandler(self.graph, wrapper)
115
+ self.oh = OperationHandler(self.graph)
116
+ self.wrapper = wrapper
113
117
 
114
- @abstractmethod
115
118
  def _quantize_node_activations(self,
116
119
  node: BaseNode,
117
120
  input_tensors: List[TFReference]) -> List[TFReference]:
@@ -126,7 +129,8 @@ class KerasModelBuilder(BaseModelBuilder):
126
129
  Output of the node.
127
130
 
128
131
  """
129
- raise NotImplemented(f'{self.__class__.__name__} have to implement a method for quantization activation nodes.')
132
+ raise NotImplemented(f'{self.__class__.__name__} did not implement a method for quantizating '
133
+ f'activation nodes.') # pragma: no cover
130
134
 
131
135
  def build_model(self) -> Tuple[Model, UserInformation]:
132
136
  """
@@ -149,10 +153,17 @@ class KerasModelBuilder(BaseModelBuilder):
149
153
  # Hold a dictionary from an input node to its corresponding input tensor. It is needed for when
150
154
  # building the model. Initially input nodes with input tensors are added to the dictionary,
151
155
  # as they're not added later.
152
- input_nodes_to_input_tensors = {inode: Input(inode.framework_attr[BATCH_INPUT_SHAPE][1:], name=inode.name)
156
+ input_nodes_to_input_tensors = {inode: Input(inode.framework_attr[BATCH_INPUT_SHAPE][1:],
157
+ name=f'{inode.name}_{INPUT_BASE_NAME}')
153
158
  for
154
159
  inode in self.graph.get_inputs()}
155
160
 
161
+ # Support adding Layer after input layers require us to store it in layer_to_node_dict
162
+ # dict offline (unlike other layers which stored during running).
163
+ for node, layer in self.oh.node_to_fw_op_dict.items():
164
+ if node.type == InputLayer:
165
+ self.oh.layer_to_node_dict[layer] = node
166
+
156
167
  # Build a list of the model's input tensors. Switching from a dictionary to a list
157
168
  # to keep the tensors input order, since inputs in Graph are ordered by their indices.
158
169
  inputs_list = []
@@ -198,6 +209,20 @@ class KerasModelBuilder(BaseModelBuilder):
198
209
 
199
210
  # Build the model.
200
211
  model = tf.keras.Model(inputs=inputs_list, outputs=model_output_tensors)
212
+
213
+ if self.wrapper is not None:
214
+ def _wrap(layer):
215
+ _node = self.oh.layer_to_node_dict.get(layer)
216
+ if _node is not None:
217
+ return self.wrapper(_node, layer)
218
+ elif is_layer_fake_quant(layer):
219
+ return layer
220
+ raise Exception( # pragma: no cover
221
+ f'Mismatch between keras model and graph cant find node named: '
222
+ f'{get_node_name_from_layer(layer)}')
223
+
224
+ model = clone_model(model, clone_function=_wrap)
225
+
201
226
  return model, self.graph.user_info
202
227
 
203
228
  def _convert_node2name(self, in_node_to_output_tensors_dict):
@@ -246,19 +271,20 @@ class KerasModelBuilder(BaseModelBuilder):
246
271
  input_tensors: List of references to Keras tensors that are the layer's inputs.
247
272
  op_func: Layer to apply to the input tensors.
248
273
  input_nodes_to_input_tensors: A dictionary from a node to its input tensors.
249
- mode: model quantization mode from ModelBuilderMode
250
274
 
251
275
  Returns:
252
276
  A list of references to Keras tensors. The layer's output tensors after applying the
253
277
  layer to the input tensors.
254
278
  """
255
-
256
279
  if len(input_tensors) == 0: # Placeholder handling
257
280
  out_tensors_of_n_float = input_nodes_to_input_tensors[n]
258
- out_tensors_of_n = out_tensors_of_n_float
259
- if n.is_activation_quantization_enabled():
281
+ if self.wrapper is not None:
282
+ # if a wrapper is defined, add an identity layer for cloning. The Identity will be warpped
283
+ out_tensors_of_n = op_func(out_tensors_of_n_float)
284
+ elif n.is_activation_quantization_enabled():
260
285
  out_tensors_of_n = self._quantize_node_activations(n, out_tensors_of_n_float)
261
-
286
+ else:
287
+ out_tensors_of_n = out_tensors_of_n_float
262
288
  else:
263
289
  input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists
264
290
  # Build a functional node using its args
@@ -275,8 +301,8 @@ class KerasModelBuilder(BaseModelBuilder):
275
301
  out_tensors_of_n_float = op_func(input_tensors)
276
302
  out_tensors_of_n = out_tensors_of_n_float
277
303
 
278
- # Add a fake quant node if the node has an activation threshold.
279
- if n.is_activation_quantization_enabled():
304
+ # Add a fake quant node if the node has an activation threshold and a wrapper isn't defined
305
+ if n.is_activation_quantization_enabled() and self.wrapper is None:
280
306
  out_tensors_of_n = self._quantize_node_activations(n, out_tensors_of_n_float)
281
307
 
282
308
  # Save a mapping from the layer that created the tensor to the node (as this layer is not the
@@ -20,13 +20,13 @@ from packaging import version
20
20
  from tqdm import tqdm
21
21
 
22
22
  if version.parse(tf.__version__) < version.parse("2.6"):
23
- from tensorflow.python.keras.layers import Layer
23
+ from tensorflow.python.keras.layers import Layer # pragma: no cover
24
24
  else:
25
25
  from keras.engine.base_layer import Layer
26
26
 
27
27
  from typing import Any, Dict, List, Tuple
28
28
  from tensorflow.python.util.object_identity import Reference as TFReference
29
- from model_compression_toolkit.core.common.constants import EPS
29
+ from model_compression_toolkit.core.common.constants import EPS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE
30
30
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
31
31
  from model_compression_toolkit.core import common
32
32
  from model_compression_toolkit.core.common import BaseNode, Graph
@@ -128,7 +128,7 @@ def keras_iterative_approx_jacobian_trace(graph_float: common.Graph,
128
128
  """
129
129
 
130
130
  if not all([images.shape[0] == 1 for node, images in model_input_tensors.items()]):
131
- Logger.critical("Iterative jacobian trace computation is only supported on a single image sample")
131
+ Logger.critical("Iterative jacobian trace computation is only supported on a single image sample") # pragma: no cover
132
132
 
133
133
  with tf.GradientTape(persistent=True, watch_accessed_variables=False) as g:
134
134
  outputs, interest_points_tensors = _model_outputs_computation(graph_float,
@@ -136,32 +136,56 @@ def keras_iterative_approx_jacobian_trace(graph_float: common.Graph,
136
136
  interest_points,
137
137
  output_list,
138
138
  gradient_tape=g)
139
- outputs_jacobians_approx = []
140
- for output in outputs: # Per model's output tensor
141
- output = tf.reshape(output, shape=[output.shape[0], -1])
142
-
143
- ipts_jac_trace_approx = []
144
- for ipt in tqdm(interest_points_tensors): # Per Interest point activation tensor
145
- trace_jv = []
146
- for j in range(n_iter): # Approximation iterations
147
- # Getting a random vector with normal distribution
148
- v = tf.random.normal(shape=output.shape)
149
- f_v = tf.reduce_sum(v * output)
150
-
151
- with g.stop_recording():
152
- # Computing the jacobian approximation by getting the gradient of (output * v)
153
- jac_v = g.gradient(f_v, ipt, unconnected_gradients=tf.UnconnectedGradients.ZERO)
154
- jac_v = tf.reshape(jac_v, [jac_v.shape[0], -1])
155
- jac_trace_approx = tf.reduce_mean(tf.reduce_sum(tf.pow(jac_v, 2.0)))
156
- trace_jv.append(jac_trace_approx)
157
- ipts_jac_trace_approx.append(2 * tf.reduce_mean(trace_jv) / output.shape[-1]) # Get averaged squared jacobian trace approximation
158
- outputs_jacobians_approx.append(ipts_jac_trace_approx)
159
-
160
- mean_per_point = tf.reduce_mean(outputs_jacobians_approx, axis=0) # Get mean of jacobian approx of all model outputs
139
+
140
+ # Concat outputs
141
+ # First, we need to unfold all outputs that are given as list, to extract the actual output tensors
142
+ unfold_outputs = []
143
+ for output in outputs:
144
+ if isinstance(output, List):
145
+ unfold_outputs += output
146
+ else:
147
+ unfold_outputs.append(output)
148
+
149
+ r_outputs = [tf.reshape(output, shape=[output.shape[0], -1]) for output in unfold_outputs]
150
+
151
+ concat_axis_dim = [o.shape[0] for o in r_outputs]
152
+ if not all(d == concat_axis_dim[0] for d in concat_axis_dim):
153
+ Logger.critical("Can't concat model's outputs for gradients calculation since the shape of the first axis " # pragma: no cover
154
+ "is not equal in all outputs.")
155
+
156
+ output = tf.concat(r_outputs, axis=1)
157
+
158
+ ipts_jac_trace_approx = []
159
+ for ipt in tqdm(interest_points_tensors): # Per Interest point activation tensor
160
+ trace_jv = []
161
+ for j in range(n_iter): # Approximation iterations
162
+ # Getting a random vector with normal distribution
163
+ v = tf.random.normal(shape=output.shape)
164
+ f_v = tf.reduce_sum(v * output)
165
+
166
+ with g.stop_recording():
167
+ # Computing the jacobian approximation by getting the gradient of (output * v)
168
+ jac_v = g.gradient(f_v, ipt, unconnected_gradients=tf.UnconnectedGradients.ZERO)
169
+ jac_v = tf.reshape(jac_v, [jac_v.shape[0], -1])
170
+ jac_trace_approx = tf.reduce_mean(tf.reduce_sum(tf.pow(jac_v, 2.0)))
171
+
172
+ # If the change to the mean Jacobian approximation is insignificant we stop the calculation
173
+ if j > MIN_JACOBIANS_ITER:
174
+ new_mean = np.mean([jac_trace_approx, *trace_jv])
175
+ delta = new_mean - np.mean(trace_jv)
176
+ if np.abs(delta) / (np.abs(new_mean) + 1e-6) < JACOBIANS_COMP_TOLERANCE:
177
+ trace_jv.append(jac_trace_approx)
178
+ break
179
+
180
+ trace_jv.append(jac_trace_approx)
181
+ ipts_jac_trace_approx.append(2 * tf.reduce_mean(trace_jv) / output.shape[-1]) # Get averaged squared jacobian trace approximation
182
+
183
+ ipts_jac_trace_approx = tf.reduce_mean([ipts_jac_trace_approx], axis=0) # Just to get one tensor instead of list of tensors with single element
184
+
161
185
  if norm_weights:
162
- return _normalize_weights(mean_per_point, all_outputs_indices, alpha)
186
+ return _normalize_weights(ipts_jac_trace_approx, all_outputs_indices, alpha)
163
187
  else:
164
- return mean_per_point
188
+ return ipts_jac_trace_approx
165
189
 
166
190
 
167
191
  def _model_outputs_computation(graph_float: common.Graph,
@@ -101,6 +101,7 @@ RELU_POT_BOUND = 8.0
101
101
 
102
102
  # Supported TP models names for Tensorflow:
103
103
  DEFAULT_TP_MODEL = 'default'
104
+ IMX500_TP_MODEL = 'imx500'
104
105
  TFLITE_TP_MODEL = 'tflite'
105
106
  QNNPACK_TP_MODEL = 'qnnpack'
106
107
 
@@ -23,6 +23,7 @@ else:
23
23
  from keras.layers.core import TFOpLambda
24
24
  from keras.layers import MultiHeadAttention, Conv2D, Softmax, Concatenate, Reshape, Permute
25
25
 
26
+ from model_compression_toolkit.core.common.logger import Logger
26
27
  from model_compression_toolkit.core import common
27
28
  from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
28
29
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
@@ -448,7 +449,7 @@ class MultiHeadAttentionDecomposition(common.BaseSubstitution):
448
449
  """
449
450
 
450
451
  if mha_node.reuse:
451
- raise Exception("MCT doesn't support reuse of MultiHeadAttention layer")
452
+ Logger.error("MCT doesn't support reuse of MultiHeadAttention layer") # pragma: no cover
452
453
  params = MHAParams(mha_node)
453
454
 
454
455
  mha_in_edges = graph.in_edges(mha_node)
@@ -156,10 +156,10 @@ else:
156
156
  def keras_kpi_data(*args, **kwargs):
157
157
  Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
158
158
  'when using keras_kpi_data. '
159
- 'Could not find Tensorflow package.')
159
+ 'Could not find Tensorflow package.') # pragma: no cover
160
160
 
161
161
 
162
162
  def keras_kpi_data_experimental(*args, **kwargs):
163
163
  Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
164
164
  'when using keras_kpi_data. '
165
- 'Could not find Tensorflow package.')
165
+ 'Could not find Tensorflow package.') # pragma: no cover
@@ -19,7 +19,7 @@ from model_compression_toolkit.core import common
19
19
  from model_compression_toolkit.core.common import Logger
20
20
  from model_compression_toolkit.core.common.constants import TENSORFLOW
21
21
  from model_compression_toolkit.core.common.user_info import UserInformation
22
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GradientPTQConfigV2
22
+ from model_compression_toolkit.gptq import GradientPTQConfig, GradientPTQConfigV2
23
23
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
24
24
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
25
25
  from model_compression_toolkit.core.common.network_editors.actions import EditRule
@@ -281,10 +281,10 @@ else:
281
281
  def keras_post_training_quantization(*args, **kwargs):
282
282
  Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
283
283
  'when using keras_post_training_quantization. '
284
- 'Could not find Tensorflow package.')
284
+ 'Could not find Tensorflow package.') # pragma: no cover
285
285
 
286
286
 
287
287
  def keras_post_training_quantization_mixed_precision(*args, **kwargs):
288
288
  Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
289
289
  'when using keras_post_training_quantization_mixed_precision. '
290
- 'Could not find Tensorflow package.')
290
+ 'Could not find Tensorflow package.') # pragma: no cover
@@ -20,7 +20,7 @@ import tensorflow as tf
20
20
  import numpy as np
21
21
  from tensorflow.python.util.object_identity import Reference as TFReference
22
22
 
23
- from model_compression_toolkit.core.common import Logger
23
+ from model_compression_toolkit.core.common.logger import Logger
24
24
  from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX
25
25
  from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import threshold_is_power_of_two
26
26
 
@@ -68,10 +68,12 @@ def power_of_two_quantization(activation_n_bits: int,
68
68
  activation_threshold = quantization_params.get(THRESHOLD)
69
69
  activation_is_signed = quantization_params.get(SIGNED)
70
70
 
71
- if activation_threshold is None or activation_is_signed is None:
72
- return None
71
+ if activation_threshold is None:
72
+ Logger.error("Activation threshold is None") # pragma: no cover
73
+ if activation_is_signed is None:
74
+ Logger.error("activation_is_signed is None") # pragma: no cover
73
75
  if not threshold_is_power_of_two(activation_threshold, per_channel=False):
74
- return None
76
+ Logger.error("Activation threshold is not power of two") # pragma: no cover
75
77
 
76
78
  min_value, max_value = quantizer_min_max_calculator(activation_threshold,
77
79
  activation_n_bits,
@@ -96,8 +98,10 @@ def symmetric_quantization(activation_n_bits: int,
96
98
  activation_threshold = quantization_params.get(THRESHOLD)
97
99
  activation_is_signed = quantization_params.get(SIGNED)
98
100
 
99
- if activation_threshold is None or activation_is_signed is None:
100
- return None
101
+ if activation_threshold is None:
102
+ Logger.error("Activation threshold is None") # pragma: no cover
103
+ if activation_is_signed is None:
104
+ Logger.error("activation_is_signed is None") # pragma: no cover
101
105
 
102
106
  min_value, max_value = quantizer_min_max_calculator(activation_threshold,
103
107
  activation_n_bits,
@@ -121,8 +125,10 @@ def uniform_quantization(activation_n_bits: int,
121
125
  """
122
126
  min_value, max_value = quantization_params.get(RANGE_MIN), quantization_params.get(RANGE_MAX)
123
127
 
124
- if min_value is None or max_value is None:
125
- return None
128
+ if min_value is None:
129
+ Logger.error("Min value is None") # pragma: no cover
130
+ if max_value is None:
131
+ Logger.error("Max value is None") # pragma: no cover
126
132
 
127
133
  return lambda x: q(x, min_value, max_value, activation_n_bits)
128
134
 
@@ -141,7 +147,7 @@ def q(x: TFReference, min_value, max_value, activation_n_bits) -> TFReference:
141
147
  The fake-quantized input tensor.
142
148
  """
143
149
  if x.dtype != tf.float32:
144
- x = tf.cast(x, dtype=tf.float32)
150
+ x = tf.cast(x, dtype=tf.float32) # pragma: no cover
145
151
 
146
152
  # fake_quant_with_min_max_vars expects to get x of float32
147
153
  return tf.quantization.fake_quant_with_min_max_vars(x,
@@ -23,6 +23,7 @@ from tensorflow_model_optimization.python.core.quantization.keras.quantize_confi
23
23
 
24
24
 
25
25
  from model_compression_toolkit.core.common import BaseNode
26
+ from model_compression_toolkit.core.common.constants import INPUT_BASE_NAME
26
27
 
27
28
 
28
29
  class InputLayerWrapperTransform(InputLayerQuantize):
@@ -55,7 +56,7 @@ class InputLayerWrapperTransform(InputLayerQuantize):
55
56
  self.wrapper_class = wrapper_class
56
57
 
57
58
  def pattern(self):
58
- return transforms.LayerPattern('InputLayer', config={'name': self.name})
59
+ return transforms.LayerPattern('InputLayer', config={'name': f'{self.name}_{INPUT_BASE_NAME}'})
59
60
 
60
61
  def replacement(self, match_layer):
61
62
  layer_wrapper = self.wrapper_class(InputLayer(input_shape=self.input_layer.input_shape),
@@ -60,7 +60,7 @@ class LUTFakeQuant(Layer):
60
60
 
61
61
  """
62
62
  if self.activation_is_signed is None or self.cluster_centers is None or self.threshold is None:
63
- return None
63
+ return None # pragma: no cover
64
64
 
65
65
  _quant_output = self.lut_kmeans_quantizer(input_data)
66
66
  return _quant_output
@@ -29,6 +29,7 @@ else:
29
29
  from keras.engine.functional import Functional
30
30
  from keras.engine.sequential import Sequential
31
31
 
32
+ from model_compression_toolkit.core.common.logger import Logger
32
33
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
33
34
 
34
35
 
@@ -46,7 +47,7 @@ def is_node_an_input_layer(node: BaseNode) -> bool:
46
47
  elif isinstance(node, KerasNode):
47
48
  return isinstance(node.layer, InputLayer)
48
49
  else:
49
- raise Exception('Node to check has to be either a graph node or a keras node')
50
+ Logger.error('Node to check has to be either a graph node or a keras node') # pragma: no cover
50
51
 
51
52
 
52
53
  def is_node_a_model(node: BaseNode) -> bool:
@@ -63,5 +64,5 @@ def is_node_a_model(node: BaseNode) -> bool:
63
64
  elif isinstance(node, KerasNode):
64
65
  return isinstance(node.layer, Functional) or isinstance(node.layer, Sequential)
65
66
  else:
66
- raise Exception('Node to check has to be either a graph node or a keras node')
67
+ Logger.error('Node to check has to be either a graph node or a keras node') # pragma: no cover
67
68
 
@@ -35,4 +35,17 @@ def node_builder(n: BaseNode) -> Module:
35
35
  node_instance = n.type(**framework_attr)
36
36
  node_instance.load_state_dict({k: torch.Tensor(v) for k, v in n.weights.items()}, strict=False)
37
37
  set_model(node_instance)
38
- return node_instance
38
+ return node_instance
39
+
40
+
41
+ def identity_wrapper(node: BaseNode, module: Module):
42
+ """
43
+ A function which takes a computational graph node and a pytorch module and return an identity wrapping which return the layer itself
44
+ Args:
45
+ node: A node of mct graph.
46
+ layer: A pytorch module
47
+ Returns: pytorch module
48
+ """
49
+ return module
50
+
51
+
@@ -18,15 +18,18 @@ import torch
18
18
  import torch.autograd as autograd
19
19
  from networkx import topological_sort
20
20
  from tqdm import tqdm
21
+ import numpy as np
21
22
 
22
23
  from model_compression_toolkit.core import common
23
24
  from model_compression_toolkit.core.common import BaseNode, Graph
24
- from model_compression_toolkit.core.common.constants import EPS
25
+ from model_compression_toolkit.core.common.constants import EPS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE
25
26
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
26
27
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
27
28
  from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder
28
- from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
29
- from model_compression_toolkit.core.pytorch.utils import torch_tensor_to_numpy
29
+ from model_compression_toolkit.core.pytorch.constants import BUFFER
30
+ from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder, BufferHolder
31
+ from model_compression_toolkit.core.pytorch.utils import torch_tensor_to_numpy, get_working_device
32
+ from model_compression_toolkit.core.common.logger import Logger
30
33
 
31
34
 
32
35
  def build_input_tensors_list(node: BaseNode,
@@ -95,8 +98,11 @@ def generate_outputs(
95
98
  output = []
96
99
  for n in out_nodes:
97
100
  out_tensors_of_n = node_to_output_tensors_dict.get(n)
98
- if len(out_tensors_of_n) > 1:
99
- output.append(out_tensors_of_n)
101
+ if len(out_tensors_of_n) > 1 or isinstance(out_tensors_of_n[0], tuple):
102
+ if isinstance(out_tensors_of_n[0], tuple):
103
+ out_tensors_of_n = out_tensors_of_n[0]
104
+ out_tensors_of_n = [torch.cat(out_tensors_of_n)]
105
+ output.append(torch.concat(out_tensors_of_n))
100
106
  else:
101
107
  output += out_tensors_of_n
102
108
  return output
@@ -128,7 +134,13 @@ class PytorchModelGradients(torch.nn.Module):
128
134
 
129
135
  for n in self.node_sort:
130
136
  if not isinstance(n, FunctionalNode):
131
- self.add_module(n.name, node_builder(n))
137
+ if n.type == BufferHolder:
138
+ self.add_module(n.name, node_builder(n))
139
+ self.get_submodule(n.name). \
140
+ register_buffer(n.name,
141
+ torch.Tensor(n.get_weights_by_keys(BUFFER)).to(get_working_device()))
142
+ else:
143
+ self.add_module(n.name, node_builder(n))
132
144
 
133
145
  def forward(self,
134
146
  *args: Any) -> Any:
@@ -153,7 +165,7 @@ class PytorchModelGradients(torch.nn.Module):
153
165
  input_tensors,
154
166
  op_func=op_func)
155
167
 
156
- if isinstance(out_tensors_of_n, list):
168
+ if isinstance(out_tensors_of_n, list) or isinstance(out_tensors_of_n, tuple):
157
169
  output_t = []
158
170
  for t in out_tensors_of_n:
159
171
  if n in self.interest_points:
@@ -162,13 +174,19 @@ class PytorchModelGradients(torch.nn.Module):
162
174
  self.interest_points_tensors.append(t)
163
175
  else:
164
176
  # We get here in case we have an output node, which is an interest point,
165
- # but it is not differentiable. We need to add this dummy tensor to in order to include this
166
- # node in the coming weights computation.
167
- self.interest_points_tensors.append(torch.tensor([0.0],
177
+ # but it is not differentiable. We need to add this dummy tensor in order to include this
178
+ # node in the future weights computation.
179
+ # Note that this call is excluded from tests coverage,
180
+ # since we do not suppose to get here - there is no valid operation that is both
181
+ # non-differentiable and return output as a list or a tuple
182
+ self.interest_points_tensors.append(torch.tensor([0.0], # pragma: no cover
168
183
  requires_grad=True,
169
184
  device=t.device))
170
- break
185
+ break # pragma: no cover
171
186
  output_t.append(t)
187
+ if isinstance(out_tensors_of_n, tuple):
188
+ # If the node's output is a Tuple, then we want to keep it as a Tuple
189
+ output_t = [tuple(output_t)]
172
190
  node_to_output_tensors_dict.update({n: output_t})
173
191
  else:
174
192
  assert isinstance(out_tensors_of_n, torch.Tensor)
@@ -178,8 +196,8 @@ class PytorchModelGradients(torch.nn.Module):
178
196
  self.interest_points_tensors.append(out_tensors_of_n)
179
197
  else:
180
198
  # We get here in case we have an output node, which is an interest point,
181
- # but it is not differentiable. We need to add this dummy tensor to in order to include this
182
- # node in the coming weights computation.
199
+ # but it is not differentiable. We need to add this dummy tensor in order to include this
200
+ # node in the future weights computation.
183
201
  self.interest_points_tensors.append(torch.tensor([0.0],
184
202
  requires_grad=True,
185
203
  device=out_tensors_of_n.device))
@@ -233,45 +251,69 @@ def pytorch_iterative_approx_jacobian_trace(graph_float: common.Graph,
233
251
  output_tensors = model_grads_net(model_input_tensors)
234
252
  device = output_tensors[0].device
235
253
 
236
- outputs_jacobians_approx = []
237
- for output in output_tensors: # Per model's output tensor
238
- output = torch.reshape(output, shape=[output.shape[0], -1])
239
-
240
- ipts_jac_trace_approx = []
241
- for ipt in tqdm(model_grads_net.interest_points_tensors): # Per Interest point activation tensor
242
- trace_jv = []
243
- for j in range(n_iter): # Approximation iterations
244
- # Getting a random vector with normal distribution
245
- v = torch.randn(output.shape, device=device)
246
- f_v = torch.sum(v * output)
247
-
248
- # Computing the jacobian approximation by getting the gradient of (output * v)
249
- jac_v = autograd.grad(outputs=f_v,
250
- inputs=ipt,
251
- retain_graph=True,
252
- allow_unused=True)[0]
253
- if jac_v is None:
254
- # In case we have an output node, which is an interest point, but it is not differentiable,
255
- # we still want to set some weight for it. For this, we need to add this dummy tensor to the ipt
256
- # jacobian traces list.
257
- trace_jv.append(torch.tensor([0.0],
258
- requires_grad=True,
259
- device=device))
254
+
255
+ # Concat outputs
256
+ # First, we need to unfold all outputs that are given as list, to extract the actual output tensors
257
+ unfold_outputs = []
258
+ for output in output_tensors:
259
+ if isinstance(output, List):
260
+ unfold_outputs += output
261
+ else:
262
+ unfold_outputs.append(output)
263
+
264
+ r_outputs = [torch.reshape(output, shape=[output.shape[0], -1]) for output in unfold_outputs]
265
+
266
+ concat_axis_dim = [o.shape[0] for o in r_outputs]
267
+ if not all(d == concat_axis_dim[0] for d in concat_axis_dim):
268
+ Logger.critical("Can't concat model's outputs for gradients calculation since the shape of the first axis " # pragma: no cover
269
+ "is not equal in all outputs.")
270
+
271
+ output = torch.concat(r_outputs, dim=1)
272
+
273
+ ipts_jac_trace_approx = []
274
+ for ipt in tqdm(model_grads_net.interest_points_tensors): # Per Interest point activation tensor
275
+ trace_jv = []
276
+ for j in range(n_iter): # Approximation iterations
277
+ # Getting a random vector with normal distribution
278
+ v = torch.randn(output.shape, device=device)
279
+ f_v = torch.sum(v * output)
280
+
281
+ # Computing the jacobian approximation by getting the gradient of (output * v)
282
+ jac_v = autograd.grad(outputs=f_v,
283
+ inputs=ipt,
284
+ retain_graph=True,
285
+ allow_unused=True)[0]
286
+ if jac_v is None:
287
+ # In case we have an output node, which is an interest point, but it is not differentiable,
288
+ # we still want to set some weight for it. For this, we need to add this dummy tensor to the ipt
289
+ # jacobian traces list.
290
+ trace_jv.append(torch.tensor([0.0],
291
+ requires_grad=True,
292
+ device=device))
293
+ break
294
+ jac_v = torch.reshape(jac_v, [jac_v.shape[0], -1])
295
+ jac_trace_approx = torch.mean(torch.sum(torch.pow(jac_v, 2.0)))
296
+
297
+ # If the change to the mean Jacobian approximation is insignificant we stop the calculation
298
+ if j > MIN_JACOBIANS_ITER:
299
+ new_mean = torch.mean(torch.stack([jac_trace_approx, *trace_jv]))
300
+ delta = new_mean - torch.mean(torch.stack(trace_jv))
301
+ if torch.abs(delta) / (torch.abs(new_mean) + 1e-6) < JACOBIANS_COMP_TOLERANCE:
302
+ trace_jv.append(jac_trace_approx)
260
303
  break
261
- jac_v = torch.reshape(jac_v, [jac_v.shape[0], -1])
262
- jac_trace_approx = torch.mean(torch.sum(torch.pow(jac_v, 2.0)))
263
- trace_jv.append(jac_trace_approx)
264
- ipts_jac_trace_approx.append(2*torch.mean(torch.stack(trace_jv))/output.shape[-1]) # Get averaged jacobian trace approximation
265
- outputs_jacobians_approx.append(ipts_jac_trace_approx)
266
304
 
267
- mean_per_point = torch_tensor_to_numpy(torch.mean(torch.Tensor(outputs_jacobians_approx), dim=0)) # Get mean of jacobians of all model's outputs
305
+ trace_jv.append(jac_trace_approx)
306
+ ipts_jac_trace_approx.append(2*torch.mean(torch.stack(trace_jv))/output.shape[-1]) # Get averaged jacobian trace approximation
307
+
308
+ ipts_jac_trace_approx = torch_tensor_to_numpy(torch.Tensor(ipts_jac_trace_approx)) # Just to get one tensor instead of list of tensors with single element
309
+
268
310
  if norm_weights:
269
- return _normalize_weights(mean_per_point, all_outputs_indices, alpha)
311
+ return _normalize_weights(ipts_jac_trace_approx, all_outputs_indices, alpha)
270
312
  else:
271
- return mean_per_point
313
+ return ipts_jac_trace_approx
272
314
 
273
315
 
274
- def _normalize_weights(jacobians_traces: torch.Tensor,
316
+ def _normalize_weights(jacobians_traces: np.ndarray,
275
317
  all_outputs_indices: List[int],
276
318
  alpha: float) -> List[float]:
277
319
  """