onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,3605 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # -------------------------------------------------------------------------
5
+ """
6
+ This converts GPT2 or T5 model to onnx with beam search operator.
7
+
8
+ Example 1: convert gpt2 model with beam search:
9
+ python convert_generation.py -m gpt2 --output gpt2_beam_search.onnx
10
+
11
+ Example 2: convert gpt2 model with beam search containing specific cuda optimizations:
12
+ python convert_generation.py -m gpt2 --output gpt2_beam_search.onnx --use_gpu \
13
+ --past_present_share_buffer --use_decoder_masked_attention
14
+
15
+ Example 3: convert gpt2 model with beam search with mixed precision and enable SkipLayerNorm strict mode:
16
+ python convert_generation.py -m gpt2 --output gpt2_beam_search.onnx --use_gpu -p fp16 --use_sln_strict_mode
17
+
18
+ Example 4: convert T5 model with beam search in two steps:
19
+ python -m models.t5.convert_to_onnx -m t5-small
20
+ python convert_generation.py -m t5-small --model_type t5 \
21
+ --decoder_onnx ./onnx_models/t5-small_decoder.onnx \
22
+ --encoder_decoder_init_onnx ./onnx_models/t5-small_encoder.onnx \
23
+ --output ./onnx_models/t5_small_beam_search.onnx
24
+
25
+ Example 5: convert T5 model with beam search. All in one step:
26
+ python convert_generation.py -m t5-small --model_type t5 --output t5_small_beam_search.onnx
27
+
28
+ Example 6: convert T5 model with beam search containing specific cuda optimizations. All in one step:
29
+ python convert_generation.py -m t5-small --model_type t5 --output t5_small_beam_search.onnx \
30
+ --use_gpu --past_present_share_buffer --use_decoder_masked_attention
31
+
32
+ Example 7: convert MT5 model with external data file like mt5-base-beamsearch.onnx.data in below example.
33
+ python convert_generation.py -m google/mt5-base --model_type mt5 --output mt5-base-beamsearch.onnx -e
34
+
35
+ Example 8: convert gpt2 model with greedy search:
36
+ python convert_generation.py -m gpt2 --output gpt2_greedy_search.onnx --num_beams 1 --num_return_sequences 1
37
+
38
+ Example 9: convert gpt2 model with sampling:
39
+ python convert_generation.py -m gpt2 --output gpt2_sampling.onnx --num_beams 1 --num_return_sequences 1 --top_p 0.6
40
+ """
41
+
42
+ import argparse
43
+ import logging
44
+ import math
45
+ import os
46
+ import time
47
+ from enum import Enum
48
+ from pathlib import Path
49
+ from typing import Any
50
+
51
+ import numpy as np
52
+ import onnx
53
+ import torch
54
+ from benchmark_helper import Precision, setup_logger
55
+ from fusion_utils import NumpyHelper
56
+ from onnx import GraphProto, ModelProto, TensorProto
57
+ from onnx_model import OnnxModel
58
+ from transformers import (
59
+ GPT2Config,
60
+ GPT2LMHeadModel,
61
+ GPT2Tokenizer,
62
+ MT5Config,
63
+ MT5ForConditionalGeneration,
64
+ T5Config,
65
+ T5ForConditionalGeneration,
66
+ T5Tokenizer,
67
+ )
68
+
69
+ from onnxruntime import (
70
+ GraphOptimizationLevel,
71
+ InferenceSession,
72
+ SessionOptions,
73
+ get_available_providers,
74
+ )
75
+ from onnxruntime.transformers.models.gpt2.convert_to_onnx import (
76
+ main as convert_gpt2_to_onnx,
77
+ )
78
+ from onnxruntime.transformers.models.gpt2.gpt2_helper import PRETRAINED_GPT2_MODELS
79
+ from onnxruntime.transformers.models.t5.convert_to_onnx import (
80
+ export_onnx_models as export_t5_onnx_models,
81
+ )
82
+ from onnxruntime.transformers.models.t5.t5_helper import (
83
+ PRETRAINED_MT5_MODELS,
84
+ PRETRAINED_T5_MODELS,
85
+ )
86
+
87
+ logger = logging.getLogger("")
88
+
89
+
90
+ class GenerationType(Enum):
91
+ BEAMSEARCH = "beam_search"
92
+ GREEDYSEARCH = "greedy_search"
93
+ SAMPLING = "sampling"
94
+
95
+ def __str__(self):
96
+ return self.value
97
+
98
+
99
+ def parse_arguments(argv: list[str] | None = None) -> argparse.Namespace:
100
+ """Parse arguments
101
+
102
+ Args:
103
+ argv (Optional[List[str]], optional): _description_. Defaults to None.
104
+
105
+ Returns:
106
+ argparse.Namespace: Parsed arguments.
107
+ """
108
+ parser = argparse.ArgumentParser()
109
+
110
+ input_group = parser.add_argument_group("Input options")
111
+
112
+ input_group.add_argument(
113
+ "-m",
114
+ "--model_name_or_path",
115
+ required=True,
116
+ type=str,
117
+ help="Pytorch model checkpoint path, or pretrained model name in the list: "
118
+ + ", ".join(PRETRAINED_GPT2_MODELS + PRETRAINED_T5_MODELS + PRETRAINED_MT5_MODELS),
119
+ )
120
+
121
+ input_group.add_argument(
122
+ "--model_type",
123
+ required=False,
124
+ type=str,
125
+ default="gpt2",
126
+ choices=["gpt2", "t5", "mt5"],
127
+ help="Model type (default is gpt2) in the list: " + ", ".join(["gpt2", "t5", "mt5"]),
128
+ )
129
+
130
+ input_group.add_argument(
131
+ "--cache_dir",
132
+ required=False,
133
+ type=str,
134
+ default=os.path.join(".", "cache_models"),
135
+ help="Directory to cache pre-trained models",
136
+ )
137
+
138
+ input_group.add_argument(
139
+ "--decoder_onnx",
140
+ required=False,
141
+ type=str,
142
+ default="",
143
+ help="Path of onnx model for decoder. Specify it when you have exported the model.",
144
+ )
145
+
146
+ input_group.add_argument(
147
+ "--encoder_decoder_init_onnx",
148
+ required=False,
149
+ type=str,
150
+ default="",
151
+ help="Path of ONNX model for encoder and decoder initialization. Specify it when you have exported the model.",
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--verbose",
156
+ required=False,
157
+ action="store_true",
158
+ help="Print more information",
159
+ )
160
+ parser.set_defaults(verbose=False)
161
+
162
+ output_group = parser.add_argument_group("Output options")
163
+
164
+ output_group.add_argument(
165
+ "--output",
166
+ required=True,
167
+ type=str,
168
+ help="Output path for onnx model with beam search.",
169
+ )
170
+
171
+ output_group.add_argument(
172
+ "-p",
173
+ "--precision",
174
+ required=False,
175
+ type=str,
176
+ default=Precision.FLOAT32.value,
177
+ choices=[Precision.FLOAT32.value, Precision.FLOAT16.value],
178
+ help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision",
179
+ )
180
+
181
+ output_group.add_argument(
182
+ "-b",
183
+ "--op_block_list",
184
+ required=False,
185
+ nargs="*",
186
+ default=["auto"],
187
+ help="Disable certain onnx operators when exporting model to onnx format. When using default"
188
+ 'value for gpt2 type of model fp16 precision, it will be set to ["Add", "LayerNormalization",'
189
+ ' "SkipLayerNormalization", "FastGelu"]. Other situation, it will be set to []',
190
+ )
191
+
192
+ output_group.add_argument(
193
+ "-e",
194
+ "--use_external_data_format",
195
+ required=False,
196
+ action="store_true",
197
+ help="save external data for model > 2G",
198
+ )
199
+ output_group.set_defaults(use_external_data_format=False)
200
+
201
+ output_group.add_argument(
202
+ "-s",
203
+ "--run_shape_inference",
204
+ required=False,
205
+ action="store_true",
206
+ help="run shape inference",
207
+ )
208
+ output_group.set_defaults(run_shape_inference=False)
209
+
210
+ output_group.add_argument(
211
+ "-dpvs",
212
+ "--disable_pad_vocab_size",
213
+ required=False,
214
+ action="store_true",
215
+ help="Do not pad logits MatMul weight to be a multiple of 8 along the dimension where dim value is"
216
+ " the vocab size. The logits MatMul may hence be of poor performance for fp16 precision.",
217
+ )
218
+ output_group.set_defaults(disable_pad_vocab_size=False)
219
+
220
+ output_group.add_argument(
221
+ "-dsgd",
222
+ "--disable_separate_gpt2_decoder_for_init_run",
223
+ required=False,
224
+ action="store_true",
225
+ help="Do not create separate decoder subgraphs for initial and remaining runs. This does not allow "
226
+ "for optimizations based on sequence lengths in each subgraph",
227
+ )
228
+ output_group.set_defaults(disable_separate_gpt2_decoder_for_init_run=False)
229
+
230
+ output_group.add_argument(
231
+ "-i",
232
+ "--disable_shared_initializers",
233
+ required=False,
234
+ action="store_true",
235
+ help="do not share initializers in encoder and decoder for T5 or in the init decoder and decoder for "
236
+ "GPT2. It will increase memory usage of t5/mt5/gpt2 models.",
237
+ )
238
+ output_group.set_defaults(disable_shared_initializers=False)
239
+
240
+ output_group.add_argument(
241
+ "--encoder_decoder_init",
242
+ required=False,
243
+ action="store_true",
244
+ help="Add decoder initialization to encoder for T5 model. This is legacy format that will be deprecated.",
245
+ )
246
+ output_group.set_defaults(encoder_decoder_init=False)
247
+
248
+ model_group = parser.add_argument_group("Beam search parameters that stored in the output model")
249
+
250
+ model_group.add_argument(
251
+ "--output_sequences_scores",
252
+ required=False,
253
+ action="store_true",
254
+ help="output sequences scores",
255
+ )
256
+ model_group.set_defaults(output_sequences_scores=False)
257
+
258
+ model_group.add_argument(
259
+ "--output_token_scores",
260
+ required=False,
261
+ action="store_true",
262
+ help="output token scores",
263
+ )
264
+ model_group.set_defaults(output_token_scores=False)
265
+
266
+ model_group.add_argument("--early_stopping", required=False, action="store_true")
267
+ model_group.set_defaults(early_stopping=False)
268
+
269
+ model_group.add_argument(
270
+ "--no_repeat_ngram_size",
271
+ type=int,
272
+ required=False,
273
+ default=0,
274
+ help="No repeat ngram size",
275
+ )
276
+
277
+ model_group.add_argument(
278
+ "--vocab_mask",
279
+ required=False,
280
+ action="store_true",
281
+ help="Enable vocab_mask. This mask applies only to every generated token to filter some bad words.",
282
+ )
283
+ model_group.set_defaults(vocab_mask=False)
284
+
285
+ model_group.add_argument(
286
+ "--past_present_share_buffer",
287
+ required=False,
288
+ action="store_true",
289
+ help="Use shared buffer for past and present, currently work for gpt2 greedy/sampling search.",
290
+ )
291
+ model_group.set_defaults(past_present_share_buffer=False)
292
+
293
+ model_group.add_argument(
294
+ "--use_decoder_masked_attention",
295
+ required=False,
296
+ action="store_true",
297
+ help="Uses `DecoderMaskedSelfAttention` or `DecoderMaskedMultiHeadAttention` to optimize the decoding Attention computation. "
298
+ "Must be used with `past_present_share_buffer`. Currently, only Attention head sizes of 32, 64 and 128 are supported.",
299
+ )
300
+ model_group.set_defaults(use_decoder_masked_attention=False)
301
+
302
+ model_group.add_argument(
303
+ "--prefix_vocab_mask",
304
+ required=False,
305
+ action="store_true",
306
+ help="Enable prefix_vocab_mask. This mask can be used to filter bad words in the first generated token only",
307
+ )
308
+ model_group.set_defaults(prefix_vocab_mask=False)
309
+
310
+ model_group.add_argument(
311
+ "--custom_attention_mask",
312
+ required=False,
313
+ action="store_true",
314
+ help="Enable custom_attention_mask. This mask can be used to replace default encoder attention mask",
315
+ )
316
+ model_group.set_defaults(custom_attention_mask=False)
317
+
318
+ model_group.add_argument(
319
+ "--presence_mask",
320
+ required=False,
321
+ action="store_true",
322
+ help="Presence mask for custom sampling",
323
+ )
324
+ model_group.set_defaults(presence_mask=False)
325
+
326
+ model_group.add_argument(
327
+ "--seed",
328
+ required=False,
329
+ action="store_true",
330
+ help="Random seed for sampling op",
331
+ )
332
+ model_group.set_defaults(seed=False)
333
+
334
+ beam_parameters_group = parser.add_argument_group(
335
+ "Beam search parameters not stored in the output model, for testing parity and performance"
336
+ )
337
+
338
+ beam_parameters_group.add_argument("--min_length", type=int, required=False, default=1, help="Min sequence length")
339
+
340
+ beam_parameters_group.add_argument("--max_length", type=int, required=False, default=50, help="Max sequence length")
341
+
342
+ beam_parameters_group.add_argument("--num_beams", type=int, required=False, default=4, help="Beam size")
343
+
344
+ beam_parameters_group.add_argument(
345
+ "--num_return_sequences",
346
+ type=int,
347
+ required=False,
348
+ default=1,
349
+ help="Number of return sequence <= num_beams",
350
+ )
351
+
352
+ beam_parameters_group.add_argument(
353
+ "--length_penalty",
354
+ type=float,
355
+ required=False,
356
+ default=1,
357
+ help="Positive. >1 to penalize and <1 to encourage short sentence.",
358
+ )
359
+
360
+ beam_parameters_group.add_argument(
361
+ "--repetition_penalty",
362
+ type=float,
363
+ required=False,
364
+ default=1,
365
+ help="Positive. >1 to penalize and <1 to encourage.",
366
+ )
367
+
368
+ beam_parameters_group.add_argument(
369
+ "--temperature",
370
+ type=float,
371
+ required=False,
372
+ default=1.0,
373
+ help="The value used to module the next token probabilities.",
374
+ )
375
+
376
+ beam_parameters_group.add_argument(
377
+ "--top_p",
378
+ type=float,
379
+ required=False,
380
+ default=1.0,
381
+ help="Top P for sampling",
382
+ )
383
+
384
+ beam_parameters_group.add_argument(
385
+ "--filter_value",
386
+ type=float,
387
+ required=False,
388
+ default=-float("Inf"),
389
+ help="Filter value for Top P sampling",
390
+ )
391
+
392
+ beam_parameters_group.add_argument(
393
+ "--min_tokens_to_keep",
394
+ type=int,
395
+ required=False,
396
+ default=1,
397
+ help="Minimum number of tokens we keep per batch example in the output.",
398
+ )
399
+
400
+ beam_parameters_group.add_argument(
401
+ "--presence_penalty",
402
+ type=float,
403
+ required=False,
404
+ default=0.0,
405
+ help="presence penalty for custom sampling.",
406
+ )
407
+
408
+ beam_parameters_group.add_argument(
409
+ "--custom",
410
+ type=int,
411
+ required=False,
412
+ default=0,
413
+ help="If 1 customized top P logic is applied",
414
+ )
415
+
416
+ beam_parameters_group.add_argument(
417
+ "--vocab_size",
418
+ type=int,
419
+ required=False,
420
+ default=-1,
421
+ help="Vocab_size of the underlying model used to decide the shape of vocab mask",
422
+ )
423
+
424
+ beam_parameters_group.add_argument(
425
+ "--eos_token_id",
426
+ type=int,
427
+ required=False,
428
+ default=-1,
429
+ help="custom eos_token_id for generating model with existing onnx encoder/decoder",
430
+ )
431
+
432
+ beam_parameters_group.add_argument(
433
+ "--pad_token_id",
434
+ type=int,
435
+ required=False,
436
+ default=-1,
437
+ help="custom pad_token_id for generating model with existing onnx encoder/decoder",
438
+ )
439
+
440
+ test_group = parser.add_argument_group("Other options for testing parity and performance")
441
+
442
+ test_group.add_argument(
443
+ "--use_sln_strict_mode",
444
+ required=False,
445
+ action="store_true",
446
+ help="Enable strict mode for SLN in CUDA provider. This ensures a better accuracy but will be slower.",
447
+ )
448
+ test_group.set_defaults(use_sln_strict_mode=False)
449
+
450
+ test_group.add_argument(
451
+ "--use_gpu",
452
+ required=False,
453
+ action="store_true",
454
+ help="use GPU for inference. Required for fp16.",
455
+ )
456
+ test_group.set_defaults(use_gpu=False)
457
+
458
+ test_group.add_argument(
459
+ "--disable_parity",
460
+ required=False,
461
+ action="store_true",
462
+ help="do not run parity test",
463
+ )
464
+ test_group.set_defaults(disable_parity=False)
465
+
466
+ test_group.add_argument(
467
+ "--disable_perf_test",
468
+ required=False,
469
+ action="store_true",
470
+ help="do not run perf test",
471
+ )
472
+ test_group.set_defaults(disable_perf_test=False)
473
+
474
+ test_group.add_argument(
475
+ "--torch_performance",
476
+ required=False,
477
+ action="store_true",
478
+ help="test PyTorch performance",
479
+ )
480
+ test_group.set_defaults(torch_performance=False)
481
+
482
+ test_group.add_argument(
483
+ "--total_runs",
484
+ required=False,
485
+ type=int,
486
+ default=1,
487
+ help="Number of times of inference for latency measurement",
488
+ )
489
+
490
+ test_group.add_argument(
491
+ "--save_test_data",
492
+ required=False,
493
+ action="store_true",
494
+ help="save test data for onnxruntime_perf_test tool",
495
+ )
496
+ test_group.set_defaults(save_test_data=False)
497
+
498
+ args = parser.parse_args(argv)
499
+
500
+ return args
501
+
502
+
503
+ def gpt2_to_onnx(args: argparse.Namespace):
504
+ """Convert GPT-2 model to onnx
505
+
506
+ Args:
507
+ args (argparse.Namespace): arguments parsed from command line
508
+ """
509
+ model_name = args.model_name_or_path
510
+
511
+ arguments = [
512
+ "--model_name_or_path",
513
+ model_name,
514
+ "--output",
515
+ args.decoder_onnx,
516
+ "--optimize_onnx",
517
+ "--precision",
518
+ args.precision,
519
+ "--test_runs",
520
+ "1",
521
+ "--test_cases",
522
+ "10",
523
+ "--overwrite", # Overwrite onnx file if existed
524
+ ]
525
+ if args.cache_dir:
526
+ arguments.extend(["--cache_dir", args.cache_dir])
527
+ if args.use_gpu:
528
+ arguments.append("--use_gpu")
529
+ if args.use_external_data_format:
530
+ arguments.append("--use_external_data_format")
531
+
532
+ if len(args.op_block_list):
533
+ arguments.extend(["--op_block_list"])
534
+ arguments.extend(args.op_block_list)
535
+
536
+ if args.precision == Precision.FLOAT16.value:
537
+ assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu"
538
+ # TODO(tianleiwu): Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision')
539
+ # Need change cuda kernel to support a combination of fp32 logits and fp16 past state.
540
+ # Currently logits and past state shall be same data type.
541
+
542
+ if args.verbose:
543
+ logger.info(f"arguments for convert_to_onnx:{arguments}")
544
+
545
+ convert_gpt2_to_onnx(argv=arguments)
546
+
547
+
548
+ def t5_to_onnx(args: argparse.Namespace):
549
+ """Convert T5 model to onnx
550
+
551
+ Args:
552
+ args (argparse.Namespace): arguments parsed from command line
553
+ """
554
+ paths = export_t5_onnx_models(
555
+ model_name_or_path=args.model_name_or_path,
556
+ cache_dir=args.cache_dir,
557
+ output_dir=Path(args.output).parent,
558
+ use_gpu=args.use_gpu,
559
+ use_external_data_format=args.use_external_data_format,
560
+ optimize_onnx=(args.precision != Precision.FLOAT16.value),
561
+ precision=args.precision,
562
+ verbose=False,
563
+ use_decoder_start_token=False,
564
+ overwrite=True,
565
+ disable_auto_mixed_precision=False,
566
+ use_int32_inputs=True,
567
+ model_type=args.model_type,
568
+ encoder_decoder_init=args.encoder_decoder_init,
569
+ force_fp16_io=(args.precision == Precision.FLOAT16.value), # required by BeamSearch op implementation.
570
+ )
571
+
572
+ logger.debug(f"onnx model for encoder: {paths[0]}")
573
+ logger.debug(f"onnx model for decoder: {paths[1]}")
574
+ args.encoder_decoder_init_onnx = paths[0]
575
+ args.decoder_onnx = paths[1]
576
+
577
+
578
+ def shape_inference(onnx_path: str, use_external_data_format: bool = True):
579
+ """Shape inference on an onnx file, which will be overwritten.
580
+
581
+ Args:
582
+ onnx_path (str): Path of onnx model
583
+ use_external_data_format(bool): output tensors to external data or not.
584
+ """
585
+ # Run symbolic shape inference to walk around ORT shape inference issue for subgraph.
586
+ from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference # noqa: PLC0415
587
+
588
+ model = onnx.load_model(onnx_path, load_external_data=True)
589
+ out = SymbolicShapeInference.infer_shapes(model, auto_merge=True, guess_output_rank=False)
590
+ if out:
591
+ OnnxModel.save(out, onnx_path, save_as_external_data=use_external_data_format)
592
+ else:
593
+ logger.warning("Failed to run symbolic shape inference on the model.")
594
+
595
+
596
+ def pad_weights_of_logits_matmul(onnx_path: str, use_external_data_format: bool = True) -> bool:
597
+ """Pad the logits MatMul weight in the provided decoder model, which will be overwritten.
598
+
599
+ Args:
600
+ onnx_path (str): Path of onnx model
601
+ use_external_data_format(bool): output tensors to external data or not.
602
+ """
603
+ decoder_model_proto = onnx.load_model(onnx_path, load_external_data=True)
604
+
605
+ logits_output_name = decoder_model_proto.graph.output[0].name
606
+
607
+ decoder_model = OnnxModel(decoder_model_proto)
608
+
609
+ output_name_to_node = decoder_model.output_name_to_node()
610
+ assert logits_output_name in output_name_to_node
611
+
612
+ matmul_node = output_name_to_node[logits_output_name]
613
+ # Sanity check - the logits need to be produced by a MatMul node
614
+ if matmul_node.op_type != "MatMul":
615
+ return False
616
+
617
+ # The logits MatMul weight MUST be an initializer (or)
618
+ # it MUST be flowing through a Transpose whose input is
619
+ # an initializer
620
+ pad_along_axis_1 = True
621
+ logits_weight = decoder_model.get_initializer(matmul_node.input[1])
622
+ if logits_weight is None:
623
+ transpose_before_matmul = decoder_model.match_parent(matmul_node, "Transpose", 1)
624
+
625
+ if transpose_before_matmul is None:
626
+ return False
627
+
628
+ logits_weight = decoder_model.get_initializer(transpose_before_matmul.input[0])
629
+
630
+ if logits_weight is None:
631
+ return False
632
+
633
+ pad_along_axis_1 = False
634
+
635
+ # The logits MatMul weight MUST be fp16
636
+ if logits_weight.data_type != TensorProto.DataType.FLOAT16:
637
+ return False
638
+
639
+ # The logits MatMul weight MUST be 2-dimensional
640
+ if len(logits_weight.dims) != 2:
641
+ return False
642
+
643
+ # Pad and over-write the initializer (if needed)
644
+ actual_vocab_size = logits_weight.dims[1]
645
+
646
+ if (actual_vocab_size % 8) == 0:
647
+ # Already "padded"
648
+ return True
649
+
650
+ padded_vocab_size = math.ceil(actual_vocab_size / 8) * 8
651
+ padding = padded_vocab_size - actual_vocab_size
652
+
653
+ # TODO(hasesh): Handle cases where the fp16 data is stored in the
654
+ # non-raw data field
655
+ if logits_weight.raw_data:
656
+ if pad_along_axis_1:
657
+ padding_data = np.zeros((logits_weight.dims[0], padding), dtype=np.float16)
658
+ weight_with_padding = np.concatenate((NumpyHelper.to_array(logits_weight), padding_data), axis=1)
659
+ logits_weight.dims[1] = padded_vocab_size
660
+ else:
661
+ padding_data = np.zeros((padding, logits_weight.dims[1]), dtype=np.float16)
662
+ weight_with_padding = np.concatenate((NumpyHelper.to_array(logits_weight), padding_data), axis=0)
663
+ logits_weight.dims[0] = padded_vocab_size
664
+
665
+ logits_weight.raw_data = weight_with_padding.tobytes()
666
+ else:
667
+ return False
668
+
669
+ # Save the model
670
+ OnnxModel.save(decoder_model_proto, onnx_path, save_as_external_data=use_external_data_format)
671
+ return True
672
+
673
+
674
+ def create_ort_session(model_path: str, use_gpu: bool, use_sln_strict_mode: bool) -> InferenceSession:
675
+ """Create OnnxRuntime session.
676
+
677
+ Args:
678
+ model_path (str): onnx model path
679
+ use_gpu (bool): use GPU or not
680
+ use_sln_strict_mode (bool): use strict mode for skip layer normalization or not
681
+
682
+ Raises:
683
+ RuntimeError: CUDAExecutionProvider is not available when --use_gpu is specified.
684
+
685
+ Returns:
686
+ onnxruntime.InferenceSession: The created session.
687
+ """
688
+ sess_options = SessionOptions()
689
+ sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
690
+ execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
691
+ if use_gpu:
692
+ if "CUDAExecutionProvider" not in get_available_providers():
693
+ raise RuntimeError("CUDAExecutionProvider is not available for --use_gpu!")
694
+ else:
695
+ logger.info("use CUDAExecutionProvider")
696
+ if use_sln_strict_mode:
697
+ cuda_provider_options = {"enable_skip_layer_norm_strict_mode": True}
698
+ provider_options = {"CUDAExecutionProvider": cuda_provider_options}
699
+ execution_providers = [
700
+ (name, provider_options[name]) if name in provider_options else name for name in execution_providers
701
+ ]
702
+
703
+ ort_session = InferenceSession(model_path, sess_options, providers=execution_providers)
704
+ return ort_session
705
+
706
+
707
+ def verify_gpt2_subgraph(graph: onnx.GraphProto, precision: Precision):
708
+ """Verify GPT-2 subgraph
709
+
710
+ Args:
711
+ graph (onnx.GraphProto): onnx graph of GPT-2
712
+ precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
713
+
714
+ Raises:
715
+ ValueError: Number of inputs not expected.
716
+ ValueError: Input name is not expected.
717
+ ValueError: Input data type is not expected.
718
+ ValueError: Number of outputs not expected.
719
+ ValueError: Output name is not expected.
720
+ ValueError: Output data type is not expected.
721
+ """
722
+ is_float16 = precision == Precision.FLOAT16.value
723
+
724
+ input_count = len(graph.input)
725
+ layer_count = input_count - 3
726
+ assert layer_count >= 1
727
+
728
+ expected_inputs = ["input_ids", "position_ids", "attention_mask"] + [f"past_{i}" for i in range(layer_count)]
729
+ if len(graph.input) != len(expected_inputs):
730
+ raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
731
+
732
+ for i, expected_input in enumerate(expected_inputs):
733
+ if graph.input[i].name != expected_input:
734
+ raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
735
+
736
+ expected_type = TensorProto.INT32
737
+ if i >= 3:
738
+ expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
739
+
740
+ input_type = graph.input[i].type.tensor_type.elem_type
741
+ if input_type != expected_type:
742
+ raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
743
+ logger.info("Verifying GPT-2 graph inputs: name and data type are good.")
744
+
745
+ expected_outputs = ["logits"] + [f"present_{i}" for i in range(layer_count)]
746
+ if len(graph.output) != len(expected_outputs):
747
+ raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
748
+
749
+ for i, expected_output in enumerate(expected_outputs):
750
+ if graph.output[i].name != expected_output:
751
+ raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
752
+
753
+ expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
754
+ output_type = graph.output[i].type.tensor_type.elem_type
755
+ if output_type != expected_type:
756
+ raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {output_type}")
757
+ logger.info("Verifying GPT-2 graph outputs: name and data type are good.")
758
+
759
+ # TODO(tianleiwu): verify shapes of inputs and outputs.
760
+ return
761
+
762
+
763
+ def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision):
764
+ """Verify T5 decoder subgraph
765
+
766
+ Args:
767
+ graph (onnx.GraphProto): onnx graph of T5 decoder
768
+ precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
769
+
770
+ Raises:
771
+ ValueError: Number of inputs not expected.
772
+ ValueError: Input name is not expected.
773
+ ValueError: Input data type is not expected.
774
+ ValueError: Number of outputs not expected.
775
+ ValueError: Output name is not expected.
776
+ ValueError: Output data type is not expected.
777
+ """
778
+ is_float16 = precision == Precision.FLOAT16.value
779
+ float_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
780
+
781
+ input_count = len(graph.input)
782
+ layer_count = (input_count - 2) // 4
783
+ assert layer_count >= 1
784
+
785
+ # Expect inputs:
786
+ # input_ids: int32 (B, 1)
787
+ # encoder_attention_mask: int32 (B, encode_sequence_length)
788
+
789
+ # past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
790
+ # past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
791
+ # ... (for each self attention layer)
792
+
793
+ # past_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
794
+ # past_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
795
+ # ... (for each cross attention layer)
796
+
797
+ # TODO: encoder_hidden_states is optional
798
+ expected_inputs = ["input_ids", "encoder_attention_mask"]
799
+ for i in range(layer_count):
800
+ expected_inputs.append(f"past_key_self_{i}")
801
+ expected_inputs.append(f"past_value_self_{i}")
802
+ for i in range(layer_count):
803
+ expected_inputs.append(f"past_key_cross_{i}")
804
+ expected_inputs.append(f"past_value_cross_{i}")
805
+
806
+ if len(graph.input) != len(expected_inputs):
807
+ raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
808
+
809
+ for i, expected_input in enumerate(expected_inputs):
810
+ if graph.input[i].name != expected_input:
811
+ raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
812
+
813
+ expected_type = TensorProto.INT32 if i < 2 else float_type
814
+ input_type = graph.input[i].type.tensor_type.elem_type
815
+ if input_type != expected_type:
816
+ raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
817
+
818
+ # Expect outputs:
819
+ # logits: (B, 1, vocab_size)
820
+ # present_key_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
821
+ # present_value_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
822
+ # ... (for each self attention layer)
823
+ expected_outputs = ["logits"]
824
+ for i in range(layer_count):
825
+ expected_outputs.append(f"present_key_self_{i}")
826
+ expected_outputs.append(f"present_value_self_{i}")
827
+
828
+ if len(graph.output) != len(expected_outputs):
829
+ raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
830
+
831
+ for i, expected_output in enumerate(expected_outputs):
832
+ if graph.output[i].name != expected_output:
833
+ raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
834
+ output_type = graph.output[i].type.tensor_type.elem_type
835
+ if output_type != float_type:
836
+ raise ValueError(f"Output {i} is expected to have onnx data type {float_type}. Got {output_type}")
837
+
838
+
839
+ def verify_t5_encoder_decoder_init_subgraph(graph: onnx.GraphProto, precision: Precision):
840
+ """Verify T5 decoder subgraph
841
+
842
+ Args:
843
+ graph (onnx.GraphProto): onnx graph of T5 decoder
844
+ precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
845
+
846
+ Raises:
847
+ ValueError: Number of inputs not expected.
848
+ ValueError: Input name is not expected.
849
+ ValueError: Input data type is not expected.
850
+ ValueError: Number of outputs not expected.
851
+ ValueError: Output name is not expected.
852
+ ValueError: Output data type is not expected.
853
+ """
854
+ is_float16 = precision == Precision.FLOAT16.value
855
+ new_format = "cross" in graph.output[0].name
856
+
857
+ # Expect 3 inputs:
858
+ # encoder_input_ids: int32 (B, encode_sequence_length)
859
+ # encoder_attention_mask: int32 (B, encode_sequence_length)
860
+ # decoder_input_ids: int32 (B, 1)
861
+ expected_inputs = [
862
+ "encoder_input_ids",
863
+ "encoder_attention_mask",
864
+ "decoder_input_ids",
865
+ ]
866
+ if new_format:
867
+ expected_inputs = expected_inputs[:2]
868
+ if len(graph.input) != len(expected_inputs):
869
+ raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
870
+
871
+ for i, expected_input in enumerate(expected_inputs):
872
+ if graph.input[i].name != expected_input:
873
+ raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
874
+
875
+ expected_type = TensorProto.INT32
876
+ input_type = graph.input[i].type.tensor_type.elem_type
877
+ if input_type != expected_type:
878
+ raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
879
+
880
+ if new_format:
881
+ assert len(graph.output) % 2 == 0
882
+ layer_count = len(graph.output) // 2
883
+ assert layer_count >= 1
884
+
885
+ # Expected outputs:
886
+ # present_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
887
+ # present_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
888
+ # ... (for each cross attention layer)
889
+ expected_outputs = []
890
+ for i in range(layer_count):
891
+ expected_outputs.append(f"present_key_cross_{i}")
892
+ expected_outputs.append(f"present_value_cross_{i}")
893
+ else:
894
+ logger.warning("This format is deprecated. Please export T5 encoder in new format with only cross outputs.")
895
+ assert (len(graph.output) - 2) % 4 == 0
896
+ layer_count = (len(graph.output) - 2) // 4
897
+ assert layer_count >= 1
898
+
899
+ # Expected outputs:
900
+ # logits: (B, 1, vocab_size)
901
+ # encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
902
+ # present_key_self_0: (B, num_heads, 1, head_size)
903
+ # present_value_self_0: (B, num_heads, 1, head_size)
904
+ # ... (for each self attention layer)
905
+ # present_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
906
+ # present_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
907
+ # ... (for each cross attention layer)
908
+ expected_outputs = ["logits", "encoder_hidden_states"]
909
+ for i in range(layer_count):
910
+ expected_outputs.append(f"present_key_self_{i}")
911
+ expected_outputs.append(f"present_value_self_{i}")
912
+ for i in range(layer_count):
913
+ expected_outputs.append(f"present_key_cross_{i}")
914
+ expected_outputs.append(f"present_value_cross_{i}")
915
+
916
+ if len(graph.output) != len(expected_outputs):
917
+ raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
918
+
919
+ for i, expected_output in enumerate(expected_outputs):
920
+ if graph.output[i].name != expected_output:
921
+ raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
922
+
923
+ expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
924
+ output_type = graph.output[i].type.tensor_type.elem_type
925
+ if output_type != expected_type:
926
+ raise ValueError(f"Output {i} is expected to have onnx data type {expected_type}. Got {output_type}")
927
+
928
+ logger.info("T5 encoder graph verified: name and data type of inputs and outputs are good.")
929
+
930
+
931
+ def remove_shared_initializers(
932
+ graph1: GraphProto,
933
+ graph2: GraphProto,
934
+ shared_prefix: str = "shared_",
935
+ min_elements: int = 1024,
936
+ signature_cache1: dict | None = None,
937
+ signature_cache2: dict | None = None,
938
+ ):
939
+ """Remove initializers with same value from two graphs.
940
+
941
+ Args:
942
+ graph1 (GraphProto): the first graph to process
943
+ graph2 (GraphProto): the second graph to process
944
+ shared_prefix (str): add prefix to the shared initializers among two graphs
945
+ min_elements (int, optional): minimal number of elements for initializers to be considered. Defaults to 1024.
946
+ signature_cache1 (dict): Optional dictionary to store data signatures of tensors in graph1 in order to speed up comparison
947
+ signature_cache2 (dict): Optional dictionary to store data signatures of tensors in graph2 in order to speed up comparison
948
+ """
949
+
950
+ mapping_initializers_1 = {}
951
+ mapping_initializers_2 = {}
952
+ shared_initializers_1 = []
953
+ shared_initializers_2 = []
954
+ shared_initializers_names = []
955
+
956
+ for initializer1 in graph1.initializer:
957
+ if not (initializer1.dims and sum(initializer1.dims) >= min_elements):
958
+ continue
959
+
960
+ for initializer2 in graph2.initializer:
961
+ if not (initializer2.dims and sum(initializer2.dims) >= min_elements):
962
+ continue
963
+
964
+ if OnnxModel.has_same_value(initializer1, initializer2, signature_cache1, signature_cache2):
965
+ mapping_initializers_1[initializer1.name] = shared_prefix + initializer2.name
966
+ shared_initializers_1.append(initializer1)
967
+
968
+ if initializer2.name not in mapping_initializers_2:
969
+ shared_name = shared_prefix + initializer2.name
970
+ mapping_initializers_2[initializer2.name] = shared_name
971
+ shared_initializers_2.append(initializer2)
972
+ shared_initializers_names.append(shared_name)
973
+ break
974
+
975
+ logger.debug(f"shared initializers:{shared_initializers_names}")
976
+
977
+ # Make sure new name does not exist in graph 1
978
+ for node in graph1.node:
979
+ for j in range(len(node.input)):
980
+ if node.input[j] in shared_initializers_names:
981
+ raise RuntimeError(f"name is found in graph 1: {node.input[j]}")
982
+
983
+ # Make sure new name does not exist in graph 2
984
+ for node in graph2.node:
985
+ for j in range(len(node.input)):
986
+ if node.input[j] in shared_initializers_names:
987
+ raise RuntimeError(f"name is found in graph 2: {node.input[j]}")
988
+
989
+ # Remove shared initializers from graph 2
990
+ for initializer in shared_initializers_2:
991
+ graph2.initializer.remove(initializer)
992
+
993
+ # Rename value info for old names in graph 2
994
+ for value_info in graph2.value_info:
995
+ if value_info.name in mapping_initializers_2:
996
+ value_info.name = mapping_initializers_2[value_info.name]
997
+
998
+ # Rename nodes inputs in graph 2:
999
+ for node in graph2.node:
1000
+ for j in range(len(node.input)):
1001
+ if node.input[j] in mapping_initializers_2:
1002
+ new_name = mapping_initializers_2[node.input[j]]
1003
+ logger.debug(f"graph 2 rename node {node.name} input {j} from {node.input[j]} to {new_name}")
1004
+ node.input[j] = new_name
1005
+
1006
+ # Remove shared initializers from graph 1
1007
+ for initializer in shared_initializers_1:
1008
+ graph1.initializer.remove(initializer)
1009
+
1010
+ # Rename value info for old names in graph 1
1011
+ for value_info in graph1.value_info:
1012
+ if value_info.name in mapping_initializers_1:
1013
+ value_info.name = mapping_initializers_1[value_info.name]
1014
+
1015
+ # Rename nodes inputs in graph 1:
1016
+ for node in graph1.node:
1017
+ for j in range(len(node.input)):
1018
+ if node.input[j] in mapping_initializers_1:
1019
+ new_name = mapping_initializers_1[node.input[j]]
1020
+ logger.debug(f"graph 1 rename node {node.name} input {j} from {node.input[j]} to {new_name}")
1021
+ node.input[j] = new_name
1022
+
1023
+ # Rename shared initializers in graph 2
1024
+ for initializer in shared_initializers_2:
1025
+ initializer.name = mapping_initializers_2[initializer.name]
1026
+
1027
+ for initializer in shared_initializers_2:
1028
+ shape = onnx.numpy_helper.to_array(initializer).shape
1029
+ value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape)
1030
+ # Need add value_info for initializers moved to parent graph. Otherwise, ORT will fail.
1031
+ graph1.value_info.append(value_info)
1032
+ graph2.value_info.append(value_info)
1033
+
1034
+ return shared_initializers_2
1035
+
1036
+
1037
+ def get_shared_initializers(encoder_model: ModelProto, decoder_model: ModelProto):
1038
+ encoder = OnnxModel(encoder_model)
1039
+ decoder = OnnxModel(decoder_model)
1040
+ encoder.add_prefix_to_names("e_")
1041
+ decoder.add_prefix_to_names("d_")
1042
+ signature_cache1, signature_cache2 = {}, {}
1043
+ encoder.remove_duplicated_initializer(signature_cache1)
1044
+ decoder.remove_duplicated_initializer(signature_cache2)
1045
+ initializers = remove_shared_initializers(
1046
+ decoder.model.graph,
1047
+ encoder.model.graph,
1048
+ shared_prefix="s_",
1049
+ signature_cache1=signature_cache1,
1050
+ signature_cache2=signature_cache2,
1051
+ )
1052
+ return initializers
1053
+
1054
+
1055
+ def move_initializers(
1056
+ graph: GraphProto,
1057
+ min_elements: int = 1024,
1058
+ ) -> list[TensorProto]:
1059
+ """Remove initializers of a graph, when they have number of elements larger than a threshold.
1060
+
1061
+ Args:
1062
+ graph (GraphProto): the graph.
1063
+ min_elements (int, optional): minimal number of elements for initializers to be considered. Defaults to 1024.
1064
+
1065
+ Returns:
1066
+ List[TensorProto]: initializers that are removed from the graph.
1067
+ """
1068
+ moved_initializers = []
1069
+ for tensor in graph.initializer:
1070
+ if not (tensor.dims and sum(tensor.dims) >= min_elements):
1071
+ continue
1072
+ moved_initializers.append(tensor)
1073
+
1074
+ for initializer in moved_initializers:
1075
+ graph.initializer.remove(initializer)
1076
+
1077
+ # Add type info, otherwise ORT will raise error: "input arg (*) does not have type information set by parent node."
1078
+ for initializer in moved_initializers:
1079
+ shape = onnx.numpy_helper.to_array(initializer).shape
1080
+ value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape)
1081
+ graph.value_info.append(value_info)
1082
+
1083
+ return moved_initializers
1084
+
1085
+
1086
+ def _attribute_to_pair(attribute):
1087
+ """
1088
+ Convert attribute to kwarg format for use with onnx.helper.make_node.
1089
+ :parameter attribute: attribute in AttributeProto format.
1090
+ :return: attribute in {key: value} format.
1091
+ """
1092
+ if attribute.type == 0:
1093
+ raise ValueError(f"attribute {attribute.name} does not have type specified.")
1094
+
1095
+ # Based on attribute type definitions from AttributeProto
1096
+ # definition in https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
1097
+ if attribute.type == 1:
1098
+ value = attribute.f
1099
+ elif attribute.type == 2:
1100
+ value = attribute.i
1101
+ elif attribute.type == 3:
1102
+ value = attribute.s
1103
+ elif attribute.type == 4:
1104
+ value = attribute.t
1105
+ elif attribute.type == 5:
1106
+ value = attribute.g
1107
+ elif attribute.type == 6:
1108
+ value = attribute.floats
1109
+ elif attribute.type == 7:
1110
+ value = attribute.ints
1111
+ elif attribute.type == 8:
1112
+ value = attribute.strings
1113
+ elif attribute.type == 9:
1114
+ value = attribute.tensors
1115
+ elif attribute.type == 10:
1116
+ value = attribute.graphs
1117
+ else:
1118
+ raise ValueError(f"attribute {attribute.name} has unsupported type {attribute.type}.")
1119
+
1120
+ return (attribute.name, value)
1121
+
1122
+
1123
+ def kwargs_of(node):
1124
+ kwargs = {}
1125
+ for attr in node.attribute:
1126
+ (key, value) = _attribute_to_pair(attr)
1127
+ kwargs.update({key: value})
1128
+ if node.domain:
1129
+ kwargs.update({"domain": node.domain})
1130
+ return kwargs
1131
+
1132
+
1133
+ def shape_of(vi):
1134
+ return tuple([d.dim_param if (d.dim_param) else d.dim_value for d in vi.type.tensor_type.shape.dim])
1135
+
1136
+
1137
+ def update_decoder_subgraph_past_present_share_buffer(subg: GraphProto):
1138
+ input_past_0 = 3
1139
+ output_past_0 = 1
1140
+ new_inputs = []
1141
+ for i, vi in enumerate(subg.input):
1142
+ if i >= input_past_0:
1143
+ shape = shape_of(vi)
1144
+ vi = onnx.helper.make_tensor_value_info( # noqa: PLW2901
1145
+ vi.name,
1146
+ elem_type=vi.type.tensor_type.elem_type,
1147
+ shape=[shape[0], shape[1], shape[2], "max_seq_len", shape[4]],
1148
+ )
1149
+ new_inputs.extend([vi])
1150
+ new_inputs.extend([onnx.helper.make_tensor_value_info("past_sequence_length", onnx.TensorProto.INT32, shape=[1])])
1151
+ subg.ClearField("input")
1152
+ subg.input.extend(new_inputs)
1153
+
1154
+ new_outputs = []
1155
+ for i, vi in enumerate(subg.output):
1156
+ if i >= output_past_0:
1157
+ shape = shape_of(vi)
1158
+ vi = onnx.helper.make_tensor_value_info( # noqa: PLW2901
1159
+ vi.name,
1160
+ elem_type=vi.type.tensor_type.elem_type,
1161
+ shape=[shape[0], shape[1], shape[2], "max_seq_len", shape[4]],
1162
+ )
1163
+ new_outputs.extend([vi])
1164
+ subg.ClearField("output")
1165
+ subg.output.extend(new_outputs)
1166
+
1167
+ new_nodes = []
1168
+ for node in subg.node:
1169
+ new_node = node
1170
+ if node.op_type == "Attention":
1171
+ kwargs = kwargs_of(node)
1172
+ kwargs.update({"past_present_share_buffer": 1})
1173
+ nis = []
1174
+ nis.extend(node.input)
1175
+ while len(nis) < 6:
1176
+ nis.extend([""])
1177
+ if len(nis) < 7:
1178
+ nis.extend(["past_sequence_length"])
1179
+ new_node = onnx.helper.make_node("Attention", nis, node.output, name=node.name, **kwargs)
1180
+ new_nodes.extend([new_node])
1181
+ subg.ClearField("node")
1182
+ subg.node.extend(new_nodes)
1183
+ return subg
1184
+
1185
+
1186
+ def update_decoder_subgraph_use_decoder_masked_attention(
1187
+ subg: GraphProto, is_beam_search: bool, switch_attention: bool
1188
+ ) -> bool:
1189
+ """Update the Attention nodes to DecoderMaskedSelfAttention.
1190
+
1191
+ Args:
1192
+ subg (GraphProto): GraphProto of the decoder subgraph
1193
+ is_beam_search (bool): Boolean specifying if the sampling algo is BeamSearch
1194
+ switch_attention (bool): Boolean specifying if `Attention` is to be switched with `DecoderMaskedSelfAttention`
1195
+ """
1196
+ if is_beam_search:
1197
+ new_inputs = []
1198
+ for _i, vi in enumerate(subg.input):
1199
+ new_inputs.extend([vi])
1200
+
1201
+ # Add 2 BeamSearch specific inputs
1202
+ new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])])
1203
+ new_inputs.extend(
1204
+ [
1205
+ onnx.helper.make_tensor_value_info(
1206
+ "cache_indirection",
1207
+ onnx.TensorProto.INT32,
1208
+ shape=["batch_size", "beam_width", "max_seq_len"],
1209
+ )
1210
+ ]
1211
+ )
1212
+ subg.ClearField("input")
1213
+ subg.input.extend(new_inputs)
1214
+
1215
+ if switch_attention:
1216
+ decoder_masked_attention_supported_attr = [
1217
+ "past_present_share_buffer",
1218
+ "num_heads",
1219
+ "scale",
1220
+ "mask_filter_value",
1221
+ "domain",
1222
+ ]
1223
+
1224
+ new_nodes = []
1225
+ for node in subg.node:
1226
+ if node.op_type == "Attention":
1227
+ kwargs = kwargs_of(node)
1228
+ for k in kwargs.copy():
1229
+ # The Attention operator does not support different qkv hidden sizes when past/present
1230
+ # input/output exists (GPT2 model). Hence, we should never run into this.
1231
+ # But, if we do, do not go ahead with the optimization.
1232
+ if k == "qkv_hidden_sizes":
1233
+ return False
1234
+
1235
+ if k not in decoder_masked_attention_supported_attr:
1236
+ # Log the fact that we are removing certain attributes from the node
1237
+ # We don't need to log it for "unidirectional" as we are aware that
1238
+ # decoding attention kernels are unidirectional by definition.
1239
+ if k != "unidirectional":
1240
+ logger.warning(
1241
+ f"Removing attribute: {k} from Attention node while switching to DecoderMaskedSelfAttention"
1242
+ )
1243
+
1244
+ del kwargs[k]
1245
+
1246
+ nis = []
1247
+ nis.extend(node.input)
1248
+
1249
+ # Add 2 BeamSearch specific inputs
1250
+ if is_beam_search:
1251
+ while len(nis) < 7:
1252
+ nis.extend([""])
1253
+ if len(nis) < 8:
1254
+ nis.extend(["beam_width"])
1255
+ if len(nis) < 9:
1256
+ nis.extend(["cache_indirection"])
1257
+
1258
+ node = onnx.helper.make_node( # noqa: PLW2901
1259
+ "DecoderMaskedSelfAttention",
1260
+ nis,
1261
+ node.output,
1262
+ name=node.name,
1263
+ **kwargs,
1264
+ )
1265
+ new_nodes.extend([node])
1266
+ subg.ClearField("node")
1267
+ subg.node.extend(new_nodes)
1268
+
1269
+ return True
1270
+
1271
+
1272
+ def find_past_seq_len_usage(subg: GraphProto):
1273
+ """Correct graph which originally use dim of past_seq_len from input_ids's shape which is fixed to max_seq_len after
1274
+ shared past/present buffer
1275
+
1276
+ Args:
1277
+ subg (GraphProto): GraphProto of the decoder subgraph
1278
+ return:
1279
+ tensor_names_to_rename : set of tensor names which is equal to past_sequence_length
1280
+ nodes_to_remove : list of node to remove
1281
+ """
1282
+ tensor_names_to_rename = set()
1283
+ nodes_to_remove = []
1284
+
1285
+ graph_input_names = {inp.name: index for index, inp in enumerate(subg.input)}
1286
+
1287
+ input_name_to_nodes = {}
1288
+ output_name_to_node = {}
1289
+ for node in subg.node:
1290
+ for input_name in node.input:
1291
+ if input_name:
1292
+ if input_name not in input_name_to_nodes:
1293
+ input_name_to_nodes[input_name] = [node]
1294
+ else:
1295
+ input_name_to_nodes[input_name].append(node)
1296
+ for output_name in node.output:
1297
+ if output_name:
1298
+ output_name_to_node[output_name] = node
1299
+
1300
+ for node in subg.node:
1301
+ # find "past_key_self_0 --> [Transpose(past_key_self_0) --> Reshape(past_key_self_0)] --> Shape(past_key_self_0) --> Gather(*, 2)"
1302
+ # where [Transpose(past_key_self_0) --> Reshape(past_key_self_0)] may or may not exist
1303
+ if node.op_type == "Gather":
1304
+ if not node.input[1] or not node.input[0]:
1305
+ continue
1306
+
1307
+ # Find Gather node's index value
1308
+ shape_tensor_name, shape_index_name = (node.input[0], node.input[1])
1309
+ ini_gather_indices = None
1310
+ if "Constant_" in shape_index_name:
1311
+ # If shape_index_name refers to a Constant node
1312
+ for const_node in subg.node:
1313
+ if const_node.op_type == "Constant" and const_node.output[0] == shape_index_name:
1314
+ ini_gather_indices = const_node.attribute[0].t
1315
+ break
1316
+ else:
1317
+ # If shape_index_name refers to an initializer
1318
+ for tensor in subg.initializer:
1319
+ if tensor.name == shape_index_name:
1320
+ ini_gather_indices = tensor
1321
+ break
1322
+ if ini_gather_indices is None:
1323
+ continue
1324
+ gather_indices_arr = onnx.numpy_helper.to_array(ini_gather_indices)
1325
+
1326
+ if (
1327
+ gather_indices_arr.size == 1
1328
+ and gather_indices_arr.item() in {1, 2}
1329
+ and node.input[0] in output_name_to_node
1330
+ ):
1331
+ shape_node = output_name_to_node[shape_tensor_name]
1332
+ if not (shape_node.op_type == "Shape" and shape_node.input[0]):
1333
+ continue
1334
+
1335
+ if (
1336
+ shape_node.input[0] in graph_input_names
1337
+ and (
1338
+ shape_node.input[0].startswith("past_key_self_")
1339
+ or shape_node.input[0].startswith("past_value_self_")
1340
+ )
1341
+ and gather_indices_arr.item() == 2
1342
+ ):
1343
+ # "past_key_self_0 --> Shape(past_key_self_0) --> Gather(*, 2)"
1344
+ tensor_names_to_rename.add(node.output[0])
1345
+ nodes_to_remove.append(node)
1346
+ if len(input_name_to_nodes[shape_node.output[0]]) == 1:
1347
+ nodes_to_remove.append(shape_node)
1348
+ continue
1349
+
1350
+ if shape_node.input[0] not in output_name_to_node:
1351
+ continue
1352
+ reshape_node = output_name_to_node[shape_node.input[0]]
1353
+ if not (reshape_node.op_type == "Reshape" and reshape_node.input[0]):
1354
+ continue
1355
+ transpose_node = output_name_to_node[reshape_node.input[0]]
1356
+ if not (transpose_node.op_type == "Transpose" and transpose_node.input[0]):
1357
+ continue
1358
+
1359
+ if (
1360
+ transpose_node.input[0] in graph_input_names
1361
+ and (
1362
+ transpose_node.input[0].startswith("past_key_self_")
1363
+ or transpose_node.input[0].startswith("past_value_self_")
1364
+ )
1365
+ and gather_indices_arr.item() == 1
1366
+ ):
1367
+ # "past_key_self_0 --> Transpose(past_key_self_0) --> Reshape(past_key_self_0) --> Shape(past_key_self_0) --> Gather(*, 2)"
1368
+ tensor_names_to_rename.add(node.output[0])
1369
+ nodes_to_remove.extend([node, shape_node, reshape_node])
1370
+ if len(input_name_to_nodes[transpose_node.output[0]]) == 1:
1371
+ nodes_to_remove.append(transpose_node)
1372
+ continue
1373
+
1374
+ return tensor_names_to_rename, nodes_to_remove
1375
+
1376
+
1377
+ def add_cache_indirection_to_mha(model: OnnxModel, past_seq_len_name: str):
1378
+ # Add past_sequence_length and cache_indirection as inputs to all MultiHeadAttention ops and as inputs to model
1379
+ cache_indirection_name = "cache_indirection"
1380
+ mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node))
1381
+ for node in mha_nodes:
1382
+ # MHA op takes the following potential inputs:
1383
+ # query, key, value, bias, key_padding_mask, add_qk, past_key, past_value
1384
+ while len(node.input) < 8:
1385
+ node.input.append("")
1386
+ node.input.append(past_seq_len_name)
1387
+ node.input.append(cache_indirection_name)
1388
+
1389
+ model.model.graph.input.append(
1390
+ onnx.helper.make_tensor_value_info(
1391
+ cache_indirection_name, TensorProto.INT32, shape=["batch_size", "beam_width", "max_sequence_length"]
1392
+ ),
1393
+ )
1394
+ model.topological_sort()
1395
+ return model
1396
+
1397
+
1398
+ def add_output_qk_to_mha(model: OnnxModel, dtype: int = 0, skip_node_idxs: list[int] = []): # noqa: B006
1399
+ # Add output_qk as output to MultiHeadAttention ops and as outputs to model
1400
+ output_qk_basename = "output_cross_qk"
1401
+ output_qks = []
1402
+ mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node))
1403
+ for idx, node in enumerate(mha_nodes):
1404
+ # Skip MHA nodes where output_qk does not need to be added
1405
+ if idx in skip_node_idxs:
1406
+ continue
1407
+
1408
+ # Get `num_heads` attribute from MHA
1409
+ num_heads = 0
1410
+ for att in node.attribute:
1411
+ if att.name == "num_heads":
1412
+ num_heads = att.i
1413
+ break
1414
+
1415
+ # Get dtype for `output_qk` based on MHA bias if not provided
1416
+ output_qk_dtype = dtype
1417
+ if output_qk_dtype == 0:
1418
+ for i in model.model.graph.initializer:
1419
+ if i.name == node.input[3]:
1420
+ output_qk_dtype = i.data_type
1421
+ break
1422
+
1423
+ # Get `target_sequence_length` attribute from 4D input for key if it's a constant
1424
+ target_sequence_length = "target_sequence_length"
1425
+ for i in model.model.graph.input:
1426
+ if i.name == node.input[1]:
1427
+ target_sequence_length = i.type.tensor_type.shape.dim[2].dim_value
1428
+ break
1429
+
1430
+ # MHA op takes the following potential outputs:
1431
+ # output, present_key, present_value
1432
+ while len(node.output) < 3:
1433
+ node.output.append("")
1434
+
1435
+ output_qk_name = f"{output_qk_basename}_{idx // 2}"
1436
+ node.output.append(output_qk_name)
1437
+ output_qks.append(
1438
+ onnx.helper.make_tensor_value_info(
1439
+ output_qk_name,
1440
+ output_qk_dtype,
1441
+ shape=["batch_size", num_heads, "sequence_length", target_sequence_length],
1442
+ ),
1443
+ )
1444
+
1445
+ model.model.graph.output.extend(output_qks)
1446
+ model.topological_sort()
1447
+ return model
1448
+
1449
+
1450
+ def fix_past_sequence_length(model: OnnxModel):
1451
+ # Modify total_sequence_length = past_sequence_length + curr_sequence_length subgraph to calculate
1452
+ # past_sequence_length from the new `past_sequence_length` input of size 1D and type int32 instead of
1453
+ # from `past_key_self_0` since DecoderMaskedMultiHeadAttention (DMMHA) uses buffer sharing and
1454
+ # `past_key_self_0.shape[2] = max_sequence_length` instead of `past_key_self_0.shape[2] = past_sequence_length`
1455
+ # when buffer sharing is enabled
1456
+ #
1457
+ # Before:
1458
+ #
1459
+ # input_ids past_key_self_0
1460
+ # | |
1461
+ # Shape Shape
1462
+ # | |
1463
+ # Gather Gather
1464
+ # (idx=1) (idx=2)
1465
+ # | | \
1466
+ # +--------+--------+ Unsqueeze
1467
+ # |
1468
+ # Add
1469
+ #
1470
+ # After:
1471
+ #
1472
+ # input_ids past_sequence_length (1D)
1473
+ # | |
1474
+ # Shape Squeeze
1475
+ # | |
1476
+ # Gather Cast
1477
+ # (idx=1) (int64)
1478
+ # | | \
1479
+ # +--------+--------+ Unsqueeze
1480
+ # |
1481
+ # Add
1482
+
1483
+ # Constant names to be used
1484
+ past_seq_len_name = "past_sequence_length"
1485
+ past_seq_len_int32 = "past_seq_len_int32"
1486
+ past_seq_len_int64 = "past_seq_len_int64"
1487
+
1488
+ node = list(filter(lambda n: n.op_type == "LayerNormalization", model.model.graph.node))[0] # noqa: RUF015
1489
+
1490
+ base_path_hf = model.match_parent_path(
1491
+ node,
1492
+ ["Add", "Gather", "Tile", "Expand", "Unsqueeze", "Range"],
1493
+ [0, 1, 1, 0, 0, 0],
1494
+ )
1495
+ base_path_oai = model.match_parent_path(
1496
+ node,
1497
+ ["Add", "Slice"],
1498
+ [0, 1],
1499
+ )
1500
+ if base_path_hf is not None:
1501
+ base_path = base_path_hf
1502
+ elif base_path_oai is not None:
1503
+ base_path = base_path_oai
1504
+ else:
1505
+ logger.info("Cannot identify base path for fixing past_sequence_length subgraph")
1506
+ return
1507
+ base_node = base_path[-1]
1508
+
1509
+ if base_node.op_type == "Range":
1510
+ # Hugging Face implementation
1511
+ range_node = base_path[-1]
1512
+
1513
+ gather_path = model.match_parent_path(
1514
+ range_node,
1515
+ ["Gather", "Shape"],
1516
+ [0, 0],
1517
+ )
1518
+ if gather_path is None:
1519
+ logger.info("Cannot identify gather path for fixing past_sequence_length subgraph")
1520
+ return
1521
+
1522
+ add_path = model.match_parent_path(
1523
+ range_node,
1524
+ ["Add", "Gather", "Shape"],
1525
+ [1, 0, 0],
1526
+ )
1527
+ if add_path is None:
1528
+ logger.info("Cannot identify add path for fixing past_sequence_length subgraph")
1529
+ return
1530
+ add_node = add_path[0]
1531
+
1532
+ if gather_path != add_path[1:]:
1533
+ logger.info("Gather path and add path do not share the same nodes for calculating the past_sequence_length")
1534
+ return
1535
+
1536
+ # Remove `past_key_self_0 --> Shape --> Gather` connection
1537
+ constant_in_gather = list(filter(lambda n: n.output[0] == gather_path[0].input[1], model.model.graph.node))[0] # noqa: RUF015
1538
+ model.model.graph.node.remove(constant_in_gather)
1539
+ model.model.graph.node.remove(gather_path[0])
1540
+ model.model.graph.node.remove(gather_path[1])
1541
+
1542
+ # Add `past_seq_len_int64` as an input name to existing nodes
1543
+ range_node.input[0] = past_seq_len_int64
1544
+ add_node.input[0] = past_seq_len_int64
1545
+
1546
+ else:
1547
+ # OpenAI implementation
1548
+ input_ids_path = model.match_parent_path(
1549
+ base_node,
1550
+ ["Unsqueeze", "Add", "Gather", "Shape", "Reshape", "Transpose"],
1551
+ [2, 0, 0, 0, 0, 0],
1552
+ )
1553
+ if input_ids_path is None:
1554
+ logger.info("Cannot identify input_ids path for fixing past_sequence_length subgraph")
1555
+ return
1556
+ add_node = input_ids_path[1]
1557
+
1558
+ past_key_path = model.match_parent_path(
1559
+ base_node,
1560
+ ["Unsqueeze", "Gather", "Shape", "Reshape", "Transpose"],
1561
+ [1, 0, 0, 0, 0],
1562
+ )
1563
+ if past_key_path is None:
1564
+ logger.info("Cannot identify past_key path for fixing past_sequence_length subgraph")
1565
+ return
1566
+ unsqueeze_node = past_key_path[0]
1567
+
1568
+ if input_ids_path[2:] != past_key_path[1:]:
1569
+ logger.info(
1570
+ "The input_ids path and past_key path do not share the same nodes for calculating the past_sequence_length"
1571
+ )
1572
+ return
1573
+
1574
+ # Remove `past_key_self_0 --> Transpose --> Reshape --> Shape --> Gather` connection
1575
+ constant_in_gather = list(filter(lambda n: n.output[0] == past_key_path[1].input[1], model.model.graph.node))[0] # noqa: RUF015
1576
+ model.model.graph.node.remove(constant_in_gather)
1577
+ constant_in_reshape = list(filter(lambda n: n.output[0] == past_key_path[-2].input[1], model.model.graph.node))[ # noqa: RUF015
1578
+ 0
1579
+ ]
1580
+ model.model.graph.node.remove(constant_in_reshape)
1581
+ model.model.graph.node.remove(past_key_path[1])
1582
+ model.model.graph.node.remove(past_key_path[2])
1583
+ model.model.graph.node.remove(past_key_path[3])
1584
+ model.model.graph.node.remove(past_key_path[4])
1585
+
1586
+ # Add `past_seq_len_int64` as an input name to existing nodes
1587
+ unsqueeze_node.input[0] = past_seq_len_int64
1588
+ add_node.input[0] = past_seq_len_int64
1589
+
1590
+ # Add `past_sequence_length` as model input
1591
+ model.model.graph.input.append(
1592
+ onnx.helper.make_tensor_value_info(past_seq_len_name, TensorProto.INT32, shape=[1]),
1593
+ )
1594
+
1595
+ # Add `past_sequence_length --> Squeeze --> Cast` connection
1596
+ squeeze_node = onnx.helper.make_node(
1597
+ "Squeeze",
1598
+ inputs=[past_seq_len_name],
1599
+ outputs=[past_seq_len_int32],
1600
+ name=model.create_node_name("Squeeze"),
1601
+ )
1602
+ squeeze_output = onnx.helper.make_tensor_value_info(past_seq_len_int32, TensorProto.INT32, shape=[])
1603
+ cast_node = onnx.helper.make_node(
1604
+ "Cast",
1605
+ inputs=[past_seq_len_int32],
1606
+ outputs=[past_seq_len_int64],
1607
+ name=model.create_node_name("Cast"),
1608
+ to=TensorProto.INT64,
1609
+ )
1610
+ cast_output = onnx.helper.make_tensor_value_info(past_seq_len_int64, TensorProto.INT64, shape=[])
1611
+
1612
+ # Add new nodes to graph
1613
+ model.model.graph.node.extend([squeeze_node, cast_node])
1614
+ model.model.graph.value_info.extend([squeeze_output, cast_output])
1615
+ model.topological_sort()
1616
+ return model, past_seq_len_name
1617
+
1618
+
1619
+ def replace_mha_with_dmmha(model: OnnxModel, past_seq_len_name: str):
1620
+ # Add `beam_width` and `cache_indirection` as model inputs
1621
+ beam_width = "beam_width"
1622
+ cache_indirection = "cache_indirection"
1623
+
1624
+ model.model.graph.input.extend(
1625
+ [
1626
+ onnx.helper.make_tensor_value_info(beam_width, TensorProto.INT32, shape=[1]),
1627
+ onnx.helper.make_tensor_value_info(
1628
+ cache_indirection, TensorProto.INT32, shape=["batch_size", "beam_width", "max_sequence_length"]
1629
+ ),
1630
+ ]
1631
+ )
1632
+
1633
+ # Replace all `MultiHeadAttention` nodes with `DecoderMaskedMultiHeadAttention` nodes
1634
+ mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node))
1635
+ for idx, node in enumerate(mha_nodes):
1636
+ # Get `num_heads` attribute from MHA
1637
+ num_heads = 0
1638
+ for att in node.attribute:
1639
+ if att.name == "num_heads":
1640
+ num_heads = att.i
1641
+ break
1642
+
1643
+ # Make Q*K outputs for cross-attention layers, which happen every alternative layer
1644
+ qk_output_name = f"output_cross_qk_{idx // 2}"
1645
+ qk_output = onnx.helper.make_tensor_value_info(
1646
+ qk_output_name, TensorProto.FLOAT, shape=["batch_size", num_heads, 1, "encode_sequence_length / 2"]
1647
+ )
1648
+ if idx % 2 == 1:
1649
+ model.model.graph.output.append(qk_output)
1650
+
1651
+ # Make DMMHA node
1652
+ dmmha_node = onnx.helper.make_node(
1653
+ "DecoderMaskedMultiHeadAttention",
1654
+ inputs=[
1655
+ node.input[0], # query
1656
+ node.input[1], # key
1657
+ node.input[2], # value
1658
+ "", # mask_index
1659
+ "", # relative_position_bias
1660
+ node.input[6] if len(node.input) > 4 else "", # past_key
1661
+ node.input[7] if len(node.input) > 4 else "", # past_value
1662
+ past_seq_len_name, # past_sequence_length
1663
+ beam_width, # beam_width
1664
+ cache_indirection, # cache_indirection
1665
+ node.input[3], # bias
1666
+ ],
1667
+ outputs=[
1668
+ node.output[0], # output
1669
+ node.output[1] if len(node.input) > 4 else "", # present_key
1670
+ node.output[2] if len(node.input) > 4 else "", # present_value
1671
+ qk_output_name if idx % 2 == 1 else "", # output_cross_qk
1672
+ ],
1673
+ name=node.name.replace("MultiHeadAttention", "DecoderMaskedMultiHeadAttention"),
1674
+ domain="com.microsoft",
1675
+ num_heads=num_heads,
1676
+ output_qk=(idx % 2),
1677
+ past_present_share_buffer=1,
1678
+ )
1679
+ if idx % 2 == 0:
1680
+ # Remove empty string for output_cross_qk, which happens every alternative layer
1681
+ dmmha_node.output.remove("")
1682
+
1683
+ model.model.graph.node.remove(node)
1684
+ model.model.graph.node.extend([dmmha_node])
1685
+
1686
+ model.topological_sort()
1687
+ return model
1688
+
1689
+
1690
+ def replace_mha_with_gqa(
1691
+ model: OnnxModel,
1692
+ attn_mask: str,
1693
+ kv_num_heads: int = 0,
1694
+ world_size: int = 1,
1695
+ window_size: int = -1,
1696
+ ):
1697
+ # Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes
1698
+ #
1699
+ # attention_mask
1700
+ # / \
1701
+ # ReduceSum Shape
1702
+ # | |
1703
+ # Sub Gather
1704
+ # | |
1705
+ # seqlens_k total_sequence_length
1706
+ # | |
1707
+ # Cast to int32 Cast to int32
1708
+
1709
+ model.add_initializer(
1710
+ onnx.helper.make_tensor(
1711
+ name="one",
1712
+ data_type=TensorProto.INT64,
1713
+ dims=[1],
1714
+ vals=[1],
1715
+ )
1716
+ )
1717
+ reduce_sum_node = onnx.helper.make_node(
1718
+ "ReduceSum",
1719
+ inputs=[attn_mask, "one"],
1720
+ outputs=[attn_mask + "_row_sums"],
1721
+ name=model.create_node_name("ReduceSum"),
1722
+ )
1723
+ sub_node = onnx.helper.make_node(
1724
+ "Sub",
1725
+ inputs=[attn_mask + "_row_sums", "one"],
1726
+ outputs=["seqlens_k_int64"],
1727
+ name=model.create_node_name("Sub"),
1728
+ )
1729
+ seqlen_k_cast_node = onnx.helper.make_node(
1730
+ "Cast",
1731
+ inputs=["seqlens_k_int64"],
1732
+ outputs=["seqlens_k"],
1733
+ name=model.create_node_name("Cast"),
1734
+ to=TensorProto.INT32,
1735
+ )
1736
+ shape_node = onnx.helper.make_node(
1737
+ "Shape",
1738
+ inputs=[attn_mask],
1739
+ outputs=[attn_mask + "_shape"],
1740
+ name=model.create_node_name("Shape"),
1741
+ )
1742
+ gather_node = onnx.helper.make_node(
1743
+ "Gather",
1744
+ inputs=[attn_mask + "_shape", "one"],
1745
+ outputs=["total_seq_len_int64"],
1746
+ name=model.create_node_name("Gather"),
1747
+ axis=0,
1748
+ )
1749
+ total_seqlen_cast_node = onnx.helper.make_node(
1750
+ "Cast",
1751
+ inputs=["total_seq_len_int64"],
1752
+ outputs=["total_seq_len"],
1753
+ name=model.create_node_name("Cast"),
1754
+ to=TensorProto.INT32,
1755
+ )
1756
+ model.model.graph.node.extend(
1757
+ [
1758
+ reduce_sum_node,
1759
+ sub_node,
1760
+ seqlen_k_cast_node,
1761
+ shape_node,
1762
+ gather_node,
1763
+ total_seqlen_cast_node,
1764
+ ]
1765
+ )
1766
+
1767
+ # Replace MultiHeadAttention with GroupQueryAttention
1768
+ #
1769
+ # When replacing, fuse the following subgraph:
1770
+ #
1771
+ # root_input
1772
+ # / | \
1773
+ # MatMul MatMul MatMul
1774
+ # | | |
1775
+ # Add Add Add (optional Adds)
1776
+ # | | |
1777
+ # RotEmb RotEmb |
1778
+ # \ | /
1779
+ # MultiHeadAttention
1780
+ #
1781
+ # to this new subgraph:
1782
+ #
1783
+ # root_input
1784
+ # |
1785
+ # PackedMatMul (if possible)
1786
+ # |
1787
+ # PackedAdd (if possible)
1788
+ # |
1789
+ # GroupQueryAttention
1790
+ #
1791
+
1792
+ mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node))
1793
+ for idx, node in enumerate(mha_nodes):
1794
+ # Detect Q path to MHA
1795
+ q_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [0, 0, 0])
1796
+ q_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [0, 0])
1797
+
1798
+ q_rotary, q_add, q_matmul = None, None, None
1799
+ if q_path_1 is not None:
1800
+ q_rotary, q_add, q_matmul = q_path_1
1801
+ elif q_path_2 is not None:
1802
+ q_rotary, q_matmul = q_path_2
1803
+
1804
+ # Detect K path to MHA
1805
+ k_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [1, 0, 0])
1806
+ k_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [1, 0])
1807
+
1808
+ k_rotary, k_add, k_matmul = None, None, None
1809
+ if k_path_1 is not None:
1810
+ k_rotary, k_add, k_matmul = k_path_1
1811
+ elif k_path_2 is not None:
1812
+ k_rotary, k_matmul = k_path_2
1813
+
1814
+ # Detect V path to MHA
1815
+ v_path_1 = model.match_parent_path(node, ["Add", "MatMul"], [2, 0])
1816
+ v_path_2 = model.match_parent_path(node, ["MatMul"], [2])
1817
+
1818
+ v_add, v_matmul = None, None
1819
+ if v_path_1 is not None:
1820
+ v_add, v_matmul = v_path_1
1821
+ elif v_path_2 is not None:
1822
+ v_matmul = v_path_2[0]
1823
+
1824
+ # Get `interleaved` attribute from RotaryEmbedding
1825
+ interleaved = 0
1826
+ if q_rotary is not None and k_rotary is not None:
1827
+ for att in q_rotary.attribute:
1828
+ if att.name == "interleaved":
1829
+ interleaved = att.i
1830
+
1831
+ # Get `num_heads` attribute from MHA
1832
+ num_heads = 0
1833
+ for att in node.attribute:
1834
+ if att.name == "num_heads":
1835
+ num_heads = att.i
1836
+
1837
+ # Check if root_input to Q/K/V paths is the same
1838
+ root_input_is_same = q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0]
1839
+
1840
+ # Check if Q/K/V paths all have bias or all don't have bias
1841
+ all_paths_have_bias = q_add is not None and k_add is not None and v_add is not None
1842
+ all_paths_have_no_bias = q_add is None and k_add is None and v_add is None
1843
+
1844
+ # Make PackedMatMul node if possible
1845
+ q_input_to_attention, k_input_to_attention, v_input_to_attention = "", "", ""
1846
+ if root_input_is_same and (all_paths_have_bias or all_paths_have_no_bias):
1847
+ qw = NumpyHelper.to_array(model.get_initializer(q_matmul.input[1]))
1848
+ kw = NumpyHelper.to_array(model.get_initializer(k_matmul.input[1]))
1849
+ vw = NumpyHelper.to_array(model.get_initializer(v_matmul.input[1]))
1850
+
1851
+ dim = qw.shape[-1]
1852
+ qkv_weight = np.stack((qw, kw, vw), axis=1).reshape(dim, 3 * dim)
1853
+ qkv_weight = onnx.numpy_helper.from_array(qkv_weight, name=f"QKV_Weight_{idx}")
1854
+ model.add_initializer(qkv_weight)
1855
+
1856
+ packed_matmul_node = onnx.helper.make_node(
1857
+ "MatMul",
1858
+ inputs=[q_matmul.input[0], qkv_weight.name],
1859
+ outputs=[f"{qkv_weight.name}_output"],
1860
+ name=model.create_node_name("MatMul"),
1861
+ )
1862
+ model.model.graph.node.extend([packed_matmul_node])
1863
+ model.model.graph.node.remove(q_matmul)
1864
+ model.model.graph.node.remove(k_matmul)
1865
+ model.model.graph.node.remove(v_matmul)
1866
+ q_input_to_attention = packed_matmul_node.output[0]
1867
+
1868
+ # Make PackedAdd node if possible
1869
+ if all_paths_have_bias:
1870
+ qb = NumpyHelper.to_array(model.get_initializer(q_add.input[1]))
1871
+ kb = NumpyHelper.to_array(model.get_initializer(k_add.input[1]))
1872
+ vb = NumpyHelper.to_array(model.get_initializer(v_add.input[1]))
1873
+
1874
+ dim = qb.shape[-1]
1875
+ qkv_bias = np.stack((qb, kb, vb), axis=0).reshape(3 * dim)
1876
+ qkv_bias = onnx.numpy_helper.from_array(qkv_bias, name=f"QKV_Bias_{idx}")
1877
+ model.add_initializer(qkv_bias)
1878
+ packed_add_node = onnx.helper.make_node(
1879
+ "Add",
1880
+ inputs=[packed_matmul_node.output[0], qkv_bias.name],
1881
+ outputs=[f"{qkv_bias.name}_output"],
1882
+ )
1883
+ model.model.graph.node.extend([packed_add_node])
1884
+ model.model.graph.node.remove(q_add)
1885
+ model.model.graph.node.remove(k_add)
1886
+ model.model.graph.node.remove(v_add)
1887
+ q_input_to_attention = packed_add_node.output[0]
1888
+
1889
+ else:
1890
+ q_input_to_attention = q_matmul.output[0]
1891
+ k_input_to_attention = k_matmul.output[0]
1892
+ v_input_to_attention = v_matmul.output[0]
1893
+
1894
+ # Make GQA node
1895
+ gqa_node = onnx.helper.make_node(
1896
+ "GroupQueryAttention",
1897
+ inputs=[
1898
+ q_input_to_attention, # query
1899
+ k_input_to_attention, # key
1900
+ v_input_to_attention, # value
1901
+ node.input[6], # past_key
1902
+ node.input[7], # past_value
1903
+ seqlen_k_cast_node.output[0], # seqlens_k (for attention mask)
1904
+ total_seqlen_cast_node.output[0], # total_seq_len (for attention mask)
1905
+ (q_rotary.input[2] if q_rotary is not None else ""), # cos_cache (for rotary embeddings)
1906
+ (q_rotary.input[3] if q_rotary is not None else ""), # sin_cache (for rotary embeddings)
1907
+ ],
1908
+ outputs=node.output,
1909
+ name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"),
1910
+ domain="com.microsoft",
1911
+ num_heads=num_heads // world_size,
1912
+ kv_num_heads=(num_heads // world_size if kv_num_heads == 0 else kv_num_heads // world_size),
1913
+ local_window_size=window_size,
1914
+ do_rotary=int(q_rotary is not None and k_rotary is not None),
1915
+ rotary_interleaved=interleaved,
1916
+ )
1917
+ model.model.graph.node.remove(node)
1918
+ model.model.graph.node.extend([gqa_node])
1919
+
1920
+ if q_rotary is not None:
1921
+ model.model.graph.node.remove(q_rotary)
1922
+ if k_rotary is not None:
1923
+ model.model.graph.node.remove(k_rotary)
1924
+
1925
+ return model
1926
+
1927
+
1928
+ def update_decoder_subgraph_output_cross_attention(subg: GraphProto):
1929
+ input_self_past_0 = 1
1930
+ # w/wo attention mask, w/wo hidden_state
1931
+ graph_input_names = [gi.name for gi in subg.input]
1932
+ while input_self_past_0 < 3 and not graph_input_names[input_self_past_0].startswith("past"):
1933
+ input_self_past_0 += 1
1934
+ output_self_present_0 = 1
1935
+
1936
+ num_layers = (len(subg.output) - output_self_present_0) // 2
1937
+ input_cross_past_0 = 2 * num_layers + input_self_past_0
1938
+ past_key_cross_inputs = {subg.input[layer * 2 + input_cross_past_0].name: layer for layer in range(num_layers)}
1939
+ print(f" -- past_key_cross_inputs = {past_key_cross_inputs}")
1940
+
1941
+ input_past_key_cross_0_shape = shape_of(subg.input[input_cross_past_0])
1942
+ print(f"past_key_cross_0_shape is {input_past_key_cross_0_shape}")
1943
+ batch_size_dim = input_past_key_cross_0_shape[0]
1944
+ num_heads_dim = input_past_key_cross_0_shape[1]
1945
+ cross_seq_len_dim = input_past_key_cross_0_shape[2]
1946
+
1947
+ num_layer_output_qk = 0
1948
+ for node in subg.node:
1949
+ if (node.op_type == "DecoderMaskedMultiHeadAttention") and (node.input[1] in past_key_cross_inputs):
1950
+ print(f" -- add cross QK output from: node: {node.name} with output: {node.output}")
1951
+ num_layer_output_qk += 1
1952
+ layer = past_key_cross_inputs[node.input[1]]
1953
+ cross_attention_out_name = f"output_cross_qk_{layer}"
1954
+ appended_names = [""] * (3 - len(node.output))
1955
+ appended_names.append(cross_attention_out_name)
1956
+ node.output.extend(appended_names)
1957
+ node.attribute.extend([onnx.helper.make_attribute("output_qk", 1)])
1958
+
1959
+ cross_attention = onnx.helper.make_tensor_value_info(
1960
+ cross_attention_out_name,
1961
+ TensorProto.FLOAT,
1962
+ [batch_size_dim, num_heads_dim, 1, cross_seq_len_dim],
1963
+ )
1964
+ subg.output.extend([cross_attention])
1965
+ if num_layer_output_qk != num_layers:
1966
+ raise ValueError(f"Did not add cross QK for all layers{num_layers} vs {num_layer_output_qk}")
1967
+
1968
+
1969
+ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: ModelProto):
1970
+ input_self_past_0 = 1
1971
+ # w/wo attention mask, w/wo hidden_state
1972
+ graph_input_names = [gi.name for gi in subg.input]
1973
+ while input_self_past_0 < 3 and not graph_input_names[input_self_past_0].startswith("past"):
1974
+ input_self_past_0 += 1
1975
+ output_self_past_0 = 1
1976
+
1977
+ num_layers = int((len(subg.input) - input_self_past_0) / 4)
1978
+ input_cross_past_0 = 2 * num_layers + input_self_past_0
1979
+
1980
+ new_nodes = []
1981
+ old_nodes = []
1982
+ for node in subg.node:
1983
+ if node.op_type == "MultiHeadAttention":
1984
+ old_nodes.extend([node])
1985
+
1986
+ # If not all the MultiHeadAttention nodes are fused, this optimization is not applicable
1987
+ if len(old_nodes) < num_layers:
1988
+ return False
1989
+
1990
+ # Redirect the RelativePositionBias node's input from past_key_self_0.shape[2] to past_sequence_length.
1991
+ # There is only one RelativePositionBias node in T5 decoder subgraph.
1992
+ rel_pos_bias_node = None
1993
+ for node in subg.node:
1994
+ if node.op_type == "RelativePositionBias":
1995
+ rel_pos_bias_node = node
1996
+ break
1997
+
1998
+ decoder_masked_attention_supported_attr = [
1999
+ "past_present_share_buffer",
2000
+ "num_heads",
2001
+ "scale",
2002
+ "mask_filter_value",
2003
+ "domain",
2004
+ ]
2005
+
2006
+ target_squeezed_past_seq_name = "past_sequence_length_squeezed_int64"
2007
+ tensor_names_to_rename, nodes_to_remove = find_past_seq_len_usage(subg)
2008
+ if len(tensor_names_to_rename) > 0:
2009
+ for name_to_rename in tensor_names_to_rename:
2010
+ print(f"Found tensor name `{name_to_rename}` to be renamed to `{target_squeezed_past_seq_name}`")
2011
+ for nr in nodes_to_remove:
2012
+ print(f"Found node to remove: type = {nr.op_type}, name = {nr.name}")
2013
+
2014
+ squeeze_node = onnx.helper.make_node(
2015
+ "Squeeze",
2016
+ ["past_sequence_length"],
2017
+ ["past_sequence_length_squeezed"],
2018
+ name="node_past_sequence_length_squeeze",
2019
+ )
2020
+ cast_node = onnx.helper.make_node(
2021
+ "Cast",
2022
+ ["past_sequence_length_squeezed"],
2023
+ [target_squeezed_past_seq_name],
2024
+ name="node_past_sequence_length_squeeze_cast",
2025
+ to=TensorProto.INT64,
2026
+ )
2027
+ new_nodes.extend([squeeze_node, cast_node])
2028
+
2029
+ for node in subg.node:
2030
+ if len(node.output) > 0 and rel_pos_bias_node is not None and node.output[0] == rel_pos_bias_node.input[1]:
2031
+ cast_node = onnx.helper.make_node(
2032
+ "Cast",
2033
+ ["past_sequence_length"],
2034
+ ["past_sequence_length_int64"],
2035
+ name="past_sequence_length_cast",
2036
+ to=TensorProto.INT64,
2037
+ )
2038
+ node.input[1] = cast_node.output[0]
2039
+ new_nodes.extend([cast_node])
2040
+
2041
+ if node.op_type == "MultiHeadAttention":
2042
+ kwargs = kwargs_of(node)
2043
+ for k in kwargs.copy():
2044
+ if k not in decoder_masked_attention_supported_attr:
2045
+ del kwargs[k]
2046
+
2047
+ # note: This logic only apply to T5 model where there is no bias in Attention node.
2048
+ nis = [
2049
+ node.input[0], # query
2050
+ node.input[1], # key
2051
+ node.input[2], # value
2052
+ ]
2053
+
2054
+ nis.extend([node.input[4] if len(node.input) > 4 else ""]) # 2D mask
2055
+ nis.extend([node.input[5] if len(node.input) > 5 else ""]) # attention_bias
2056
+ nis.extend([node.input[6] if len(node.input) > 6 else ""]) # past_key
2057
+ nis.extend([node.input[7] if len(node.input) > 7 else ""]) # past_value
2058
+ nis.extend(["past_sequence_length"]) # past_sequence_length
2059
+ nis.extend(["beam_width"]) # beam_width
2060
+ nis.extend(["cache_indirection"]) # cache_indirection
2061
+ nis.extend([node.input[3] if len(node.input) > 3 else ""]) # bias
2062
+
2063
+ kwargs["past_present_share_buffer"] = 1
2064
+
2065
+ node = onnx.helper.make_node( # noqa: PLW2901
2066
+ "DecoderMaskedMultiHeadAttention",
2067
+ nis,
2068
+ node.output,
2069
+ name=node.name,
2070
+ **kwargs,
2071
+ )
2072
+
2073
+ if node not in nodes_to_remove:
2074
+ for index, name in enumerate(node.input):
2075
+ if name in tensor_names_to_rename:
2076
+ node.input[index] = target_squeezed_past_seq_name
2077
+ new_nodes.extend([node])
2078
+
2079
+ subg.ClearField("node")
2080
+ subg.node.extend(new_nodes)
2081
+ orig_input_names = [inp.name for inp in subg.input]
2082
+
2083
+ new_inputs = []
2084
+ for i, vi in enumerate(subg.input):
2085
+ if i >= input_self_past_0 and i < input_cross_past_0:
2086
+ shape = shape_of(vi)
2087
+ vi = onnx.helper.make_tensor_value_info( # noqa: PLW2901
2088
+ vi.name,
2089
+ elem_type=vi.type.tensor_type.elem_type,
2090
+ shape=[shape[0], shape[1], "max_seq_len", shape[3]],
2091
+ )
2092
+ new_inputs.extend([vi])
2093
+ if "past_sequence_length" not in orig_input_names:
2094
+ new_inputs.extend(
2095
+ [onnx.helper.make_tensor_value_info("past_sequence_length", onnx.TensorProto.INT32, shape=[1])]
2096
+ )
2097
+ if "beam_width" not in orig_input_names:
2098
+ new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])])
2099
+ if "cache_indirection" not in orig_input_names:
2100
+ new_inputs.extend(
2101
+ [
2102
+ onnx.helper.make_tensor_value_info(
2103
+ "cache_indirection",
2104
+ onnx.TensorProto.INT32,
2105
+ shape=["batch_size", "beam_width", "max_seq_len"],
2106
+ )
2107
+ ]
2108
+ )
2109
+ subg.ClearField("input")
2110
+ subg.input.extend(new_inputs)
2111
+
2112
+ new_outputs = []
2113
+ for i, vi in enumerate(subg.output):
2114
+ if i >= output_self_past_0:
2115
+ shape = shape_of(vi)
2116
+ vi = onnx.helper.make_tensor_value_info( # noqa: PLW2901
2117
+ vi.name,
2118
+ elem_type=vi.type.tensor_type.elem_type,
2119
+ shape=[shape[0], shape[1], "max_seq_len", shape[3]],
2120
+ )
2121
+ new_outputs.extend([vi])
2122
+ subg.ClearField("output")
2123
+ subg.output.extend(new_outputs)
2124
+
2125
+ return True
2126
+
2127
+
2128
+ def pack_qkv_for_decoder_masked_mha(model_proto: ModelProto):
2129
+ onnx_model = OnnxModel(model_proto)
2130
+ output_name_to_node = onnx_model.output_name_to_node()
2131
+
2132
+ nodes_to_add = []
2133
+ nodes_to_remove = []
2134
+ for node in onnx_model.nodes():
2135
+ if node.op_type == "DecoderMaskedMultiHeadAttention":
2136
+ if "past_key_cross" in node.input[1] and "past_value_cross" in node.input[2]:
2137
+ continue
2138
+ q_matmul = output_name_to_node[node.input[0]]
2139
+ k_matmul = output_name_to_node[node.input[1]]
2140
+ v_matmul = output_name_to_node[node.input[2]]
2141
+
2142
+ q_weight = onnx_model.get_initializer(q_matmul.input[1])
2143
+ k_weight = onnx_model.get_initializer(k_matmul.input[1])
2144
+ v_weight = onnx_model.get_initializer(v_matmul.input[1])
2145
+ if not (q_weight and k_weight and v_weight):
2146
+ return False
2147
+
2148
+ qw = NumpyHelper.to_array(q_weight)
2149
+ kw = NumpyHelper.to_array(k_weight)
2150
+ vw = NumpyHelper.to_array(v_weight)
2151
+
2152
+ qkv_weight = np.concatenate([qw, kw, vw], axis=1)
2153
+
2154
+ matmul_node_name = onnx_model.create_node_name("MatMul", name_prefix="MatMul_QKV")
2155
+ weight = onnx.helper.make_tensor(
2156
+ name=matmul_node_name + "_weight",
2157
+ data_type=(TensorProto.FLOAT if q_weight.data_type == 1 else TensorProto.FLOAT16),
2158
+ dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
2159
+ vals=qkv_weight.flatten().tolist(),
2160
+ )
2161
+
2162
+ model_proto.graph.initializer.extend([weight])
2163
+
2164
+ matmul_node = onnx.helper.make_node(
2165
+ "MatMul",
2166
+ inputs=[q_matmul.input[0], matmul_node_name + "_weight"],
2167
+ outputs=[matmul_node_name + "_out"],
2168
+ name=matmul_node_name,
2169
+ )
2170
+
2171
+ node.input[0] = matmul_node.output[0]
2172
+ node.input[1] = ""
2173
+ node.input[2] = ""
2174
+
2175
+ nodes_to_add.extend([matmul_node])
2176
+ nodes_to_remove.extend([q_matmul, k_matmul, v_matmul])
2177
+
2178
+ onnx_model.add_nodes(nodes_to_add)
2179
+ onnx_model.remove_nodes(nodes_to_remove)
2180
+ onnx_model.update_graph()
2181
+
2182
+ onnx_model.topological_sort()
2183
+
2184
+ return True
2185
+
2186
+
2187
+ def update_input_shapes_for_gpt2_decoder_model(decoder_onnx_path: str, use_external_data_format: bool = True):
2188
+ """Update the input shapes for the inputs "input_ids" and "position_ids" and make the sequence length dim value 1 for each of them.
2189
+ The decoder model will be over-written.
2190
+
2191
+ Args:
2192
+ decoder_onnx_path (str): Path of GPT-2 decoder onnx model
2193
+ use_external_data_format(bool): output tensors to external data or not.
2194
+ """
2195
+
2196
+ decoder_model_proto = onnx.load_model(decoder_onnx_path, load_external_data=True)
2197
+ for i in range(len(decoder_model_proto.graph.input)):
2198
+ if (
2199
+ decoder_model_proto.graph.input[i].name == "input_ids"
2200
+ or decoder_model_proto.graph.input[i].name == "position_ids"
2201
+ ):
2202
+ shape_dim_proto = decoder_model_proto.graph.input[i].type.tensor_type.shape.dim[1]
2203
+
2204
+ # Clear any existing dim_param first
2205
+ if shape_dim_proto.HasField("dim_param"):
2206
+ shape_dim_proto.Clear()
2207
+
2208
+ # Update dim_value to be 1
2209
+ shape_dim_proto.dim_value = 1
2210
+
2211
+ OnnxModel.save(
2212
+ decoder_model_proto,
2213
+ decoder_onnx_path,
2214
+ save_as_external_data=use_external_data_format,
2215
+ )
2216
+ return True
2217
+
2218
+
2219
+ def generate_gpt2_init_decoder(
2220
+ decoder_onnx_path: str,
2221
+ init_decoder_onnx_path: str,
2222
+ use_external_data_format: bool = True,
2223
+ ) -> bool:
2224
+ """Generates the initial decoder GPT2 subgraph and saves it for downstream use.
2225
+ The initial decoder model will be saved to init_decoder_onnx_path.
2226
+
2227
+ Args:
2228
+ decoder_onnx_path (str): Path of GPT-2 decoder onnx model
2229
+ init_decoder_onnx_path (str): Path of GPT-2 init decoder onnx model
2230
+ use_external_data_format(bool): output tensors to external data or not.
2231
+ """
2232
+ init_decoder_model_proto = onnx.load_model(decoder_onnx_path, load_external_data=True)
2233
+
2234
+ logits_output_name = init_decoder_model_proto.graph.output[0].name
2235
+
2236
+ gpt2_init_decoder_model = OnnxModel(init_decoder_model_proto)
2237
+
2238
+ output_name_to_node = gpt2_init_decoder_model.output_name_to_node()
2239
+ assert logits_output_name in output_name_to_node
2240
+
2241
+ logits_matmul_node = output_name_to_node[logits_output_name]
2242
+
2243
+ # Sanity check - the logits need to be produced by a MatMul node
2244
+ if logits_matmul_node.op_type != "MatMul":
2245
+ return False
2246
+
2247
+ # Try to find the last residual Add
2248
+ # For fp16, there are Casts along the way
2249
+
2250
+ # Normalization Node is : LayerNormalization
2251
+ logits_matmul_to_residual_add_path = gpt2_init_decoder_model.match_parent_path(
2252
+ logits_matmul_node,
2253
+ [
2254
+ "Cast",
2255
+ "LayerNormalization",
2256
+ "Add",
2257
+ "Add",
2258
+ "Cast",
2259
+ "MatMul",
2260
+ "Cast",
2261
+ "FastGelu",
2262
+ "Cast",
2263
+ "MatMul",
2264
+ "Cast",
2265
+ "LayerNormalization",
2266
+ "Add",
2267
+ ],
2268
+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
2269
+ )
2270
+
2271
+ # Normalization Node is : SkipLayerNormalization
2272
+ if logits_matmul_to_residual_add_path is None:
2273
+ logits_matmul_to_residual_add_path = gpt2_init_decoder_model.match_parent_path(
2274
+ logits_matmul_node,
2275
+ [
2276
+ "Cast",
2277
+ "SkipLayerNormalization",
2278
+ "Cast",
2279
+ "MatMul",
2280
+ "Cast",
2281
+ "FastGelu",
2282
+ "Cast",
2283
+ "MatMul",
2284
+ "Cast",
2285
+ "SkipLayerNormalization",
2286
+ ],
2287
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
2288
+ )
2289
+
2290
+ # Try without the Casts before and after the MatMuls
2291
+ if logits_matmul_to_residual_add_path is None:
2292
+ # Normalization Node is : LayerNormalization
2293
+ logits_matmul_to_residual_add_path = gpt2_init_decoder_model.match_parent_path(
2294
+ logits_matmul_node,
2295
+ [
2296
+ "LayerNormalization",
2297
+ "Add",
2298
+ "Add",
2299
+ "MatMul",
2300
+ "FastGelu",
2301
+ "MatMul",
2302
+ "LayerNormalization",
2303
+ "Add",
2304
+ ],
2305
+ [0, 0, 1, 0, 0, 0, 0, 0],
2306
+ )
2307
+
2308
+ # Normalization Node is : SkipLayerNormalization
2309
+ if logits_matmul_to_residual_add_path is None:
2310
+ logits_matmul_to_residual_add_path = gpt2_init_decoder_model.match_parent_path(
2311
+ logits_matmul_node,
2312
+ [
2313
+ "SkipLayerNormalization",
2314
+ "MatMul",
2315
+ "FastGelu",
2316
+ "MatMul",
2317
+ "SkipLayerNormalization",
2318
+ ],
2319
+ [0, 1, 0, 0, 0],
2320
+ )
2321
+
2322
+ # TODO(hasesh): Are there more permutations to try before returning ?
2323
+ if logits_matmul_to_residual_add_path is None:
2324
+ return False
2325
+
2326
+ residual_add_node = logits_matmul_to_residual_add_path[-1]
2327
+
2328
+ # If the last node in the pattern is SkipLayerNormalization, we need to adjust our pattern searches accordingly
2329
+ is_skiplayernorm_path = residual_add_node.op_type == "SkipLayerNormalization"
2330
+
2331
+ # Regular LayerNormalization path
2332
+ if not is_skiplayernorm_path:
2333
+ residual_add_to_attention_parent_index = 0
2334
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
2335
+ residual_add_node,
2336
+ ["Add", "Cast", "MatMul", "Attention"],
2337
+ [residual_add_to_attention_parent_index, 0, 0, 0],
2338
+ )
2339
+
2340
+ # Try other parent index of the residual Add node
2341
+ if residual_add_to_attention_path is None:
2342
+ residual_add_to_attention_parent_index = 1
2343
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
2344
+ residual_add_node,
2345
+ ["Add", "Cast", "MatMul", "Attention"],
2346
+ [residual_add_to_attention_parent_index, 0, 0, 0],
2347
+ )
2348
+
2349
+ # Try without the Casts before and after the MatMuls
2350
+ if residual_add_to_attention_path is None:
2351
+ residual_add_to_attention_parent_index = 0
2352
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
2353
+ residual_add_node,
2354
+ ["Add", "MatMul", "Attention"],
2355
+ [residual_add_to_attention_parent_index, 0, 0],
2356
+ )
2357
+
2358
+ # Try without the Casts before and after the MatMuls and other parent index of the residual Add node
2359
+ if residual_add_to_attention_path is None:
2360
+ residual_add_to_attention_parent_index = 1
2361
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
2362
+ residual_add_node,
2363
+ ["Add", "MatMul", "Attention"],
2364
+ [residual_add_to_attention_parent_index, 0, 0],
2365
+ )
2366
+
2367
+ # SkipLayerNormalization path
2368
+ else:
2369
+ residual_add_to_attention_parent_index = 0
2370
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
2371
+ residual_add_node,
2372
+ ["Cast", "MatMul", "Attention"],
2373
+ [residual_add_to_attention_parent_index, 0, 0],
2374
+ )
2375
+
2376
+ # Try other parent index of the residual Add node
2377
+ if residual_add_to_attention_path is None:
2378
+ residual_add_to_attention_parent_index = 1
2379
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
2380
+ residual_add_node,
2381
+ ["Cast", "MatMul", "Attention"],
2382
+ [residual_add_to_attention_parent_index, 0, 0],
2383
+ )
2384
+
2385
+ # Try without the Casts before and after the MatMuls
2386
+ if residual_add_to_attention_path is None:
2387
+ residual_add_to_attention_parent_index = 0
2388
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
2389
+ residual_add_node,
2390
+ ["MatMul", "Attention"],
2391
+ [residual_add_to_attention_parent_index, 0],
2392
+ )
2393
+
2394
+ # Try without the Casts before and after the MatMuls and other parent index of the residual Add node
2395
+ if residual_add_to_attention_path is None:
2396
+ residual_add_to_attention_parent_index = 1
2397
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
2398
+ residual_add_node,
2399
+ ["MatMul", "Attention"],
2400
+ [residual_add_to_attention_parent_index, 0],
2401
+ )
2402
+
2403
+ # TODO(hasesh): Are there more permutations to try before returning ?
2404
+ if residual_add_to_attention_path is None:
2405
+ return False
2406
+
2407
+ residual_add_to_add_parent_index = 0 if residual_add_to_attention_parent_index == 1 else 1
2408
+
2409
+ # Regular LayerNormalization path
2410
+ if not is_skiplayernorm_path:
2411
+ add_before_residual_add = gpt2_init_decoder_model.match_parent(
2412
+ residual_add_node, "Add", residual_add_to_add_parent_index
2413
+ )
2414
+
2415
+ # SkipLayerNormalization path
2416
+ else:
2417
+ add_before_residual_add = gpt2_init_decoder_model.match_parent(
2418
+ residual_add_node,
2419
+ "SkipLayerNormalization",
2420
+ residual_add_to_add_parent_index,
2421
+ )
2422
+
2423
+ if add_before_residual_add is None:
2424
+ return False
2425
+
2426
+ attention = residual_add_to_attention_path[-1]
2427
+ matmul_after_attention = residual_add_to_attention_path[-2]
2428
+
2429
+ slice_starts = onnx.helper.make_tensor(
2430
+ name="SliceLastTokenStarts",
2431
+ data_type=TensorProto.INT32,
2432
+ dims=[1],
2433
+ vals=[-1],
2434
+ )
2435
+
2436
+ slice_ends = onnx.helper.make_tensor(
2437
+ name="SliceLastTokenEnds",
2438
+ data_type=TensorProto.INT32,
2439
+ dims=[1],
2440
+ vals=[-2],
2441
+ )
2442
+
2443
+ slice_axes = onnx.helper.make_tensor(
2444
+ name="SliceLastTokenAxes",
2445
+ data_type=TensorProto.INT32,
2446
+ dims=[1],
2447
+ vals=[1],
2448
+ )
2449
+
2450
+ slice_steps = onnx.helper.make_tensor(
2451
+ name="SliceLastTokenSteps",
2452
+ data_type=TensorProto.INT32,
2453
+ dims=[1],
2454
+ vals=[-1],
2455
+ )
2456
+
2457
+ gpt2_init_decoder_model.add_initializer(slice_starts)
2458
+ gpt2_init_decoder_model.add_initializer(slice_ends)
2459
+ gpt2_init_decoder_model.add_initializer(slice_axes)
2460
+ gpt2_init_decoder_model.add_initializer(slice_steps)
2461
+
2462
+ # Add Slice node to the graph such that it consumes the output of Attention
2463
+ slice_0_output_name = "edge_modified_" + attention.output[0]
2464
+ slice_node_0 = onnx.helper.make_node(
2465
+ "Slice",
2466
+ inputs=[
2467
+ attention.output[0],
2468
+ "SliceLastTokenStarts",
2469
+ "SliceLastTokenEnds",
2470
+ "SliceLastTokenAxes",
2471
+ "SliceLastTokenSteps",
2472
+ ],
2473
+ outputs=[slice_0_output_name],
2474
+ name=gpt2_init_decoder_model.create_node_name("Slice", "GatherLastToken_0_"),
2475
+ )
2476
+
2477
+ # Add Slice node to the graph such that it consumes the output of Add before the residual Add
2478
+ # If the 'Add' output is produced by a SkipLayerNormalization node, then adjust its output
2479
+ # index appropriately
2480
+ add_before_residual_add_output = (
2481
+ add_before_residual_add.output[0] if not is_skiplayernorm_path else add_before_residual_add.output[3]
2482
+ )
2483
+
2484
+ slice_1_output_name = "edge_modified_" + add_before_residual_add.output[0]
2485
+ slice_node_1 = onnx.helper.make_node(
2486
+ "Slice",
2487
+ inputs=[
2488
+ add_before_residual_add_output,
2489
+ "SliceLastTokenStarts",
2490
+ "SliceLastTokenEnds",
2491
+ "SliceLastTokenAxes",
2492
+ "SliceLastTokenSteps",
2493
+ ],
2494
+ outputs=[slice_1_output_name],
2495
+ name=gpt2_init_decoder_model.create_node_name("Slice", "GatherLastToken_1_"),
2496
+ )
2497
+
2498
+ # Add the 2 Slice nodes
2499
+ gpt2_init_decoder_model.add_node(slice_node_0)
2500
+ gpt2_init_decoder_model.add_node(slice_node_1)
2501
+
2502
+ # Adjust the input(s) to the nodes consuming the outputs of the added Slice nodes
2503
+ gpt2_init_decoder_model.replace_node_input(matmul_after_attention, attention.output[0], slice_0_output_name)
2504
+ gpt2_init_decoder_model.replace_node_input(residual_add_node, add_before_residual_add_output, slice_1_output_name)
2505
+
2506
+ # Topologically sort the updated graph
2507
+ gpt2_init_decoder_model.topological_sort()
2508
+
2509
+ # Save the init decoder model
2510
+ OnnxModel.save(
2511
+ init_decoder_model_proto,
2512
+ init_decoder_onnx_path,
2513
+ save_as_external_data=use_external_data_format,
2514
+ )
2515
+ return True
2516
+
2517
+
2518
+ def make_dim_proto_numeric_t5(model, config):
2519
+ """Make dim_proto numeric.
2520
+
2521
+ Args:
2522
+ model: T5 encoder and decoder model.
2523
+ config: T5 config.
2524
+ """
2525
+ sequence_length = str(1)
2526
+ num_heads = str(config.num_heads)
2527
+ hidden_size = str(config.d_model)
2528
+ head_size = str(config.d_kv)
2529
+
2530
+ for tensor in model.graph.output:
2531
+ for dim_proto in tensor.type.tensor_type.shape.dim:
2532
+ if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
2533
+ sequence_length,
2534
+ num_heads,
2535
+ hidden_size,
2536
+ head_size,
2537
+ ]:
2538
+ dim_value = int(dim_proto.dim_param)
2539
+ dim_proto.Clear()
2540
+ dim_proto.dim_value = dim_value
2541
+
2542
+ for tensor in model.graph.input:
2543
+ for dim_proto in tensor.type.tensor_type.shape.dim:
2544
+ if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
2545
+ sequence_length,
2546
+ num_heads,
2547
+ hidden_size,
2548
+ head_size,
2549
+ ]:
2550
+ dim_value = int(dim_proto.dim_param)
2551
+ dim_proto.Clear()
2552
+ dim_proto.dim_value = dim_value
2553
+
2554
+
2555
+ def convert_generation_model(
2556
+ args: argparse.Namespace,
2557
+ generation_type: GenerationType = GenerationType.BEAMSEARCH,
2558
+ ):
2559
+ """Convert model according to command line arguments.
2560
+
2561
+ Args:
2562
+ args (argparse.Namespace): arguments parsed from command line
2563
+ """
2564
+ is_gpt2: bool = args.model_type == "gpt2"
2565
+ is_beamsearch: bool = generation_type == GenerationType.BEAMSEARCH
2566
+ is_greedysearch: bool = generation_type == GenerationType.GREEDYSEARCH
2567
+ is_sampling: bool = generation_type == GenerationType.SAMPLING
2568
+ past_present_share_buffer: bool = args.past_present_share_buffer
2569
+
2570
+ logger.info(f"**** past_present_share_buffer={past_present_share_buffer}")
2571
+ if len(args.op_block_list) == 1 and args.op_block_list[0] == "auto":
2572
+ if is_gpt2 and args.precision == Precision.FLOAT16.value:
2573
+ args.op_block_list = [
2574
+ "Add",
2575
+ "LayerNormalization",
2576
+ "SkipLayerNormalization",
2577
+ "FastGelu",
2578
+ ]
2579
+ logger.info(f"**** Setting op_block_list to {args.op_block_list}")
2580
+ logger.info("**** use --op_block_list if you want to override the block operator list.")
2581
+ else:
2582
+ args.op_block_list = []
2583
+
2584
+ if is_greedysearch or is_sampling:
2585
+ if not is_gpt2:
2586
+ raise NotImplementedError("Currently only gpt2 with greedy search/sampling is supported")
2587
+ if args.output_sequences_scores:
2588
+ raise NotImplementedError("output_sequences_scores currently is not supported in greedy search/sampling")
2589
+ if args.output_token_scores:
2590
+ raise NotImplementedError("output_token_scores currently is not supported in greedy search/sampling")
2591
+
2592
+ # For BeamSearch, sharing buffers for past and present states is only supported
2593
+ # when using `use_decoder_masked_attention`
2594
+ if past_present_share_buffer and is_beamsearch and not args.use_decoder_masked_attention:
2595
+ raise ValueError(
2596
+ "`use_decoder_masked_attention` MUST be turned on to use `past_present_share_buffer` in case of BeamSearch"
2597
+ )
2598
+
2599
+ # For any kind of sampling, using decoder masked multihead attention is only supported
2600
+ # when using `past_present_share_buffer`
2601
+ if args.use_decoder_masked_attention and not past_present_share_buffer:
2602
+ raise ValueError("`past_present_share_buffer` MUST be turned on to use `use_decoder_masked_attention`")
2603
+
2604
+ # For any kind of sampling, using decoder masked multihead attention is only supported
2605
+ # on GPUs
2606
+ if args.use_decoder_masked_attention and not args.use_gpu:
2607
+ raise ValueError("`use_decoder_masked_attention` option is only supported on GPUs")
2608
+
2609
+ if is_gpt2:
2610
+ if args.decoder_onnx and os.path.exists(args.decoder_onnx):
2611
+ logger.info(f"skip convert_to_onnx since path existed: {args.decoder_onnx}")
2612
+ else:
2613
+ if not args.decoder_onnx:
2614
+ onnx_filename = f"{args.model_name_or_path}_past_{args.precision}.onnx"
2615
+ args.decoder_onnx = Path(Path(args.output).parent, onnx_filename).as_posix()
2616
+
2617
+ logger.info(f"Convert GPT model {args.model_name_or_path} to onnx {args.decoder_onnx} ...")
2618
+ gpt2_to_onnx(args)
2619
+ else: # t5 or mt5
2620
+ if args.decoder_onnx and args.encoder_decoder_init_onnx:
2621
+ logger.info(
2622
+ f"skip convert_to_onnx since paths specified: {args.decoder_onnx} and {args.encoder_decoder_init_onnx}"
2623
+ )
2624
+ else:
2625
+ logger.info(f"Convert model {args.model_name_or_path} to onnx ...")
2626
+ t5_to_onnx(args)
2627
+
2628
+ # We only want to pad the logits MatMul weight in the decoder for fp16 models.
2629
+ # The inherent assumption is that fp16 models run on GPU for which all
2630
+ # dims need to be a multiple of 8 to leverage tensor cores.
2631
+ # NOTE: We currently only support padding the MatMul logits weight for GPT2 GreedySearch/BeamSearch.
2632
+ # This can be expanded to other models/decoding strategies later
2633
+ logits_matmul_weight_padded = False
2634
+ if (
2635
+ not args.disable_pad_vocab_size
2636
+ and args.precision == Precision.FLOAT16.value
2637
+ and is_gpt2
2638
+ and (is_beamsearch or is_greedysearch or is_sampling)
2639
+ ):
2640
+ logger.info(
2641
+ f"Pad logits MatMul weights for optimal MatMul perf in fp16 on {args.decoder_onnx}. "
2642
+ "The file will be overwritten."
2643
+ )
2644
+ logits_matmul_weight_padded = pad_weights_of_logits_matmul(args.decoder_onnx, args.use_external_data_format)
2645
+ if not logits_matmul_weight_padded:
2646
+ logger.warning(
2647
+ "Tried and failed to pad logits MatMul weights. Performance may be sub-optimal for this MatMul"
2648
+ )
2649
+
2650
+ gpt2_init_decoder_generated = False
2651
+ gpt2_init_decoder_onnx_path = None
2652
+ if (
2653
+ not args.disable_separate_gpt2_decoder_for_init_run
2654
+ and is_gpt2
2655
+ and (is_beamsearch or is_greedysearch or is_sampling)
2656
+ ):
2657
+ logger.info(f"Creating an initial run GPT2 decoder from {args.decoder_onnx}. ")
2658
+
2659
+ gpt2_init_decoder_onnx_filename = f"gpt2_init_past_{args.precision}.onnx"
2660
+
2661
+ gpt2_init_decoder_onnx_path = Path(Path(args.output).parent, gpt2_init_decoder_onnx_filename).as_posix()
2662
+
2663
+ gpt2_init_decoder_generated = generate_gpt2_init_decoder(
2664
+ args.decoder_onnx,
2665
+ gpt2_init_decoder_onnx_path,
2666
+ args.use_external_data_format,
2667
+ )
2668
+
2669
+ if not gpt2_init_decoder_generated:
2670
+ logger.warning(
2671
+ "Tried and failed to generate the init decoder GPT2 model. "
2672
+ "Performance may be sub-optimal for the initial decoding run"
2673
+ )
2674
+
2675
+ # Update the graph input shapes for the non-initial decoder model to account
2676
+ # for the fact that the sequence length will always be 1
2677
+ if gpt2_init_decoder_generated and not update_input_shapes_for_gpt2_decoder_model(
2678
+ args.decoder_onnx, args.use_external_data_format
2679
+ ):
2680
+ # Can't proceed further - better to raise an exception
2681
+ raise ValueError("Could not update the input shapes for the non-initial decoder subgraph.")
2682
+
2683
+ # If the user explicitly requests running shape inference or if we padded/mutated
2684
+ # weight(s)/input shape(s) in the decoder, we want to run shape inference to capture the new
2685
+ # shapes
2686
+ if logits_matmul_weight_padded or args.run_shape_inference or gpt2_init_decoder_generated:
2687
+ logger.info(f"Run symbolic shape inference on {args.decoder_onnx}. The file will be overwritten.")
2688
+ shape_inference(args.decoder_onnx, args.use_external_data_format)
2689
+ if gpt2_init_decoder_generated:
2690
+ logger.info(f"Run symbolic shape inference on {gpt2_init_decoder_onnx_path}. The file will be overwritten.")
2691
+ shape_inference(gpt2_init_decoder_onnx_path, args.use_external_data_format)
2692
+
2693
+ if is_gpt2:
2694
+ config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
2695
+ elif args.model_type == "t5":
2696
+ config = T5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
2697
+ else:
2698
+ config = MT5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
2699
+
2700
+ if args.verbose:
2701
+ logger.info(f"Config={config}")
2702
+
2703
+ eos_token_id = config.eos_token_id
2704
+ pad_token_id = config.eos_token_id if is_gpt2 else config.pad_token_id
2705
+ vocab_size = config.vocab_size
2706
+
2707
+ # if vocab_size is given in parameters use that.
2708
+ if args.vocab_size != -1:
2709
+ vocab_size = args.vocab_size
2710
+
2711
+ if args.eos_token_id != -1:
2712
+ eos_token_id = args.eos_token_id
2713
+ if args.pad_token_id != -1:
2714
+ pad_token_id = args.pad_token_id
2715
+
2716
+ decoder_model = onnx.load_model(args.decoder_onnx, load_external_data=True)
2717
+ decoder_model.graph.name = f"{args.model_type} decoder"
2718
+
2719
+ gpt2_init_decoder_model = None
2720
+ if args.model_type == "gpt2":
2721
+ verify_gpt2_subgraph(decoder_model.graph, args.precision)
2722
+
2723
+ # If we generated the init decoder model, verify that as well
2724
+ if gpt2_init_decoder_generated:
2725
+ gpt2_init_decoder_model = onnx.load_model(gpt2_init_decoder_onnx_path, load_external_data=True)
2726
+ gpt2_init_decoder_model.graph.name = f"{args.model_type} init decoder"
2727
+ verify_gpt2_subgraph(gpt2_init_decoder_model.graph, args.precision)
2728
+ else:
2729
+ verify_t5_decoder_subgraph(decoder_model.graph, args.precision)
2730
+
2731
+ inputs = None
2732
+ if is_beamsearch:
2733
+ inputs = [
2734
+ "input_ids",
2735
+ "max_length",
2736
+ "min_length",
2737
+ "num_beams",
2738
+ "num_return_sequences",
2739
+ "length_penalty",
2740
+ "repetition_penalty",
2741
+ ]
2742
+ elif is_greedysearch or is_sampling:
2743
+ inputs = [
2744
+ "input_ids",
2745
+ "max_length",
2746
+ "min_length",
2747
+ "repetition_penalty",
2748
+ ]
2749
+
2750
+ if args.vocab_mask:
2751
+ inputs.append("vocab_mask")
2752
+ else:
2753
+ inputs.append("")
2754
+
2755
+ if args.prefix_vocab_mask:
2756
+ inputs.append("prefix_vocab_mask")
2757
+ else:
2758
+ inputs.append("")
2759
+
2760
+ if args.custom_attention_mask:
2761
+ inputs.append("attention_mask")
2762
+ else:
2763
+ inputs.append("")
2764
+
2765
+ if is_sampling:
2766
+ if args.custom and args.presence_mask:
2767
+ inputs.append("presence_mask")
2768
+ else:
2769
+ inputs.append("")
2770
+
2771
+ if args.seed:
2772
+ inputs.append("seed")
2773
+
2774
+ outputs = ["sequences"]
2775
+ if args.output_sequences_scores:
2776
+ outputs.append("sequences_scores")
2777
+
2778
+ if args.output_token_scores:
2779
+ assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores"
2780
+ outputs.append("scores")
2781
+
2782
+ node = None
2783
+ if is_beamsearch:
2784
+ node = onnx.helper.make_node(
2785
+ "BeamSearch",
2786
+ inputs=inputs,
2787
+ outputs=outputs,
2788
+ name=f"BeamSearch_{args.model_type}",
2789
+ )
2790
+ elif is_greedysearch:
2791
+ node = onnx.helper.make_node(
2792
+ "GreedySearch",
2793
+ inputs=inputs,
2794
+ outputs=outputs,
2795
+ name=f"GreedySearch_{args.model_type}",
2796
+ )
2797
+ elif is_sampling:
2798
+ node = onnx.helper.make_node(
2799
+ "Sampling",
2800
+ inputs=inputs,
2801
+ outputs=outputs,
2802
+ name=f"Sampling_{args.model_type}",
2803
+ )
2804
+
2805
+ node.domain = "com.microsoft"
2806
+
2807
+ attr_to_extend = None
2808
+ if is_beamsearch:
2809
+ attr_to_extend = [
2810
+ onnx.helper.make_attribute("eos_token_id", eos_token_id),
2811
+ onnx.helper.make_attribute("pad_token_id", pad_token_id),
2812
+ onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
2813
+ onnx.helper.make_attribute("early_stopping", 1 if args.early_stopping else 0),
2814
+ onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
2815
+ ]
2816
+ elif is_greedysearch:
2817
+ attr_to_extend = [
2818
+ onnx.helper.make_attribute("eos_token_id", eos_token_id),
2819
+ onnx.helper.make_attribute("pad_token_id", pad_token_id),
2820
+ onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
2821
+ onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
2822
+ ]
2823
+ elif is_sampling:
2824
+ attr_to_extend = [
2825
+ onnx.helper.make_attribute("eos_token_id", eos_token_id),
2826
+ onnx.helper.make_attribute("pad_token_id", pad_token_id),
2827
+ onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
2828
+ onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
2829
+ onnx.helper.make_attribute("temperature", args.temperature),
2830
+ onnx.helper.make_attribute("top_p", args.top_p),
2831
+ onnx.helper.make_attribute("filter_value", args.filter_value),
2832
+ onnx.helper.make_attribute("min_tokens_to_keep", args.min_tokens_to_keep),
2833
+ onnx.helper.make_attribute("custom", args.custom),
2834
+ onnx.helper.make_attribute("presence_penalty", args.presence_penalty),
2835
+ ]
2836
+
2837
+ # Explicitly pass in the vocab size via an attribute
2838
+ if logits_matmul_weight_padded:
2839
+ attr_to_extend.extend([onnx.helper.make_attribute("vocab_size", vocab_size)])
2840
+
2841
+ node.attribute.extend(attr_to_extend)
2842
+
2843
+ initializers = []
2844
+
2845
+ if args.model_type in ["t5", "mt5"]:
2846
+ if args.run_shape_inference:
2847
+ logger.info(f"Symbolic shape inference on {args.encoder_decoder_init_onnx}. The file will be overwritten.")
2848
+ shape_inference(args.encoder_decoder_init_onnx, args.use_external_data_format)
2849
+ encoder_model = onnx.load_model(args.encoder_decoder_init_onnx, load_external_data=True)
2850
+ suffix = "encoder" if len(encoder_model.graph.input) == 2 else "encoder and decoder init"
2851
+ encoder_model.graph.name = f"{args.model_type} {suffix}"
2852
+ verify_t5_encoder_decoder_init_subgraph(encoder_model.graph, args.precision)
2853
+
2854
+ make_dim_proto_numeric_t5(encoder_model, config)
2855
+ make_dim_proto_numeric_t5(decoder_model, config)
2856
+
2857
+ # Update decoder subgraph in preparation to use past present share buffer
2858
+ if past_present_share_buffer:
2859
+ if not args.use_decoder_masked_attention:
2860
+ raise ValueError("past_present_share_buffer is only supported with use_decoder_masked_attention")
2861
+
2862
+ logger.info(
2863
+ "*****update t5 decoder subgraph to share past/present buffer and use decoder_masked_multihead_attention*****"
2864
+ )
2865
+ if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph):
2866
+ logger.info("*****update t5 decoder subgraph successfully!!!*****")
2867
+ else:
2868
+ logger.info("*****DecoderMaskedMultiHeadAttention is not applied to T5 decoder*****")
2869
+
2870
+ if pack_qkv_for_decoder_masked_mha(decoder_model):
2871
+ logger.info("*****pack qkv for decoder masked mha successfully!!!*****")
2872
+ else:
2873
+ logger.info("*****pack qkv for decoder masked mha failed!!!*****")
2874
+
2875
+ if not args.disable_shared_initializers:
2876
+ # Unique shared initializers from the decoder and decoder_init could reduce memory usage in inference.
2877
+ initializers = get_shared_initializers(encoder_model, decoder_model)
2878
+ logger.info(
2879
+ f"{len(initializers)} shared initializers ({[i.name for i in initializers]}) in encoder and decoder subgraphs are moved to the main graph"
2880
+ )
2881
+
2882
+ # TODO(tianleiwu): investigate the following which causes error in inference
2883
+ # Move initializer from subgraph to main graph could reduce memory usage in inference.
2884
+ # moved_initializers = move_initializers(encoder_model.graph)
2885
+ # logger.info(
2886
+ # f"{len(moved_initializers)} initializers ({[i.name for i in moved_initializers]}) from the encoder are moved to the main graph"
2887
+ # )
2888
+ # initializers.extend(moved_initializers)
2889
+
2890
+ assert config.decoder_start_token_id >= 0, "decoder_start_token_id should be >= 0"
2891
+
2892
+ node.attribute.extend(
2893
+ [
2894
+ onnx.helper.make_attribute("encoder", encoder_model.graph),
2895
+ onnx.helper.make_attribute("decoder", decoder_model.graph),
2896
+ onnx.helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id),
2897
+ ]
2898
+ )
2899
+ else:
2900
+ if gpt2_init_decoder_generated:
2901
+ # Move shared initializers (shared between init decoder and decoder models) to the main
2902
+ # graph and remove them from these models
2903
+ if not args.disable_shared_initializers:
2904
+ # Unique shared initializers from the decoder and decoder_init could reduce memory usage in inference.
2905
+ initializers = get_shared_initializers(gpt2_init_decoder_model, decoder_model)
2906
+ logger.info(
2907
+ f"{len(initializers)} shared initializers ({[i.name for i in initializers]}) in decoder and init decoder subgraphs are moved to the main graph"
2908
+ )
2909
+
2910
+ # Update init decoder subgraph in preparation to use past present share buffer
2911
+ if past_present_share_buffer:
2912
+ logger.info("*****update init decoder subgraph to make past and present share buffer******************")
2913
+ update_decoder_subgraph_past_present_share_buffer(gpt2_init_decoder_model.graph)
2914
+
2915
+ # Update init decoder subgraph in preparation to use DecoderMaskedSelfAttention
2916
+ # NOTE: Even if we will not use DecoderMaskedSelfAttention in the init decoder subgraph
2917
+ # it makes the runtime changes cleaner if we keep both the init decoder and decoder subgraphs
2918
+ # same in terms of the subgraph inputs.
2919
+ if args.use_decoder_masked_attention and not update_decoder_subgraph_use_decoder_masked_attention(
2920
+ gpt2_init_decoder_model.graph, is_beamsearch, False
2921
+ ):
2922
+ raise ValueError("Could not update the init decoder subgraph to use DecoderMaskedSelfAttention")
2923
+
2924
+ node.attribute.append(onnx.helper.make_attribute("init_decoder", gpt2_init_decoder_model.graph))
2925
+ else:
2926
+ # Move initializer from subgraph to main graph could reduce memory usage in inference.
2927
+ initializers = move_initializers(decoder_model.graph)
2928
+ logger.info(f"{len(initializers)} initializers from the decoder are moved to the main graph")
2929
+
2930
+ # Update decoder subgraph in preparation to use past present share buffer
2931
+ if past_present_share_buffer:
2932
+ logger.info("*****update decoder subgraph to make past and present share buffer******************")
2933
+ update_decoder_subgraph_past_present_share_buffer(decoder_model.graph)
2934
+
2935
+ # Update decoder subgraph in preparation to use DecoderMaskedSelfAttention
2936
+ if args.use_decoder_masked_attention and not update_decoder_subgraph_use_decoder_masked_attention(
2937
+ decoder_model.graph, is_beamsearch, True
2938
+ ):
2939
+ raise ValueError("Could not update the decoder subgraph to use DecoderMaskedSelfAttention")
2940
+
2941
+ node.attribute.append(onnx.helper.make_attribute("decoder", decoder_model.graph))
2942
+
2943
+ # graph inputs
2944
+ input_ids = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"])
2945
+ max_length = onnx.helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
2946
+ min_length = onnx.helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
2947
+ num_beams = onnx.helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1])
2948
+ num_return_sequences = onnx.helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1])
2949
+ length_penalty = onnx.helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
2950
+ repetition_penalty = onnx.helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1])
2951
+
2952
+ graph_inputs = None
2953
+ if is_beamsearch:
2954
+ graph_inputs = [
2955
+ input_ids,
2956
+ max_length,
2957
+ min_length,
2958
+ num_beams,
2959
+ num_return_sequences,
2960
+ length_penalty,
2961
+ repetition_penalty,
2962
+ ]
2963
+ elif is_greedysearch or is_sampling:
2964
+ graph_inputs = [
2965
+ input_ids,
2966
+ max_length,
2967
+ min_length,
2968
+ repetition_penalty,
2969
+ ]
2970
+
2971
+ if args.vocab_mask:
2972
+ vocab_mask = onnx.helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size])
2973
+ graph_inputs.append(vocab_mask)
2974
+
2975
+ if args.prefix_vocab_mask:
2976
+ prefix_vocab_mask = onnx.helper.make_tensor_value_info(
2977
+ "prefix_vocab_mask", TensorProto.INT32, ["batch_size", vocab_size]
2978
+ )
2979
+ graph_inputs.append(prefix_vocab_mask)
2980
+
2981
+ if args.custom_attention_mask:
2982
+ attention_mask = onnx.helper.make_tensor_value_info(
2983
+ "attention_mask", TensorProto.INT32, ["batch_size", "sequence_length"]
2984
+ )
2985
+ graph_inputs.append(attention_mask)
2986
+
2987
+ if args.custom and args.presence_mask:
2988
+ presence_mask = onnx.helper.make_tensor_value_info(
2989
+ "presence_mask", TensorProto.INT32, ["batch_size", vocab_size]
2990
+ )
2991
+ graph_inputs.append(presence_mask)
2992
+
2993
+ if is_sampling and args.seed:
2994
+ seed = onnx.helper.make_tensor_value_info("seed", TensorProto.INT32, [1])
2995
+ graph_inputs.append(seed)
2996
+
2997
+ # graph outputs
2998
+ sequences = None
2999
+ if is_beamsearch:
3000
+ sequences = onnx.helper.make_tensor_value_info(
3001
+ "sequences",
3002
+ TensorProto.INT32,
3003
+ ["batch_size", "num_return_sequences", "max_length"],
3004
+ )
3005
+ elif is_greedysearch or is_sampling:
3006
+ sequences = onnx.helper.make_tensor_value_info(
3007
+ "sequences",
3008
+ TensorProto.INT32,
3009
+ ["batch_size", "max_length"],
3010
+ )
3011
+
3012
+ graph_outputs = [sequences]
3013
+
3014
+ if args.output_sequences_scores:
3015
+ sequences_scores = onnx.helper.make_tensor_value_info(
3016
+ "sequences_scores",
3017
+ TensorProto.FLOAT,
3018
+ ["batch_size", "num_return_sequences"],
3019
+ )
3020
+ graph_outputs.append(sequences_scores)
3021
+
3022
+ if args.output_token_scores:
3023
+ scores = onnx.helper.make_tensor_value_info(
3024
+ "scores",
3025
+ TensorProto.FLOAT,
3026
+ ["max_length - sequence_length", "batch_size", "num_beams", vocab_size],
3027
+ )
3028
+ graph_outputs.append(scores)
3029
+
3030
+ new_graph = onnx.helper.make_graph(
3031
+ [node],
3032
+ (f"{args.model_type} beam search" if not is_greedysearch else f"{args.model_type} greedy search"),
3033
+ graph_inputs,
3034
+ graph_outputs,
3035
+ initializers,
3036
+ )
3037
+
3038
+ # Create the model
3039
+ new_model = onnx.helper.make_model(
3040
+ new_graph,
3041
+ producer_name="onnxruntime.transformers",
3042
+ opset_imports=decoder_model.opset_import,
3043
+ )
3044
+
3045
+ # TODO(tianleiwu): move shared initializers from T5 encoder and decoder subgraphs to parent graph to save memory.
3046
+ if args.use_external_data_format:
3047
+ from packaging import version # noqa: PLC0415
3048
+
3049
+ if version.parse(onnx.__version__) < version.parse("1.12.0"):
3050
+ logger.warning("Require onnx >= 1.12 to save large (>2GB) model!")
3051
+
3052
+ OnnxModel.save(
3053
+ new_model,
3054
+ args.output,
3055
+ save_as_external_data=True,
3056
+ all_tensors_to_one_file=True,
3057
+ )
3058
+ else:
3059
+ onnx.save(new_model, args.output)
3060
+ logger.info(f"model save to {args.output}")
3061
+
3062
+
3063
+ def test_torch_performance(
3064
+ args: argparse.Namespace,
3065
+ model: GPT2LMHeadModel | T5ForConditionalGeneration,
3066
+ input_ids: torch.Tensor,
3067
+ attention_mask: torch.Tensor,
3068
+ eos_token_id: int,
3069
+ pad_token_id: int,
3070
+ bad_words_ids: list[list[int]],
3071
+ ) -> dict[str, Any]:
3072
+ """Test PyTorch performance of text generation.
3073
+
3074
+ Args:
3075
+ args (argparse.Namespace): arguments parsed from command line
3076
+ model (Union[GPT2LMHeadModel, T5ForConditionalGeneration]): PyTorch model
3077
+ input_ids (torch.Tensor): input_ids
3078
+ attention_mask (torch.Tensor): Attention mask
3079
+ eos_token_id (int): EOS token ID
3080
+ pad_token_id (int): Padding token ID
3081
+ bad_words_ids (List[List[int]]): Words shall not be generated.
3082
+
3083
+ Raises:
3084
+ RuntimeError: PyTorch with CUDA is not available for --use_gpu
3085
+
3086
+ Returns:
3087
+ Dict[str, Any]: A dictionary with string with metric name, and value can be integer or string.
3088
+ """
3089
+ if args.use_gpu and not torch.cuda.is_available():
3090
+ raise RuntimeError("Please install PyTorch with Cuda for testing gpu performance.")
3091
+
3092
+ if args.precision == Precision.FLOAT16.value:
3093
+ model.half()
3094
+
3095
+ device = torch.device("cuda:0" if args.use_gpu else "cpu")
3096
+ model.to(device)
3097
+
3098
+ torch.set_grad_enabled(False)
3099
+ input_ids = input_ids.to(device)
3100
+ attention_mask = attention_mask.to(device)
3101
+
3102
+ torch_latency = []
3103
+ for _ in range(args.total_runs):
3104
+ start = time.time()
3105
+ _ = model.generate(
3106
+ input_ids=input_ids,
3107
+ attention_mask=attention_mask,
3108
+ max_length=args.max_length,
3109
+ min_length=args.min_length,
3110
+ num_beams=args.num_beams,
3111
+ early_stopping=args.early_stopping,
3112
+ no_repeat_ngram_size=args.no_repeat_ngram_size,
3113
+ eos_token_id=eos_token_id,
3114
+ pad_token_id=pad_token_id,
3115
+ num_return_sequences=args.num_return_sequences,
3116
+ length_penalty=args.length_penalty,
3117
+ repetition_penalty=args.repetition_penalty,
3118
+ bad_words_ids=bad_words_ids if bad_words_ids else None,
3119
+ return_dict_in_generate=True,
3120
+ output_scores=args.output_sequences_scores or args.output_token_scores,
3121
+ )
3122
+ torch_latency.append(time.time() - start)
3123
+ batch_size = input_ids.shape[0]
3124
+ from benchmark_helper import get_latency_result # noqa: PLC0415
3125
+
3126
+ return get_latency_result(torch_latency, batch_size)
3127
+
3128
+
3129
+ def create_attention_mask(input_ids, pad_token_id):
3130
+ attention_mask = np.ones(input_ids.shape, dtype=np.int32)
3131
+ for i in range(input_ids.shape[0]):
3132
+ abs_pos = 0
3133
+ for j in range(input_ids.shape[1]):
3134
+ if input_ids[i][j] == pad_token_id and abs_pos == 0:
3135
+ attention_mask[i][j] = 0
3136
+ else:
3137
+ abs_pos += 1
3138
+ return attention_mask
3139
+
3140
+
3141
+ def test_gpt_model(
3142
+ args: argparse.Namespace,
3143
+ sentences: list[str] | None = None,
3144
+ is_greedy: bool = False,
3145
+ ):
3146
+ """Test GPT-2 model
3147
+
3148
+ Args:
3149
+ args (argparse.Namespace): arguments parsed from command line
3150
+ sentences (Optional[List[str]], optional): input text. Defaults to None.
3151
+
3152
+ Returns:
3153
+ Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string.
3154
+ """
3155
+ assert args.model_type == "gpt2"
3156
+
3157
+ tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
3158
+ tokenizer.padding_side = "left"
3159
+ tokenizer.pad_token = tokenizer.eos_token
3160
+
3161
+ model = GPT2LMHeadModel.from_pretrained(
3162
+ args.model_name_or_path,
3163
+ cache_dir=args.cache_dir,
3164
+ pad_token_id=tokenizer.eos_token_id,
3165
+ )
3166
+
3167
+ # Use different length sentences to test batching
3168
+ if sentences is None:
3169
+ sentences = [
3170
+ "The product is released",
3171
+ "I enjoy walking in the park",
3172
+ "Test best way to invest",
3173
+ ]
3174
+
3175
+ inputs = tokenizer(sentences, return_tensors="pt", padding=True)
3176
+ input_ids = inputs["input_ids"]
3177
+ attention_mask = inputs["attention_mask"]
3178
+
3179
+ bad_words = "walk in park"
3180
+ bad_words_ids = tokenizer.encode(bad_words, add_prefix_space=True)
3181
+ bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list
3182
+ if args.vocab_mask:
3183
+ logger.debug("bad_words_ids", bad_words_ids) # noqa: PLE1205
3184
+ else:
3185
+ bad_words_ids = []
3186
+
3187
+ config = model.config
3188
+ eos_token_id = config.eos_token_id
3189
+ pad_token_id = config.eos_token_id
3190
+ vocab_size = config.vocab_size
3191
+
3192
+ torch_decoded_sequences = []
3193
+ beam_outputs = None
3194
+ if not args.disable_parity:
3195
+ print("-" * 50)
3196
+ print("Test PyTorch model and beam search with huggingface transformers...")
3197
+ beam_outputs = model.generate(
3198
+ input_ids=input_ids,
3199
+ attention_mask=attention_mask,
3200
+ max_length=args.max_length,
3201
+ min_length=args.min_length,
3202
+ num_beams=args.num_beams,
3203
+ early_stopping=args.early_stopping,
3204
+ no_repeat_ngram_size=args.no_repeat_ngram_size,
3205
+ eos_token_id=eos_token_id,
3206
+ pad_token_id=pad_token_id,
3207
+ num_return_sequences=args.num_return_sequences,
3208
+ length_penalty=args.length_penalty,
3209
+ repetition_penalty=args.repetition_penalty,
3210
+ bad_words_ids=bad_words_ids if bad_words_ids else None,
3211
+ return_dict_in_generate=True,
3212
+ output_scores=args.output_sequences_scores or args.output_token_scores,
3213
+ )
3214
+ print("input_ids", input_ids)
3215
+ print("huggingface transformers outputs:")
3216
+ print("sequences", beam_outputs.sequences)
3217
+ if args.output_sequences_scores:
3218
+ print("sequences_scores", beam_outputs.sequences_scores)
3219
+ if args.output_token_scores:
3220
+ print("scores", beam_outputs.scores)
3221
+ for i, sequence in enumerate(beam_outputs.sequences):
3222
+ decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True)
3223
+ torch_decoded_sequences.append(decoded_sequence)
3224
+ print(f"{i}: {decoded_sequence}")
3225
+
3226
+ print("-" * 50)
3227
+ print("Testing beam search with onnxruntime...")
3228
+
3229
+ if is_greedy:
3230
+ inputs = {
3231
+ "input_ids": input_ids.cpu().numpy().astype(np.int32),
3232
+ "max_length": np.array([args.max_length], dtype=np.int32),
3233
+ "min_length": np.array([args.min_length], dtype=np.int32),
3234
+ "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32),
3235
+ }
3236
+ else:
3237
+ inputs = {
3238
+ "input_ids": input_ids.cpu().numpy().astype(np.int32),
3239
+ "max_length": np.array([args.max_length], dtype=np.int32),
3240
+ "min_length": np.array([args.min_length], dtype=np.int32),
3241
+ "num_beams": np.array([args.num_beams], dtype=np.int32),
3242
+ "num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32),
3243
+ "length_penalty": np.array([args.length_penalty], dtype=np.float32),
3244
+ "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32),
3245
+ }
3246
+
3247
+ if args.vocab_mask:
3248
+ vocab_mask = np.ones((vocab_size), dtype=np.int32)
3249
+ if args.vocab_mask:
3250
+ for bad_word_id in bad_words_ids:
3251
+ vocab_mask[bad_word_id] = 0
3252
+ inputs["vocab_mask"] = vocab_mask
3253
+
3254
+ if args.custom_attention_mask:
3255
+ inputs["attention_mask"] = create_attention_mask(input_ids, pad_token_id)
3256
+
3257
+ batch_size = input_ids.shape[0]
3258
+ if args.prefix_vocab_mask:
3259
+ logger.info("Use prefix vocab mask with all ones in ORT, but no corresponding setting for Torch model.")
3260
+ prefix_vocab_mask = np.ones((batch_size, vocab_size), dtype=np.int32)
3261
+ inputs["prefix_vocab_mask"] = prefix_vocab_mask
3262
+
3263
+ if args.save_test_data:
3264
+ test_data_dir = Path(args.output).parent.as_posix()
3265
+ logger.debug("test_data_dir", test_data_dir) # noqa: PLE1205
3266
+ from bert_test_data import output_test_data # noqa: PLC0415
3267
+
3268
+ logger.info(f"Saving test_data to {test_data_dir}/test_data_set_* ...")
3269
+
3270
+ all_inputs = [inputs]
3271
+ for i, inputs in enumerate(all_inputs):
3272
+ dir = os.path.join(test_data_dir, "test_data_set_" + str(i))
3273
+ output_test_data(dir, inputs)
3274
+
3275
+ logger.debug("ORT inputs", inputs) # noqa: PLE1205
3276
+
3277
+ if args.disable_perf_test:
3278
+ return
3279
+
3280
+ logger.debug("Creating ort session......")
3281
+ ort_session = create_ort_session(args.output, args.use_gpu, args.use_sln_strict_mode)
3282
+
3283
+ logger.debug("Run ort session......")
3284
+ result = ort_session.run(None, inputs)
3285
+
3286
+ # Test performance
3287
+ latency = []
3288
+ for _ in range(args.total_runs):
3289
+ start = time.time()
3290
+ _ = ort_session.run(None, inputs)
3291
+ latency.append(time.time() - start)
3292
+
3293
+ from benchmark_helper import get_latency_result # noqa: PLC0415
3294
+
3295
+ batch_size = input_ids.shape[0]
3296
+ output = get_latency_result(latency, batch_size)
3297
+
3298
+ print("ORT outputs:")
3299
+ sequences = result[0]
3300
+ print("sequences", sequences)
3301
+ if args.output_sequences_scores:
3302
+ print("sequences_scores", result[1])
3303
+ if args.output_token_scores:
3304
+ print("scores", result[2])
3305
+
3306
+ if is_greedy:
3307
+ (batch_size, max_length) = sequences.shape
3308
+ ort_decoded_sequences = []
3309
+ for i in range(batch_size):
3310
+ decoded_sequence = tokenizer.decode(sequences[i], skip_special_tokens=True)
3311
+ ort_decoded_sequences.append(decoded_sequence)
3312
+ print(f"batch {i} sequence: {decoded_sequence}")
3313
+ else:
3314
+ (batch_size, num_sequences, max_length) = sequences.shape
3315
+ ort_decoded_sequences = []
3316
+ for i in range(batch_size):
3317
+ for j in range(num_sequences):
3318
+ decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True)
3319
+ ort_decoded_sequences.append(decoded_sequence)
3320
+ print(f"batch {i} sequence {j}: {decoded_sequence}")
3321
+
3322
+ if beam_outputs:
3323
+ torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1)
3324
+ ort_sequences = torch.LongTensor(sequences)
3325
+ print("-" * 50)
3326
+ print("Torch Sequences:")
3327
+ print(torch_sequences)
3328
+ print(torch_decoded_sequences)
3329
+ print("-" * 50)
3330
+ print("ORT Sequences:")
3331
+ print(ort_sequences)
3332
+ print(ort_decoded_sequences)
3333
+ print("-" * 50)
3334
+ # Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not.
3335
+ is_same = torch_decoded_sequences == ort_decoded_sequences
3336
+ print("Torch and ORT result is", "same" if is_same else "different")
3337
+ output["parity"] = is_same
3338
+
3339
+ if args.torch_performance:
3340
+ torch_latency_output = test_torch_performance(
3341
+ args,
3342
+ model,
3343
+ input_ids,
3344
+ attention_mask,
3345
+ eos_token_id,
3346
+ pad_token_id,
3347
+ bad_words_ids,
3348
+ )
3349
+ print("Torch Latency", torch_latency_output)
3350
+
3351
+ print("ORT", output)
3352
+
3353
+ return output
3354
+
3355
+
3356
+ def test_t5_model(args: argparse.Namespace, sentences: list[str] | None = None):
3357
+ """Test T5 or MT5 model
3358
+
3359
+ Args:
3360
+ args (argparse.Namespace): arguments parsed from command line
3361
+ sentences (Optional[List[str]], optional): input text. Defaults to None.
3362
+
3363
+ Returns:
3364
+ Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string.
3365
+ """
3366
+ assert args.model_type in ["t5", "mt5"]
3367
+
3368
+ if args.prefix_vocab_mask:
3369
+ logger.debug("Skipping parity test as prefix vocab mask is not implemented by Hugging Face")
3370
+ return None
3371
+
3372
+ tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
3373
+ tokenizer.padding_side = "left"
3374
+
3375
+ if args.model_type == "t5":
3376
+ model = T5ForConditionalGeneration.from_pretrained(
3377
+ args.model_name_or_path,
3378
+ cache_dir=args.cache_dir,
3379
+ )
3380
+ else:
3381
+ model = MT5ForConditionalGeneration.from_pretrained(
3382
+ args.model_name_or_path,
3383
+ cache_dir=args.cache_dir,
3384
+ )
3385
+
3386
+ # Use different length sentences to test batching
3387
+ if sentences is None:
3388
+ sentences = [
3389
+ "translate English to French: The product is released",
3390
+ "summarize: research continues to show that pets bring real health benefits to their owners. Having a dog around can lead to lower levels of stress for both adults and kids.",
3391
+ # "summarize: I enjoy walking in the park. It makes my mind feel calm and refreshed. "
3392
+ # + "I enjoy looking at the trees, flowers, and wildlife around me, and listening to sound from natural.",
3393
+ ]
3394
+
3395
+ inputs = tokenizer(sentences, return_tensors="pt", padding=True)
3396
+ input_ids = inputs["input_ids"]
3397
+ attention_mask = inputs["attention_mask"]
3398
+
3399
+ bad_words = "walk in park"
3400
+ bad_words_ids = tokenizer.encode(bad_words)[:-1] # exclude the last token (EOS)
3401
+ bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list
3402
+ if args.vocab_mask:
3403
+ logger.debug("bad_words_ids", bad_words_ids) # noqa: PLE1205
3404
+ else:
3405
+ bad_words_ids = []
3406
+
3407
+ config = model.config
3408
+ eos_token_id = config.eos_token_id
3409
+ pad_token_id = config.pad_token_id
3410
+ vocab_size = config.vocab_size
3411
+ logger.debug(f"eos_token_id:{eos_token_id}, pad_token_id:{pad_token_id}, vocab_size:{vocab_size}")
3412
+
3413
+ torch_decoded_sequences = []
3414
+ if not args.disable_parity:
3415
+ print("-" * 50)
3416
+ print("Test PyTorch model and beam search with huggingface transformers...")
3417
+ beam_outputs = model.generate(
3418
+ input_ids=input_ids,
3419
+ attention_mask=attention_mask,
3420
+ max_length=args.max_length,
3421
+ min_length=args.min_length,
3422
+ num_beams=args.num_beams,
3423
+ early_stopping=args.early_stopping,
3424
+ no_repeat_ngram_size=args.no_repeat_ngram_size,
3425
+ eos_token_id=eos_token_id,
3426
+ pad_token_id=pad_token_id,
3427
+ num_return_sequences=args.num_return_sequences,
3428
+ length_penalty=args.length_penalty,
3429
+ repetition_penalty=args.repetition_penalty,
3430
+ bad_words_ids=bad_words_ids if bad_words_ids else None,
3431
+ return_dict_in_generate=True,
3432
+ output_scores=args.output_sequences_scores or args.output_token_scores,
3433
+ )
3434
+
3435
+ print("input_ids", input_ids)
3436
+ print("huggingface transformers outputs:")
3437
+ print("sequences", beam_outputs.sequences)
3438
+ if args.output_sequences_scores:
3439
+ print("sequences_scores", beam_outputs.sequences_scores)
3440
+ if args.output_token_scores:
3441
+ print("scores", beam_outputs.scores)
3442
+ for i, sequence in enumerate(beam_outputs.sequences):
3443
+ decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True)
3444
+ torch_decoded_sequences.append(decoded_sequence)
3445
+ print(f"{i}: {decoded_sequence}")
3446
+
3447
+ print("-" * 50)
3448
+ print("Testing beam search with onnxruntime...")
3449
+
3450
+ vocab_mask = np.ones((vocab_size), dtype=np.int32)
3451
+ if args.vocab_mask:
3452
+ for bad_word_id in bad_words_ids:
3453
+ vocab_mask[bad_word_id] = 0
3454
+
3455
+ inputs = {
3456
+ "input_ids": input_ids.cpu().numpy().astype(np.int32),
3457
+ "max_length": np.array([args.max_length], dtype=np.int32),
3458
+ "min_length": np.array([args.min_length], dtype=np.int32),
3459
+ "num_beams": np.array([args.num_beams], dtype=np.int32),
3460
+ "num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32),
3461
+ "length_penalty": np.array([args.length_penalty], dtype=np.float32),
3462
+ "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32),
3463
+ }
3464
+
3465
+ if args.vocab_mask:
3466
+ inputs["vocab_mask"] = vocab_mask
3467
+
3468
+ if args.custom_attention_mask:
3469
+ inputs["attention_mask"] = create_attention_mask(input_ids, pad_token_id)
3470
+
3471
+ if args.save_test_data:
3472
+ test_data_dir = Path(args.output).parent.as_posix()
3473
+ logger.debug("test_data_dir", test_data_dir) # noqa: PLE1205
3474
+ from bert_test_data import output_test_data # noqa: PLC0415
3475
+
3476
+ all_inputs = [inputs]
3477
+ for i, inputs in enumerate(all_inputs):
3478
+ dir = os.path.join(test_data_dir, "test_data_set_" + str(i))
3479
+ output_test_data(dir, inputs)
3480
+
3481
+ logger.debug("ORT inputs", inputs) # noqa: PLE1205
3482
+
3483
+ ort_session = create_ort_session(args.output, args.use_gpu, args.use_sln_strict_mode)
3484
+
3485
+ # Test performance
3486
+ latency = []
3487
+ for _ in range(args.total_runs):
3488
+ start = time.time()
3489
+ result = ort_session.run(None, inputs)
3490
+ latency.append(time.time() - start)
3491
+ batch_size = input_ids.shape[0]
3492
+ from benchmark_helper import get_latency_result # noqa: PLC0415
3493
+
3494
+ output = get_latency_result(latency, batch_size)
3495
+
3496
+ print("ORT outputs:")
3497
+ sequences = result[0]
3498
+ print("sequences", sequences)
3499
+ if args.output_sequences_scores:
3500
+ print("sequences_scores", result[1])
3501
+ if args.output_token_scores:
3502
+ print("scores", result[2])
3503
+
3504
+ (batch_size, num_sequences, max_length) = sequences.shape
3505
+ ort_decoded_sequences = []
3506
+ for i in range(batch_size):
3507
+ for j in range(num_sequences):
3508
+ decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True)
3509
+ ort_decoded_sequences.append(decoded_sequence)
3510
+ print(f"batch {i} sequence {j}: {decoded_sequence}")
3511
+
3512
+ if not args.disable_parity:
3513
+ torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1)
3514
+ ort_sequences = torch.LongTensor(sequences)
3515
+ print("-" * 50)
3516
+ print("Torch Sequences:")
3517
+ print(torch_sequences)
3518
+ print(torch_decoded_sequences)
3519
+ print("-" * 50)
3520
+ print("ORT Sequences:")
3521
+ print(ort_sequences)
3522
+ print(ort_decoded_sequences)
3523
+ print("-" * 50)
3524
+ # Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not.
3525
+ is_same = torch_decoded_sequences == ort_decoded_sequences
3526
+ print("Torch and ORT result is ", "same" if is_same else "different")
3527
+ output["parity"] = is_same
3528
+
3529
+ if args.torch_performance:
3530
+ torch_latency_output = test_torch_performance(
3531
+ args,
3532
+ model,
3533
+ input_ids,
3534
+ attention_mask,
3535
+ eos_token_id,
3536
+ pad_token_id,
3537
+ bad_words_ids,
3538
+ )
3539
+ print("Torch Latency", torch_latency_output)
3540
+
3541
+ print("ORT", output)
3542
+ return output
3543
+
3544
+
3545
+ def main(argv: list[str] | None = None, sentences: list[str] | None = None):
3546
+ """Main entry function
3547
+
3548
+ Args:
3549
+ argv (Optional[List[str]], optional): _description_. Defaults to None.
3550
+ sentences (Optional[List[str]], optional): input text. Defaults to None.
3551
+
3552
+ Raises:
3553
+ ValueError: Path does not exist: --encoder_decoder_init_onnx
3554
+ ValueError: Path does not exist: --decoder_onnx
3555
+ ValueError: --decoder_onnx and --encoder_decoder_init_onnx are not used together for T5
3556
+
3557
+ Returns:
3558
+ Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string.
3559
+ """
3560
+
3561
+ args = parse_arguments(argv)
3562
+ setup_logger(args.verbose)
3563
+
3564
+ if args.model_type in ["t5", "mt5"]:
3565
+ if args.encoder_decoder_init_onnx and not os.path.exists(args.encoder_decoder_init_onnx):
3566
+ raise ValueError(f"Path does not exist: --encoder_decoder_init_onnx {args.encoder_decoder_init_onnx}")
3567
+ if args.decoder_onnx and not os.path.exists(args.decoder_onnx):
3568
+ raise ValueError(f"Path does not exist: --decoder_onnx {args.decoder_onnx}")
3569
+ if (args.encoder_decoder_init_onnx and not args.decoder_onnx) or (
3570
+ args.decoder_onnx and not args.encoder_decoder_init_onnx
3571
+ ):
3572
+ raise ValueError("--decoder_onnx shall use together with --encoder_decoder_init_onnx")
3573
+
3574
+ is_greedy = args.num_beams == 1 and args.num_return_sequences == 1
3575
+
3576
+ if args.model_type == "gpt2" and is_greedy:
3577
+ if args.top_p > 0.0 and args.top_p < 1.0:
3578
+ convert_generation_model(args, GenerationType.SAMPLING)
3579
+ logger.info(
3580
+ "The test for gpt2_sampling onnx model is limited to non-custom model with small top_p(e.g <=0.01) value. The result should be the same as gpt2 greedy search."
3581
+ )
3582
+ if args.top_p > 0.01 or args.custom or args.seed:
3583
+ return
3584
+ else:
3585
+ convert_generation_model(args, GenerationType.GREEDYSEARCH)
3586
+ else:
3587
+ convert_generation_model(args)
3588
+
3589
+ logger.info("start testing model...")
3590
+ if args.model_type in ["t5", "mt5"]:
3591
+ result = test_t5_model(args, sentences=sentences)
3592
+ else:
3593
+ result = test_gpt_model(args, sentences=sentences, is_greedy=is_greedy)
3594
+
3595
+ if result:
3596
+ if args.use_external_data_format:
3597
+ logger.info(f"Output files: {args.output}, {args.output}.data")
3598
+ else:
3599
+ logger.info(f"Output file: {args.output}")
3600
+
3601
+ return result
3602
+
3603
+
3604
+ if __name__ == "__main__":
3605
+ main()