onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,3124 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # -------------------------------------------------------------------------
5
+ """
6
+ This converts GPT2 or T5 model to onnx with beam search operator.
7
+
8
+ Example 1: convert gpt2 model with beam search:
9
+ python convert_generation.py -m gpt2 --output gpt2_beam_search.onnx
10
+
11
+ Example 2: convert gpt2 model with beam search containing specific cuda optimizations:
12
+ python convert_generation.py -m gpt2 --output gpt2_beam_search.onnx --use_gpu \
13
+ --past_present_share_buffer --use_decoder_masked_attention
14
+
15
+ Example 3: convert gpt2 model with beam search with mixed precision and enable SkipLayerNorm strict mode:
16
+ python convert_generation.py -m gpt2 --output gpt2_beam_search.onnx --use_gpu -p fp16 --use_sln_strict_mode
17
+
18
+ Example 4: convert T5 model with beam search in two steps:
19
+ cd ./models/t5
20
+ python convert_to_onnx.py -m t5-small
21
+ cd ../..
22
+ python convert_generation.py -m t5-small --model_type t5 \
23
+ --decoder_onnx ./models/t5/onnx_models/t5-small_decoder.onnx \
24
+ --encoder_decoder_init_onnx ./models/t5/onnx_models/t5-small_encoder_decoder_init.onnx \
25
+ --output ./models/t5/onnx_models/t5_small_beam_search.onnx
26
+
27
+ Example 5: convert T5 model with beam search. All in one step:
28
+ python convert_generation.py -m t5-small --model_type t5 --output ./models/t5/onnx_models/t5_small_beam_search.onnx
29
+
30
+ Example 6: convert T5 model with beam search containing specific cuda optimizations. All in one step:
31
+ python convert_generation.py -m t5-small --model_type t5 --output ./models/t5/onnx_models/t5_small_beam_search.onnx \
32
+ --use_gpu --past_present_share_buffer --use_decoder_masked_attention
33
+
34
+ Example 7: convert MT5 model with external data file like mt5-base-beamsearch.onnx.data in below example.
35
+ python convert_generation.py -m google/mt5-base --model_type mt5 --output mt5-base-beamsearch.onnx -e
36
+
37
+ Example 8: convert gpt2 model with greedy search:
38
+ python convert_generation.py -m gpt2 --output gpt2_greedy_search.onnx --num_beams 1 --num_return_sequences 1
39
+
40
+ Example 9: convert gpt2 model with sampling:
41
+ python convert_generation.py -m gpt2 --output gpt2_sampling.onnx --num_beams 1 --num_return_sequences 1 --top_p 0.6
42
+ """
43
+
44
+ import argparse
45
+ import logging
46
+ import math
47
+ import os
48
+ import time
49
+ from enum import Enum
50
+ from pathlib import Path
51
+ from typing import Any, Dict, List, Optional, Union
52
+
53
+ import numpy as np
54
+ import onnx
55
+ import torch
56
+ from benchmark_helper import Precision, setup_logger
57
+ from fusion_utils import NumpyHelper
58
+ from onnx import GraphProto, ModelProto, TensorProto
59
+ from onnx_model import OnnxModel
60
+ from transformers import (
61
+ GPT2Config,
62
+ GPT2LMHeadModel,
63
+ GPT2Tokenizer,
64
+ MT5Config,
65
+ MT5ForConditionalGeneration,
66
+ T5Config,
67
+ T5ForConditionalGeneration,
68
+ T5Tokenizer,
69
+ )
70
+
71
+ from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_available_providers
72
+ from onnxruntime.transformers.models.gpt2.convert_to_onnx import main as convert_gpt2_to_onnx
73
+ from onnxruntime.transformers.models.gpt2.gpt2_helper import PRETRAINED_GPT2_MODELS
74
+ from onnxruntime.transformers.models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models
75
+ from onnxruntime.transformers.models.t5.t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS
76
+
77
+ logger = logging.getLogger("")
78
+
79
+
80
+ class GenerationType(Enum):
81
+ BEAMSEARCH = "beam_search"
82
+ GREEDYSEARCH = "greedy_search"
83
+ SAMPLING = "sampling"
84
+
85
+ def __str__(self):
86
+ return self.value
87
+
88
+
89
+ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
90
+ """Parse arguments
91
+
92
+ Args:
93
+ argv (Optional[List[str]], optional): _description_. Defaults to None.
94
+
95
+ Returns:
96
+ argparse.Namespace: Parsed arguments.
97
+ """
98
+ parser = argparse.ArgumentParser()
99
+
100
+ input_group = parser.add_argument_group("Input options")
101
+
102
+ input_group.add_argument(
103
+ "-m",
104
+ "--model_name_or_path",
105
+ required=True,
106
+ type=str,
107
+ help="Pytorch model checkpoint path, or pretrained model name in the list: "
108
+ + ", ".join(PRETRAINED_GPT2_MODELS + PRETRAINED_T5_MODELS + PRETRAINED_MT5_MODELS),
109
+ )
110
+
111
+ input_group.add_argument(
112
+ "--model_type",
113
+ required=False,
114
+ type=str,
115
+ default="gpt2",
116
+ choices=["gpt2", "t5", "mt5"],
117
+ help="Model type (default is gpt2) in the list: " + ", ".join(["gpt2", "t5", "mt5"]),
118
+ )
119
+
120
+ input_group.add_argument(
121
+ "--cache_dir",
122
+ required=False,
123
+ type=str,
124
+ default=os.path.join(".", "cache_models"),
125
+ help="Directory to cache pre-trained models",
126
+ )
127
+
128
+ input_group.add_argument(
129
+ "--decoder_onnx",
130
+ required=False,
131
+ type=str,
132
+ default="",
133
+ help="Path of onnx model for decoder. Specify it when you have exported the model.",
134
+ )
135
+
136
+ input_group.add_argument(
137
+ "--encoder_decoder_init_onnx",
138
+ required=False,
139
+ type=str,
140
+ default="",
141
+ help="Path of ONNX model for encoder and decoder initialization. Specify it when you have exported the model.",
142
+ )
143
+
144
+ parser.add_argument(
145
+ "--verbose",
146
+ required=False,
147
+ action="store_true",
148
+ help="Print more information",
149
+ )
150
+ parser.set_defaults(verbose=False)
151
+
152
+ output_group = parser.add_argument_group("Output options")
153
+
154
+ output_group.add_argument(
155
+ "--output",
156
+ required=True,
157
+ type=str,
158
+ help="Output path for onnx model with beam search.",
159
+ )
160
+
161
+ output_group.add_argument(
162
+ "-p",
163
+ "--precision",
164
+ required=False,
165
+ type=Precision,
166
+ default=Precision.FLOAT32,
167
+ choices=[Precision.FLOAT32, Precision.FLOAT16],
168
+ help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision",
169
+ )
170
+
171
+ output_group.add_argument(
172
+ "-b",
173
+ "--op_block_list",
174
+ required=False,
175
+ nargs="*",
176
+ default=["auto"],
177
+ help="Disable certain onnx operators when exporting model to onnx format. When using default"
178
+ 'value for gpt2 type of model fp16 precision, it will be set to ["Add", "LayerNormalization",'
179
+ ' "SkipLayerNormalization", "FastGelu"]. Other situation, it will be set to []',
180
+ )
181
+
182
+ output_group.add_argument(
183
+ "-e",
184
+ "--use_external_data_format",
185
+ required=False,
186
+ action="store_true",
187
+ help="save external data for model > 2G",
188
+ )
189
+ output_group.set_defaults(use_external_data_format=False)
190
+
191
+ output_group.add_argument(
192
+ "-s", "--run_shape_inference", required=False, action="store_true", help="run shape inference"
193
+ )
194
+ output_group.set_defaults(run_shape_inference=False)
195
+
196
+ output_group.add_argument(
197
+ "-dpvs",
198
+ "--disable_pad_vocab_size",
199
+ required=False,
200
+ action="store_true",
201
+ help="Do not pad logits MatMul weight to be a multiple of 8 along the dimension where dim value is"
202
+ " the vocab size. The logits MatMul may hence be of poor performance for fp16 precision.",
203
+ )
204
+ output_group.set_defaults(disable_pad_vocab_size=False)
205
+
206
+ output_group.add_argument(
207
+ "-dsgd",
208
+ "--disable_separate_gpt2_decoder_for_init_run",
209
+ required=False,
210
+ action="store_true",
211
+ help="Do not create separate decoder subgraphs for initial and remaining runs. This does not allow "
212
+ "for optimizations based on sequence lengths in each subgraph",
213
+ )
214
+ output_group.set_defaults(disable_separate_gpt2_decoder_for_init_run=False)
215
+
216
+ output_group.add_argument(
217
+ "-i",
218
+ "--disable_shared_initializers",
219
+ required=False,
220
+ action="store_true",
221
+ help="do not share initializers in encoder and decoder for T5 or in the init decoder and decoder for "
222
+ "GPT2. It will increase memory usage of t5/mt5/gpt2 models.",
223
+ )
224
+ output_group.set_defaults(disable_shared_initializers=False)
225
+
226
+ model_group = parser.add_argument_group("Beam search parameters that stored in the output model")
227
+
228
+ model_group.add_argument(
229
+ "--output_sequences_scores",
230
+ required=False,
231
+ action="store_true",
232
+ help="output sequences scores",
233
+ )
234
+ model_group.set_defaults(output_sequences_scores=False)
235
+
236
+ model_group.add_argument(
237
+ "--output_token_scores",
238
+ required=False,
239
+ action="store_true",
240
+ help="output token scores",
241
+ )
242
+ model_group.set_defaults(output_token_scores=False)
243
+
244
+ model_group.add_argument("--early_stopping", required=False, action="store_true")
245
+ model_group.set_defaults(early_stopping=False)
246
+
247
+ model_group.add_argument(
248
+ "--no_repeat_ngram_size",
249
+ type=int,
250
+ required=False,
251
+ default=0,
252
+ help="No repeat ngram size",
253
+ )
254
+
255
+ model_group.add_argument(
256
+ "--vocab_mask",
257
+ required=False,
258
+ action="store_true",
259
+ help="Enable vocab_mask. This mask applies only to every generated token to filter some bad words.",
260
+ )
261
+ model_group.set_defaults(vocab_mask=False)
262
+
263
+ model_group.add_argument(
264
+ "--past_present_share_buffer",
265
+ required=False,
266
+ action="store_true",
267
+ help="Use shared buffer for past and present, currently work for gpt2 greedy/sampling search.",
268
+ )
269
+ model_group.set_defaults(past_present_share_buffer=False)
270
+
271
+ model_group.add_argument(
272
+ "--use_decoder_masked_attention",
273
+ required=False,
274
+ action="store_true",
275
+ help="Uses `DecoderMaskedSelfAttention` or `DecoderMaskedMultiHeadAttention` to optimize the decoding Attention computation. "
276
+ "Must be used with `past_present_share_buffer`. Currently, only Attention head sizes of 32, 64 and 128 are supported.",
277
+ )
278
+ model_group.set_defaults(use_decoder_masked_attention=False)
279
+
280
+ model_group.add_argument(
281
+ "--prefix_vocab_mask",
282
+ required=False,
283
+ action="store_true",
284
+ help="Enable prefix_vocab_mask. This mask can be used to filter bad words in the first generated token only",
285
+ )
286
+ model_group.set_defaults(prefix_vocab_mask=False)
287
+
288
+ model_group.add_argument(
289
+ "--custom_attention_mask",
290
+ required=False,
291
+ action="store_true",
292
+ help="Enable custom_attention_mask. This mask can be used to replace default encoder attention mask",
293
+ )
294
+ model_group.set_defaults(custom_attention_mask=False)
295
+
296
+ model_group.add_argument(
297
+ "--presence_mask",
298
+ required=False,
299
+ action="store_true",
300
+ help="Presence mask for custom sampling",
301
+ )
302
+ model_group.set_defaults(presence_mask=False)
303
+
304
+ model_group.add_argument(
305
+ "--seed",
306
+ required=False,
307
+ action="store_true",
308
+ help="Random seed for sampling op",
309
+ )
310
+ model_group.set_defaults(seed=False)
311
+
312
+ beam_parameters_group = parser.add_argument_group(
313
+ "Beam search parameters not stored in the output model, for testing parity and performance"
314
+ )
315
+
316
+ beam_parameters_group.add_argument("--min_length", type=int, required=False, default=1, help="Min sequence length")
317
+
318
+ beam_parameters_group.add_argument("--max_length", type=int, required=False, default=50, help="Max sequence length")
319
+
320
+ beam_parameters_group.add_argument("--num_beams", type=int, required=False, default=4, help="Beam size")
321
+
322
+ beam_parameters_group.add_argument(
323
+ "--num_return_sequences",
324
+ type=int,
325
+ required=False,
326
+ default=1,
327
+ help="Number of return sequence <= num_beams",
328
+ )
329
+
330
+ beam_parameters_group.add_argument(
331
+ "--length_penalty",
332
+ type=float,
333
+ required=False,
334
+ default=1,
335
+ help="Positive. >1 to penalize and <1 to encourage short sentence.",
336
+ )
337
+
338
+ beam_parameters_group.add_argument(
339
+ "--repetition_penalty",
340
+ type=float,
341
+ required=False,
342
+ default=1,
343
+ help="Positive. >1 to penalize and <1 to encourage.",
344
+ )
345
+
346
+ beam_parameters_group.add_argument(
347
+ "--temperature",
348
+ type=float,
349
+ required=False,
350
+ default=1.0,
351
+ help="The value used to module the next token probabilities.",
352
+ )
353
+
354
+ beam_parameters_group.add_argument(
355
+ "--top_p",
356
+ type=float,
357
+ required=False,
358
+ default=1.0,
359
+ help="Top P for sampling",
360
+ )
361
+
362
+ beam_parameters_group.add_argument(
363
+ "--filter_value",
364
+ type=float,
365
+ required=False,
366
+ default=-float("Inf"),
367
+ help="Filter value for Top P sampling",
368
+ )
369
+
370
+ beam_parameters_group.add_argument(
371
+ "--min_tokens_to_keep",
372
+ type=int,
373
+ required=False,
374
+ default=1,
375
+ help="Minimum number of tokens we keep per batch example in the output.",
376
+ )
377
+
378
+ beam_parameters_group.add_argument(
379
+ "--presence_penalty",
380
+ type=float,
381
+ required=False,
382
+ default=0.0,
383
+ help="presence penalty for custom sampling.",
384
+ )
385
+
386
+ beam_parameters_group.add_argument(
387
+ "--custom",
388
+ type=int,
389
+ required=False,
390
+ default=0,
391
+ help="If 1 customized top P logic is applied",
392
+ )
393
+
394
+ beam_parameters_group.add_argument(
395
+ "--vocab_size",
396
+ type=int,
397
+ required=False,
398
+ default=-1,
399
+ help="Vocab_size of the underlying model used to decide the shape of vocab mask",
400
+ )
401
+
402
+ beam_parameters_group.add_argument(
403
+ "--eos_token_id",
404
+ type=int,
405
+ required=False,
406
+ default=-1,
407
+ help="custom eos_token_id for generating model with existing onnx encoder/decoder",
408
+ )
409
+
410
+ beam_parameters_group.add_argument(
411
+ "--pad_token_id",
412
+ type=int,
413
+ required=False,
414
+ default=-1,
415
+ help="custom pad_token_id for generating model with existing onnx encoder/decoder",
416
+ )
417
+
418
+ test_group = parser.add_argument_group("Other options for testing parity and performance")
419
+
420
+ test_group.add_argument(
421
+ "--use_sln_strict_mode",
422
+ required=False,
423
+ action="store_true",
424
+ help="Enable strict mode for SLN in CUDA provider. This ensures a better accuracy but will be slower.",
425
+ )
426
+ test_group.set_defaults(use_sln_strict_mode=False)
427
+
428
+ test_group.add_argument(
429
+ "--use_gpu", required=False, action="store_true", help="use GPU for inference. Required for fp16."
430
+ )
431
+ test_group.set_defaults(use_gpu=False)
432
+
433
+ test_group.add_argument(
434
+ "--disable_parity",
435
+ required=False,
436
+ action="store_true",
437
+ help="do not run parity test",
438
+ )
439
+ test_group.set_defaults(disable_parity=False)
440
+
441
+ test_group.add_argument(
442
+ "--disable_perf_test",
443
+ required=False,
444
+ action="store_true",
445
+ help="do not run perf test",
446
+ )
447
+ test_group.set_defaults(disable_perf_test=False)
448
+
449
+ test_group.add_argument(
450
+ "--torch_performance",
451
+ required=False,
452
+ action="store_true",
453
+ help="test PyTorch performance",
454
+ )
455
+ test_group.set_defaults(torch_performance=False)
456
+
457
+ test_group.add_argument(
458
+ "--total_runs",
459
+ required=False,
460
+ type=int,
461
+ default=1,
462
+ help="Number of times of inference for latency measurement",
463
+ )
464
+
465
+ test_group.add_argument(
466
+ "--save_test_data",
467
+ required=False,
468
+ action="store_true",
469
+ help="save test data for onnxruntime_perf_test tool",
470
+ )
471
+ test_group.set_defaults(save_test_data=False)
472
+
473
+ args = parser.parse_args(argv)
474
+
475
+ return args
476
+
477
+
478
+ def gpt2_to_onnx(args: argparse.Namespace):
479
+ """Convert GPT-2 model to onnx
480
+
481
+ Args:
482
+ args (argparse.Namespace): arguments parsed from command line
483
+ """
484
+ model_name = args.model_name_or_path
485
+
486
+ arguments = [
487
+ "--model_name_or_path",
488
+ model_name,
489
+ "--output",
490
+ args.decoder_onnx,
491
+ "--optimize_onnx",
492
+ "--precision",
493
+ "fp32" if args.precision == Precision.FLOAT32 else "fp16",
494
+ "--test_runs",
495
+ "1",
496
+ "--test_cases",
497
+ "10",
498
+ "--overwrite", # Overwrite onnx file if existed
499
+ ]
500
+ if args.cache_dir:
501
+ arguments.extend(["--cache_dir", args.cache_dir])
502
+ if args.use_gpu:
503
+ arguments.append("--use_gpu")
504
+ if args.use_external_data_format:
505
+ arguments.append("--use_external_data_format")
506
+
507
+ if len(args.op_block_list):
508
+ arguments.extend(["--op_block_list"])
509
+ arguments.extend(args.op_block_list)
510
+
511
+ if args.precision == Precision.FLOAT16:
512
+ assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu"
513
+ # TODO(tianleiwu): Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision')
514
+ # Need change cuda kernel to support a combination of fp32 logits and fp16 past state.
515
+ # Currently logits and past state shall be same data type.
516
+
517
+ if args.verbose:
518
+ logger.info(f"arguments for convert_to_onnx:{arguments}")
519
+
520
+ convert_gpt2_to_onnx(argv=arguments)
521
+
522
+
523
+ def t5_to_onnx(args: argparse.Namespace):
524
+ """Convert T5 model to onnx
525
+
526
+ Args:
527
+ args (argparse.Namespace): arguments parsed from command line
528
+ """
529
+ paths = export_t5_onnx_models(
530
+ args.model_name_or_path,
531
+ args.cache_dir,
532
+ Path(args.output).parent,
533
+ use_gpu=args.use_gpu,
534
+ use_external_data_format=args.use_external_data_format,
535
+ optimize_onnx=(args.precision != Precision.FLOAT16),
536
+ precision=args.precision,
537
+ verbose=False,
538
+ use_decoder_start_token=False,
539
+ merge_encoder_and_decoder_init=True,
540
+ overwrite=True,
541
+ disable_auto_mixed_precision=False,
542
+ use_int32_inputs=True,
543
+ model_type=args.model_type,
544
+ )
545
+
546
+ logger.debug(f"onnx model for encoder: {paths[0]}")
547
+ logger.debug(f"onnx model for decoder: {paths[1]}")
548
+ args.encoder_decoder_init_onnx = paths[0]
549
+ args.decoder_onnx = paths[1]
550
+
551
+
552
+ def shape_inference(onnx_path: str, use_external_data_format: bool = True):
553
+ """Shape inference on an onnx file, which will be overwritten.
554
+
555
+ Args:
556
+ onnx_path (str): Path of onnx model
557
+ use_external_data_format(bool): output tensors to external data or not.
558
+ """
559
+ # Run symbolic shape inference to walk around ORT shape inference issue for subgraph.
560
+ from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
561
+
562
+ model = onnx.load_model(onnx_path, load_external_data=True)
563
+ out = SymbolicShapeInference.infer_shapes(model, auto_merge=True, guess_output_rank=False)
564
+ if out:
565
+ OnnxModel.save(out, onnx_path, save_as_external_data=use_external_data_format)
566
+ else:
567
+ logger.warning("Failed to run symbolic shape inference on the model.")
568
+
569
+
570
+ def pad_weights_of_logits_matmul(onnx_path: str, use_external_data_format: bool = True) -> bool:
571
+ """Pad the logits MatMul weight in the provided decoder model, which will be overwritten.
572
+
573
+ Args:
574
+ onnx_path (str): Path of onnx model
575
+ use_external_data_format(bool): output tensors to external data or not.
576
+ """
577
+ decoder_model_proto = onnx.load_model(onnx_path, load_external_data=True)
578
+
579
+ logits_output_name = decoder_model_proto.graph.output[0].name
580
+
581
+ decoder_model = OnnxModel(decoder_model_proto)
582
+
583
+ output_name_to_node = decoder_model.output_name_to_node()
584
+ assert logits_output_name in output_name_to_node
585
+
586
+ matmul_node = output_name_to_node[logits_output_name]
587
+ # Sanity check - the logits need to be produced by a MatMul node
588
+ if matmul_node.op_type != "MatMul":
589
+ return False
590
+
591
+ # The logits MatMul weight MUST be an initializer (or)
592
+ # it MUST be flowing through a Transpose whose input is
593
+ # an initializer
594
+ pad_along_axis_1 = True
595
+ logits_weight = decoder_model.get_initializer(matmul_node.input[1])
596
+ if logits_weight is None:
597
+ transpose_before_matmul = decoder_model.match_parent(matmul_node, "Transpose", 1)
598
+
599
+ if transpose_before_matmul is None:
600
+ return False
601
+
602
+ logits_weight = decoder_model.get_initializer(transpose_before_matmul.input[0])
603
+
604
+ if logits_weight is None:
605
+ return False
606
+
607
+ pad_along_axis_1 = False
608
+
609
+ # The logits MatMul weight MUST be fp16
610
+ if logits_weight.data_type != TensorProto.DataType.FLOAT16:
611
+ return False
612
+
613
+ # The logits MatMul weight MUST be 2-dimensional
614
+ if len(logits_weight.dims) != 2:
615
+ return False
616
+
617
+ # Pad and over-write the initializer (if needed)
618
+ actual_vocab_size = logits_weight.dims[1]
619
+
620
+ if (actual_vocab_size % 8) == 0:
621
+ # Already "padded"
622
+ return True
623
+
624
+ padded_vocab_size = math.ceil(actual_vocab_size / 8) * 8
625
+ padding = padded_vocab_size - actual_vocab_size
626
+
627
+ # TODO(hasesh): Handle cases where the fp16 data is stored in the
628
+ # non-raw data field
629
+ if logits_weight.raw_data:
630
+ if pad_along_axis_1:
631
+ padding_data = np.zeros((logits_weight.dims[0], padding), dtype=np.float16)
632
+ weight_with_padding = np.concatenate((NumpyHelper.to_array(logits_weight), padding_data), axis=1)
633
+ logits_weight.dims[1] = padded_vocab_size
634
+ else:
635
+ padding_data = np.zeros((padding, logits_weight.dims[1]), dtype=np.float16)
636
+ weight_with_padding = np.concatenate((NumpyHelper.to_array(logits_weight), padding_data), axis=0)
637
+ logits_weight.dims[0] = padded_vocab_size
638
+
639
+ logits_weight.raw_data = weight_with_padding.tobytes()
640
+ else:
641
+ return False
642
+
643
+ # Save the model
644
+ OnnxModel.save(decoder_model_proto, onnx_path, save_as_external_data=use_external_data_format)
645
+ return True
646
+
647
+
648
+ def create_ort_session(model_path: str, use_gpu: bool, use_sln_strict_mode: bool) -> InferenceSession:
649
+ """Create OnnxRuntime session.
650
+
651
+ Args:
652
+ model_path (str): onnx model path
653
+ use_gpu (bool): use GPU or not
654
+ use_sln_strict_mode (bool): use strict mode for skip layer normalization or not
655
+
656
+ Raises:
657
+ RuntimeError: CUDAExecutionProvider is not available when --use_gpu is specified.
658
+
659
+ Returns:
660
+ onnxruntime.InferenceSession: The created session.
661
+ """
662
+ sess_options = SessionOptions()
663
+ sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
664
+ execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
665
+ if use_gpu:
666
+ if "CUDAExecutionProvider" not in get_available_providers():
667
+ raise RuntimeError("CUDAExecutionProvider is not available for --use_gpu!")
668
+ else:
669
+ logger.info("use CUDAExecutionProvider")
670
+ if use_sln_strict_mode:
671
+ cuda_provider_options = {"enable_skip_layer_norm_strict_mode": True}
672
+ provider_options = {"CUDAExecutionProvider": cuda_provider_options}
673
+ execution_providers = [
674
+ (name, provider_options[name]) if name in provider_options else name for name in execution_providers
675
+ ]
676
+
677
+ ort_session = InferenceSession(model_path, sess_options, providers=execution_providers)
678
+ return ort_session
679
+
680
+
681
+ def verify_gpt2_subgraph(graph: onnx.GraphProto, precision: Precision):
682
+ """Verify GPT-2 subgraph
683
+
684
+ Args:
685
+ graph (onnx.GraphProto): onnx graph of GPT-2
686
+ precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
687
+
688
+ Raises:
689
+ ValueError: Number of inputs not expected.
690
+ ValueError: Input name is not expected.
691
+ ValueError: Input data type is not expected.
692
+ ValueError: Number of outputs not expected.
693
+ ValueError: Output name is not expected.
694
+ ValueError: Output data type is not expected.
695
+ """
696
+ is_float16 = precision == Precision.FLOAT16
697
+
698
+ input_count = len(graph.input)
699
+ layer_count = input_count - 3
700
+ assert layer_count >= 1
701
+
702
+ expected_inputs = ["input_ids", "position_ids", "attention_mask"] + [f"past_{i}" for i in range(layer_count)]
703
+ if len(graph.input) != len(expected_inputs):
704
+ raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
705
+
706
+ for i, expected_input in enumerate(expected_inputs):
707
+ if graph.input[i].name != expected_input:
708
+ raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
709
+
710
+ expected_type = TensorProto.INT32
711
+ if i >= 3:
712
+ expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
713
+
714
+ input_type = graph.input[i].type.tensor_type.elem_type
715
+ if input_type != expected_type:
716
+ raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
717
+ logger.info("Verifying GPT-2 graph inputs: name and data type are good.")
718
+
719
+ expected_outputs = ["logits"] + [f"present_{i}" for i in range(layer_count)]
720
+ if len(graph.output) != len(expected_outputs):
721
+ raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
722
+
723
+ for i, expected_output in enumerate(expected_outputs):
724
+ if graph.output[i].name != expected_output:
725
+ raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
726
+
727
+ expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
728
+ output_type = graph.output[i].type.tensor_type.elem_type
729
+ if output_type != expected_type:
730
+ raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {output_type}")
731
+ logger.info("Verifying GPT-2 graph outputs: name and data type are good.")
732
+
733
+ # TODO(tianleiwu): verify shapes of inputs and outputs.
734
+ return
735
+
736
+
737
+ def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision):
738
+ """Verify T5 decoder subgraph
739
+
740
+ Args:
741
+ graph (onnx.GraphProto): onnx graph of T5 decoder
742
+ precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
743
+
744
+ Raises:
745
+ ValueError: Number of inputs not expected.
746
+ ValueError: Input name is not expected.
747
+ ValueError: Input data type is not expected.
748
+ ValueError: Number of outputs not expected.
749
+ ValueError: Output name is not expected.
750
+ ValueError: Output data type is not expected.
751
+ """
752
+ is_float16 = precision == Precision.FLOAT16
753
+ float_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
754
+
755
+ input_count = len(graph.input)
756
+ layer_count = (input_count - 2) // 4
757
+ assert layer_count >= 1
758
+
759
+ # Expect inputs:
760
+ # input_ids: int32 (B, 1)
761
+ # encoder_attention_mask: int32 (B, encode_sequence_length)
762
+
763
+ # past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
764
+ # past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
765
+ # ... (for each self attention layer)
766
+
767
+ # past_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
768
+ # past_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
769
+ # ... (for each cross attention layer)
770
+
771
+ # TODO: encoder_hidden_states is optional
772
+ expected_inputs = ["input_ids", "encoder_attention_mask"]
773
+ for i in range(layer_count):
774
+ expected_inputs.append(f"past_key_self_{i}")
775
+ expected_inputs.append(f"past_value_self_{i}")
776
+ for i in range(layer_count):
777
+ expected_inputs.append(f"past_key_cross_{i}")
778
+ expected_inputs.append(f"past_value_cross_{i}")
779
+
780
+ if len(graph.input) != len(expected_inputs):
781
+ raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
782
+
783
+ for i, expected_input in enumerate(expected_inputs):
784
+ if graph.input[i].name != expected_input:
785
+ raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
786
+
787
+ expected_type = TensorProto.INT32 if i < 2 else float_type
788
+ input_type = graph.input[i].type.tensor_type.elem_type
789
+ if input_type != expected_type:
790
+ raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
791
+
792
+ # Expect outputs:
793
+ # logits: (B, 1, vocab_size)
794
+ # present_key_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
795
+ # present_value_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
796
+ # ... (for each self attention layer)
797
+ expected_outputs = ["logits"]
798
+ for i in range(layer_count):
799
+ expected_outputs.append(f"present_key_self_{i}")
800
+ expected_outputs.append(f"present_value_self_{i}")
801
+
802
+ if len(graph.output) != len(expected_outputs):
803
+ raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
804
+
805
+ for i, expected_output in enumerate(expected_outputs):
806
+ if graph.output[i].name != expected_output:
807
+ raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
808
+ output_type = graph.output[i].type.tensor_type.elem_type
809
+ if output_type != float_type:
810
+ raise ValueError(f"Output {i} is expected to have onnx data type {float_type}. Got {output_type}")
811
+
812
+
813
+ def verify_t5_encoder_decoder_init_subgraph(graph: onnx.GraphProto, precision: Precision):
814
+ """Verify T5 decoder subgraph
815
+
816
+ Args:
817
+ graph (onnx.GraphProto): onnx graph of T5 decoder
818
+ precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
819
+
820
+ Raises:
821
+ ValueError: Number of inputs not expected.
822
+ ValueError: Input name is not expected.
823
+ ValueError: Input data type is not expected.
824
+ ValueError: Number of outputs not expected.
825
+ ValueError: Output name is not expected.
826
+ ValueError: Output data type is not expected.
827
+ """
828
+ is_float16 = precision == Precision.FLOAT16
829
+ layer_count = (len(graph.output) - 2) // 4
830
+ assert layer_count >= 1
831
+
832
+ # Expect 3 inputs:
833
+ # encoder_input_ids: int32 (B, encode_sequence_length)
834
+ # encoder_attention_mask: int32 (B, encode_sequence_length)
835
+ # decoder_input_ids: int32 (B, 1)
836
+ expected_inputs = ["encoder_input_ids", "encoder_attention_mask", "decoder_input_ids"]
837
+ if len(graph.input) != len(expected_inputs):
838
+ raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
839
+
840
+ for i, expected_input in enumerate(expected_inputs):
841
+ if graph.input[i].name != expected_input:
842
+ raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
843
+
844
+ expected_type = TensorProto.INT32
845
+ input_type = graph.input[i].type.tensor_type.elem_type
846
+ if input_type != expected_type:
847
+ raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
848
+
849
+ # Expected outputs:
850
+ # logits: (B, 1, vocab_size)
851
+ # encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
852
+ # present_key_self_0: (B, num_heads, 1, head_size)
853
+ # present_value_self_0: (B, num_heads, 1, head_size)
854
+ # ... (for each self attention layer)
855
+ # present_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
856
+ # present_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
857
+ # ... (for each cross attention layer)
858
+ expected_outputs = ["logits", "encoder_hidden_states"]
859
+ for i in range(layer_count):
860
+ expected_outputs.append(f"present_key_self_{i}")
861
+ expected_outputs.append(f"present_value_self_{i}")
862
+ for i in range(layer_count):
863
+ expected_outputs.append(f"present_key_cross_{i}")
864
+ expected_outputs.append(f"present_value_cross_{i}")
865
+
866
+ if len(graph.output) != len(expected_outputs):
867
+ raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
868
+
869
+ for i, expected_output in enumerate(expected_outputs):
870
+ if graph.output[i].name != expected_output:
871
+ raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
872
+
873
+ expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
874
+ output_type = graph.output[i].type.tensor_type.elem_type
875
+ if output_type != expected_type:
876
+ raise ValueError(f"Output {i} is expected to have onnx data type {expected_type}. Got {output_type}")
877
+
878
+ logger.info("T5 encoder graph verified: name and data type of inputs and outputs are good.")
879
+
880
+
881
+ def remove_shared_initializers(
882
+ graph1: GraphProto,
883
+ graph2: GraphProto,
884
+ shared_prefix: str = "shared_",
885
+ min_elements: int = 1024,
886
+ signature_cache1: Optional[dict] = None,
887
+ signature_cache2: Optional[dict] = None,
888
+ ):
889
+ """Remove initializers with same value from two graphs.
890
+
891
+ Args:
892
+ graph1 (GraphProto): the first graph to process
893
+ graph2 (GraphProto): the second graph to process
894
+ shared_prefix (str): add prefix to the shared initializers among two graphs
895
+ min_elements (int, optional): minimal number of elements for initializers to be considered. Defaults to 1024.
896
+ signature_cache1 (dict): Optional dictionary to store data signatures of tensors in graph1 in order to speed up comparison
897
+ signature_cache2 (dict): Optional dictionary to store data signatures of tensors in graph2 in order to speed up comparison
898
+ """
899
+
900
+ mapping_initializers_1 = {}
901
+ mapping_initializers_2 = {}
902
+ shared_initializers_1 = []
903
+ shared_initializers_2 = []
904
+ shared_initializers_names = []
905
+
906
+ for initializer1 in graph1.initializer:
907
+ if not (initializer1.dims and sum(initializer1.dims) >= min_elements):
908
+ continue
909
+
910
+ for initializer2 in graph2.initializer:
911
+ if not (initializer2.dims and sum(initializer2.dims) >= min_elements):
912
+ continue
913
+
914
+ if OnnxModel.has_same_value(initializer1, initializer2, signature_cache1, signature_cache2):
915
+ mapping_initializers_1[initializer1.name] = shared_prefix + initializer2.name
916
+ shared_initializers_1.append(initializer1)
917
+
918
+ if initializer2.name not in mapping_initializers_2:
919
+ shared_name = shared_prefix + initializer2.name
920
+ mapping_initializers_2[initializer2.name] = shared_name
921
+ shared_initializers_2.append(initializer2)
922
+ shared_initializers_names.append(shared_name)
923
+ break
924
+
925
+ logger.debug(f"shared initializers:{shared_initializers_names}")
926
+
927
+ # Make sure new name does not exist in graph 1
928
+ for node in graph1.node:
929
+ for j in range(len(node.input)):
930
+ if node.input[j] in shared_initializers_names:
931
+ raise RuntimeError(f"name is found in graph 1: {node.input[j]}")
932
+
933
+ # Make sure new name does not exist in graph 2
934
+ for node in graph2.node:
935
+ for j in range(len(node.input)):
936
+ if node.input[j] in shared_initializers_names:
937
+ raise RuntimeError(f"name is found in graph 2: {node.input[j]}")
938
+
939
+ # Remove shared initializers from graph 2
940
+ for initializer in shared_initializers_2:
941
+ graph2.initializer.remove(initializer)
942
+
943
+ # Rename value info for old names in graph 2
944
+ for value_info in graph2.value_info:
945
+ if value_info.name in mapping_initializers_2:
946
+ value_info.name = mapping_initializers_2[value_info.name]
947
+
948
+ # Rename nodes inputs in graph 2:
949
+ for node in graph2.node:
950
+ for j in range(len(node.input)):
951
+ if node.input[j] in mapping_initializers_2:
952
+ new_name = mapping_initializers_2[node.input[j]]
953
+ logger.debug(f"graph 2 rename node {node.name} input {j} from {node.input[j]} to {new_name}")
954
+ node.input[j] = new_name
955
+
956
+ # Remove shared initializers from graph 1
957
+ for initializer in shared_initializers_1:
958
+ graph1.initializer.remove(initializer)
959
+
960
+ # Rename value info for old names in graph 1
961
+ for value_info in graph1.value_info:
962
+ if value_info.name in mapping_initializers_1:
963
+ value_info.name = mapping_initializers_1[value_info.name]
964
+
965
+ # Rename nodes inputs in graph 1:
966
+ for node in graph1.node:
967
+ for j in range(len(node.input)):
968
+ if node.input[j] in mapping_initializers_1:
969
+ new_name = mapping_initializers_1[node.input[j]]
970
+ logger.debug(f"graph 1 rename node {node.name} input {j} from {node.input[j]} to {new_name}")
971
+ node.input[j] = new_name
972
+
973
+ # Rename shared initializers in graph 2
974
+ for initializer in shared_initializers_2:
975
+ initializer.name = mapping_initializers_2[initializer.name]
976
+
977
+ for initializer in shared_initializers_2:
978
+ shape = onnx.numpy_helper.to_array(initializer).shape
979
+ value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape)
980
+ # Need add value_info for initializers moved to parent graph. Otherwise, ORT will fail.
981
+ graph1.value_info.append(value_info)
982
+ graph2.value_info.append(value_info)
983
+
984
+ return shared_initializers_2
985
+
986
+
987
+ def get_shared_initializers(encoder_model: ModelProto, decoder_model: ModelProto):
988
+ encoder = OnnxModel(encoder_model)
989
+ decoder = OnnxModel(decoder_model)
990
+ encoder.add_prefix_to_names("e_")
991
+ decoder.add_prefix_to_names("d_")
992
+ signature_cache1, signature_cache2 = {}, {}
993
+ encoder.remove_duplicated_initializer(signature_cache1)
994
+ decoder.remove_duplicated_initializer(signature_cache2)
995
+ initializers = remove_shared_initializers(
996
+ decoder.model.graph,
997
+ encoder.model.graph,
998
+ shared_prefix="s_",
999
+ signature_cache1=signature_cache1,
1000
+ signature_cache2=signature_cache2,
1001
+ )
1002
+ return initializers
1003
+
1004
+
1005
+ def move_initializers(
1006
+ graph: GraphProto,
1007
+ min_elements: int = 1024,
1008
+ ) -> List[TensorProto]:
1009
+ """Remove initializers of a graph, when they have number of elements larger than a threshold.
1010
+
1011
+ Args:
1012
+ graph (GraphProto): the graph.
1013
+ min_elements (int, optional): minimal number of elements for initializers to be considered. Defaults to 1024.
1014
+
1015
+ Returns:
1016
+ List[TensorProto]: initializers that are removed from the graph.
1017
+ """
1018
+ moved_initializers = []
1019
+ for tensor in graph.initializer:
1020
+ if not (tensor.dims and sum(tensor.dims) >= min_elements):
1021
+ continue
1022
+ moved_initializers.append(tensor)
1023
+
1024
+ for initializer in moved_initializers:
1025
+ graph.initializer.remove(initializer)
1026
+
1027
+ # Add type info, otherwise ORT will raise error: "input arg (*) does not have type information set by parent node."
1028
+ for initializer in moved_initializers:
1029
+ shape = onnx.numpy_helper.to_array(initializer).shape
1030
+ value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape)
1031
+ graph.value_info.append(value_info)
1032
+
1033
+ return moved_initializers
1034
+
1035
+
1036
+ def _attribute_to_pair(attribute):
1037
+ """
1038
+ Convert attribute to kwarg format for use with onnx.helper.make_node.
1039
+ :parameter attribute: attribute in AttributeProto format.
1040
+ :return: attribute in {key: value} format.
1041
+ """
1042
+ if attribute.type == 0:
1043
+ raise ValueError(f"attribute {attribute.name} does not have type specified.")
1044
+
1045
+ # Based on attribute type definitions from AttributeProto
1046
+ # definition in https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
1047
+ if attribute.type == 1:
1048
+ value = attribute.f
1049
+ elif attribute.type == 2:
1050
+ value = attribute.i
1051
+ elif attribute.type == 3:
1052
+ value = attribute.s
1053
+ elif attribute.type == 4:
1054
+ value = attribute.t
1055
+ elif attribute.type == 5:
1056
+ value = attribute.g
1057
+ elif attribute.type == 6:
1058
+ value = attribute.floats
1059
+ elif attribute.type == 7:
1060
+ value = attribute.ints
1061
+ elif attribute.type == 8:
1062
+ value = attribute.strings
1063
+ elif attribute.type == 9:
1064
+ value = attribute.tensors
1065
+ elif attribute.type == 10:
1066
+ value = attribute.graphs
1067
+ else:
1068
+ raise ValueError(f"attribute {attribute.name} has unsupported type {attribute.type}.")
1069
+
1070
+ return (attribute.name, value)
1071
+
1072
+
1073
+ def kwargs_of(node):
1074
+ kwargs = {}
1075
+ for attr in node.attribute:
1076
+ (key, value) = _attribute_to_pair(attr)
1077
+ kwargs.update({key: value})
1078
+ if node.domain:
1079
+ kwargs.update({"domain": node.domain})
1080
+ return kwargs
1081
+
1082
+
1083
+ def shape_of(vi):
1084
+ return tuple([d.dim_param if (d.dim_param) else d.dim_value for d in vi.type.tensor_type.shape.dim])
1085
+
1086
+
1087
+ def update_decoder_subgraph_past_present_share_buffer(subg: GraphProto):
1088
+ input_past_0 = 3
1089
+ output_past_0 = 1
1090
+ new_inputs = []
1091
+ for i, vi in enumerate(subg.input):
1092
+ if i >= input_past_0:
1093
+ shape = shape_of(vi)
1094
+ vi = onnx.helper.make_tensor_value_info( # noqa: PLW2901
1095
+ vi.name,
1096
+ elem_type=vi.type.tensor_type.elem_type,
1097
+ shape=[shape[0], shape[1], shape[2], "max_seq_len", shape[4]],
1098
+ )
1099
+ new_inputs.extend([vi])
1100
+ new_inputs.extend([onnx.helper.make_tensor_value_info("past_sequence_length", onnx.TensorProto.INT32, shape=[1])])
1101
+ subg.ClearField("input")
1102
+ subg.input.extend(new_inputs)
1103
+
1104
+ new_outputs = []
1105
+ for i, vi in enumerate(subg.output):
1106
+ if i >= output_past_0:
1107
+ shape = shape_of(vi)
1108
+ vi = onnx.helper.make_tensor_value_info( # noqa: PLW2901
1109
+ vi.name,
1110
+ elem_type=vi.type.tensor_type.elem_type,
1111
+ shape=[shape[0], shape[1], shape[2], "max_seq_len", shape[4]],
1112
+ )
1113
+ new_outputs.extend([vi])
1114
+ subg.ClearField("output")
1115
+ subg.output.extend(new_outputs)
1116
+
1117
+ new_nodes = []
1118
+ for node in subg.node:
1119
+ if node.op_type == "Attention":
1120
+ kwargs = kwargs_of(node)
1121
+ kwargs.update({"past_present_share_buffer": 1})
1122
+ nis = []
1123
+ nis.extend(node.input)
1124
+ while len(nis) < 6:
1125
+ nis.extend([""])
1126
+ if len(nis) < 7:
1127
+ nis.extend(["past_sequence_length"])
1128
+ node = onnx.helper.make_node("Attention", nis, node.output, name=node.name, **kwargs) # noqa: PLW2901
1129
+ new_nodes.extend([node])
1130
+ subg.ClearField("node")
1131
+ subg.node.extend(new_nodes)
1132
+ return subg
1133
+
1134
+
1135
+ def update_decoder_subgraph_use_decoder_masked_attention(
1136
+ subg: GraphProto, is_beam_search: bool, switch_attention: bool
1137
+ ) -> bool:
1138
+ """Update the Attention nodes to DecoderMaskedSelfAttention.
1139
+
1140
+ Args:
1141
+ subg (GraphProto): GraphProto of the decoder subgraph
1142
+ is_beam_search (bool): Boolean specifying if the sampling algo is BeamSearch
1143
+ switch_attention (bool): Boolean specifying if `Attention` is to be switched with `DecoderMaskedSelfAttention`
1144
+ """
1145
+ if is_beam_search:
1146
+ new_inputs = []
1147
+ for _i, vi in enumerate(subg.input):
1148
+ new_inputs.extend([vi])
1149
+
1150
+ # Add 2 BeamSearch specific inputs
1151
+ new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])])
1152
+ new_inputs.extend(
1153
+ [
1154
+ onnx.helper.make_tensor_value_info(
1155
+ "cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"]
1156
+ )
1157
+ ]
1158
+ )
1159
+ subg.ClearField("input")
1160
+ subg.input.extend(new_inputs)
1161
+
1162
+ if switch_attention:
1163
+ decoder_masked_attention_supported_attr = [
1164
+ "past_present_share_buffer",
1165
+ "num_heads",
1166
+ "scale",
1167
+ "mask_filter_value",
1168
+ "domain",
1169
+ ]
1170
+
1171
+ new_nodes = []
1172
+ for node in subg.node:
1173
+ if node.op_type == "Attention":
1174
+ kwargs = kwargs_of(node)
1175
+ for k in kwargs.copy():
1176
+ # The Attention operator does not support different qkv hidden sizes when past/present
1177
+ # input/output exists (GPT2 model). Hence, we should never run into this.
1178
+ # But, if we do, do not go ahead with the optimization.
1179
+ if k == "qkv_hidden_sizes":
1180
+ return False
1181
+
1182
+ if k not in decoder_masked_attention_supported_attr:
1183
+ # Log the fact that we are removing certain attributes from the node
1184
+ # We don't need to log it for "unidirectional" as we are aware that
1185
+ # decoding attention kernels are unidirectional by definition.
1186
+ if k != "unidirectional":
1187
+ logger.warning(
1188
+ f"Removing attribute: {k} from Attention node while switching to DecoderMaskedSelfAttention"
1189
+ )
1190
+
1191
+ del kwargs[k]
1192
+
1193
+ nis = []
1194
+ nis.extend(node.input)
1195
+
1196
+ # Add 2 BeamSearch specific inputs
1197
+ if is_beam_search:
1198
+ while len(nis) < 7:
1199
+ nis.extend([""])
1200
+ if len(nis) < 8:
1201
+ nis.extend(["beam_width"])
1202
+ if len(nis) < 9:
1203
+ nis.extend(["cache_indirection"])
1204
+
1205
+ node = onnx.helper.make_node( # noqa: PLW2901
1206
+ "DecoderMaskedSelfAttention", nis, node.output, name=node.name, **kwargs
1207
+ )
1208
+ new_nodes.extend([node])
1209
+ subg.ClearField("node")
1210
+ subg.node.extend(new_nodes)
1211
+
1212
+ return True
1213
+
1214
+
1215
+ def find_past_seq_len_usage(subg: GraphProto):
1216
+ """Correct graph which originally use dim of past_seq_len from input_ids's shape which is fixed to max_seq_len after
1217
+ shared past/present buffer
1218
+
1219
+ Args:
1220
+ subg (GraphProto): GraphProto of the decoder subgraph
1221
+ return:
1222
+ tensor_names_to_rename : set of tensor names which is equal to past_sequence_length
1223
+ nodes_to_remove : list of node to remove
1224
+ """
1225
+ tensor_names_to_rename = set()
1226
+ nodes_to_remove = []
1227
+
1228
+ graph_input_names = {inp.name: index for index, inp in enumerate(subg.input)}
1229
+
1230
+ input_name_to_nodes = {}
1231
+ output_name_to_node = {}
1232
+ for node in subg.node:
1233
+ for input_name in node.input:
1234
+ if input_name:
1235
+ if input_name not in input_name_to_nodes:
1236
+ input_name_to_nodes[input_name] = [node]
1237
+ else:
1238
+ input_name_to_nodes[input_name].append(node)
1239
+ for output_name in node.output:
1240
+ if output_name:
1241
+ output_name_to_node[output_name] = node
1242
+
1243
+ for node in subg.node:
1244
+ # find "Shape(past_key_self..) --> Gather(*, 2)"
1245
+ if node.op_type == "Gather":
1246
+ if not node.input[1] or not node.input[0]:
1247
+ continue
1248
+ shape_tensor_name, shape_index_name = (node.input[0], node.input[1])
1249
+ ini_gather_indices = None
1250
+ for tensor in subg.initializer:
1251
+ if tensor.name == shape_index_name:
1252
+ ini_gather_indices = tensor
1253
+ break
1254
+ if ini_gather_indices is None:
1255
+ continue
1256
+ gather_indices_arr = onnx.numpy_helper.to_array(ini_gather_indices)
1257
+ if gather_indices_arr.size == 1 and gather_indices_arr.item() == 2 and node.input[0] in output_name_to_node:
1258
+ shape_node = output_name_to_node[shape_tensor_name]
1259
+ if (
1260
+ shape_node.op_type == "Shape"
1261
+ and shape_node.input[0]
1262
+ and shape_node.input[0] in graph_input_names
1263
+ and (
1264
+ shape_node.input[0].startswith("past_key_self_")
1265
+ or shape_node.input[0].startswith("past_value_self_")
1266
+ )
1267
+ ):
1268
+ tensor_names_to_rename.add(node.output[0])
1269
+ nodes_to_remove.append(node)
1270
+ if len(input_name_to_nodes[shape_node.output[0]]) == 1:
1271
+ nodes_to_remove.append(shape_node)
1272
+ return tensor_names_to_rename, nodes_to_remove
1273
+
1274
+
1275
+ def replace_mha_with_gqa(
1276
+ model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = -1
1277
+ ):
1278
+ # Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes
1279
+ #
1280
+ # attention_mask
1281
+ # / \
1282
+ # ReduceSum Shape
1283
+ # | |
1284
+ # Sub Gather
1285
+ # | |
1286
+ # seqlens_k total_sequence_length
1287
+ # | |
1288
+ # Cast to int32 Cast to int32
1289
+
1290
+ model.add_initializer(
1291
+ onnx.helper.make_tensor(
1292
+ name="one",
1293
+ data_type=TensorProto.INT64,
1294
+ dims=[1],
1295
+ vals=[1],
1296
+ )
1297
+ )
1298
+ reduce_sum_node = onnx.helper.make_node(
1299
+ "ReduceSum",
1300
+ inputs=[attn_mask, "one"],
1301
+ outputs=[attn_mask + "_row_sums"],
1302
+ name=model.create_node_name("ReduceSum"),
1303
+ )
1304
+ sub_node = onnx.helper.make_node(
1305
+ "Sub",
1306
+ inputs=[attn_mask + "_row_sums", "one"],
1307
+ outputs=["seqlens_k_int64"],
1308
+ name=model.create_node_name("Sub"),
1309
+ )
1310
+ seqlen_k_cast_node = onnx.helper.make_node(
1311
+ "Cast",
1312
+ inputs=["seqlens_k_int64"],
1313
+ outputs=["seqlens_k"],
1314
+ name=model.create_node_name("Cast"),
1315
+ to=TensorProto.INT32,
1316
+ )
1317
+ shape_node = onnx.helper.make_node(
1318
+ "Shape",
1319
+ inputs=[attn_mask],
1320
+ outputs=[attn_mask + "_shape"],
1321
+ name=model.create_node_name("Shape"),
1322
+ )
1323
+ gather_node = onnx.helper.make_node(
1324
+ "Gather",
1325
+ inputs=[attn_mask + "_shape", "one"],
1326
+ outputs=["total_seq_len_int64"],
1327
+ name=model.create_node_name("Gather"),
1328
+ axis=0,
1329
+ )
1330
+ total_seqlen_cast_node = onnx.helper.make_node(
1331
+ "Cast",
1332
+ inputs=["total_seq_len_int64"],
1333
+ outputs=["total_seq_len"],
1334
+ name=model.create_node_name("Cast"),
1335
+ to=TensorProto.INT32,
1336
+ )
1337
+ model.model.graph.node.extend(
1338
+ [reduce_sum_node, sub_node, seqlen_k_cast_node, shape_node, gather_node, total_seqlen_cast_node]
1339
+ )
1340
+
1341
+ # Replace MultiHeadAttention with GroupQueryAttention
1342
+ #
1343
+ # When replacing, fuse the following subgraph:
1344
+ #
1345
+ # root_input
1346
+ # / | \
1347
+ # MatMul MatMul MatMul
1348
+ # | | |
1349
+ # Add Add Add (optional Adds)
1350
+ # | | |
1351
+ # RotEmb RotEmb |
1352
+ # \ | /
1353
+ # MultiHeadAttention
1354
+ #
1355
+ # to this new subgraph:
1356
+ #
1357
+ # root_input
1358
+ # |
1359
+ # PackedMatMul (if possible)
1360
+ # |
1361
+ # PackedAdd (if possible)
1362
+ # |
1363
+ # GroupQueryAttention
1364
+ #
1365
+
1366
+ mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node))
1367
+ for idx, node in enumerate(mha_nodes):
1368
+ # Detect Q path to MHA
1369
+ q_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [0, 0, 0])
1370
+ q_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [0, 0])
1371
+
1372
+ q_rotary, q_add, q_matmul = None, None, None
1373
+ if q_path_1 is not None:
1374
+ q_rotary, q_add, q_matmul = q_path_1
1375
+ elif q_path_2 is not None:
1376
+ q_rotary, q_matmul = q_path_2
1377
+
1378
+ # Detect K path to MHA
1379
+ k_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [1, 0, 0])
1380
+ k_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [1, 0])
1381
+
1382
+ k_rotary, k_add, k_matmul = None, None, None
1383
+ if k_path_1 is not None:
1384
+ k_rotary, k_add, k_matmul = k_path_1
1385
+ elif k_path_2 is not None:
1386
+ k_rotary, k_matmul = k_path_2
1387
+
1388
+ # Detect V path to MHA
1389
+ v_path_1 = model.match_parent_path(node, ["Add", "MatMul"], [2, 0])
1390
+ v_path_2 = model.match_parent_path(node, ["MatMul"], [2])
1391
+
1392
+ v_add, v_matmul = None, None
1393
+ if v_path_1 is not None:
1394
+ v_add, v_matmul = v_path_1
1395
+ elif v_path_2 is not None:
1396
+ v_matmul = v_path_2[0]
1397
+
1398
+ # Get `interleaved` attribute from RotaryEmbedding
1399
+ interleaved = 0
1400
+ if q_rotary is not None and k_rotary is not None:
1401
+ for att in q_rotary.attribute:
1402
+ if att.name == "interleaved":
1403
+ interleaved = att.i
1404
+
1405
+ # Get `num_heads` attribute from MHA
1406
+ num_heads = 0
1407
+ for att in node.attribute:
1408
+ if att.name == "num_heads":
1409
+ num_heads = att.i
1410
+
1411
+ # Check if root_input to Q/K/V paths is the same
1412
+ root_input_is_same = q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0]
1413
+
1414
+ # Check if Q/K/V paths all have bias or all don't have bias
1415
+ all_paths_have_bias = q_add is not None and k_add is not None and v_add is not None
1416
+ all_paths_have_no_bias = q_add is None and k_add is None and v_add is None
1417
+
1418
+ # Make PackedMatMul node if possible
1419
+ q_input_to_attention, k_input_to_attention, v_input_to_attention = "", "", ""
1420
+ if root_input_is_same and (all_paths_have_bias or all_paths_have_no_bias):
1421
+ qw = NumpyHelper.to_array(model.get_initializer(q_matmul.input[1]))
1422
+ kw = NumpyHelper.to_array(model.get_initializer(k_matmul.input[1]))
1423
+ vw = NumpyHelper.to_array(model.get_initializer(v_matmul.input[1]))
1424
+
1425
+ dim = qw.shape[-1]
1426
+ qkv_weight = np.stack((qw, kw, vw), axis=1).reshape(dim, 3 * dim)
1427
+ qkv_weight = onnx.numpy_helper.from_array(qkv_weight, name=f"QKV_Weight_{idx}")
1428
+ model.add_initializer(qkv_weight)
1429
+
1430
+ packed_matmul_node = onnx.helper.make_node(
1431
+ "MatMul",
1432
+ inputs=[q_matmul.input[0], qkv_weight.name],
1433
+ outputs=[f"{qkv_weight.name}_output"],
1434
+ name=model.create_node_name("MatMul"),
1435
+ )
1436
+ model.model.graph.node.extend([packed_matmul_node])
1437
+ model.model.graph.node.remove(q_matmul)
1438
+ model.model.graph.node.remove(k_matmul)
1439
+ model.model.graph.node.remove(v_matmul)
1440
+ q_input_to_attention = packed_matmul_node.output[0]
1441
+
1442
+ # Make PackedAdd node if possible
1443
+ if all_paths_have_bias:
1444
+ qb = NumpyHelper.to_array(model.get_initializer(q_add.input[1]))
1445
+ kb = NumpyHelper.to_array(model.get_initializer(k_add.input[1]))
1446
+ vb = NumpyHelper.to_array(model.get_initializer(v_add.input[1]))
1447
+
1448
+ dim = qb.shape[-1]
1449
+ qkv_bias = np.stack((qb, kb, vb), axis=0).reshape(3 * dim)
1450
+ qkv_bias = onnx.numpy_helper.from_array(qkv_bias, name=f"QKV_Bias_{idx}")
1451
+ model.add_initializer(qkv_bias)
1452
+ packed_add_node = onnx.helper.make_node(
1453
+ "Add",
1454
+ inputs=[packed_matmul_node.output[0], qkv_bias.name],
1455
+ outputs=[f"{qkv_bias.name}_output"],
1456
+ )
1457
+ model.model.graph.node.extend([packed_add_node])
1458
+ model.model.graph.node.remove(q_add)
1459
+ model.model.graph.node.remove(k_add)
1460
+ model.model.graph.node.remove(v_add)
1461
+ q_input_to_attention = packed_add_node.output[0]
1462
+
1463
+ else:
1464
+ q_input_to_attention = q_matmul.output[0]
1465
+ k_input_to_attention = k_matmul.output[0]
1466
+ v_input_to_attention = v_matmul.output[0]
1467
+
1468
+ # Make GQA node
1469
+ gqa_node = onnx.helper.make_node(
1470
+ "GroupQueryAttention",
1471
+ inputs=[
1472
+ q_input_to_attention, # query
1473
+ k_input_to_attention, # key
1474
+ v_input_to_attention, # value
1475
+ node.input[6], # past_key
1476
+ node.input[7], # past_value
1477
+ seqlen_k_cast_node.output[0], # seqlens_k (for attention mask)
1478
+ total_seqlen_cast_node.output[0], # total_seq_len (for attention mask)
1479
+ q_rotary.input[2] if q_rotary is not None else "", # cos_cache (for rotary embeddings)
1480
+ q_rotary.input[3] if q_rotary is not None else "", # sin_cache (for rotary embeddings)
1481
+ ],
1482
+ outputs=node.output,
1483
+ name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"),
1484
+ domain="com.microsoft",
1485
+ num_heads=num_heads // world_size,
1486
+ kv_num_heads=num_heads // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
1487
+ local_window_size=window_size,
1488
+ do_rotary=int(q_rotary is not None and k_rotary is not None),
1489
+ rotary_interleaved=interleaved,
1490
+ )
1491
+ model.model.graph.node.remove(node)
1492
+ model.model.graph.node.extend([gqa_node])
1493
+
1494
+ if q_rotary is not None:
1495
+ model.model.graph.node.remove(q_rotary)
1496
+ if k_rotary is not None:
1497
+ model.model.graph.node.remove(k_rotary)
1498
+
1499
+ return model
1500
+
1501
+
1502
+ def update_decoder_subgraph_output_cross_attention(subg: GraphProto):
1503
+ input_self_past_0 = 1
1504
+ # w/wo attention mask, w/wo hidden_state
1505
+ graph_input_names = [gi.name for gi in subg.input]
1506
+ while input_self_past_0 < 3 and not graph_input_names[input_self_past_0].startswith("past"):
1507
+ input_self_past_0 += 1
1508
+ output_self_present_0 = 1
1509
+
1510
+ num_layers = (len(subg.output) - output_self_present_0) // 2
1511
+ input_cross_past_0 = 2 * num_layers + input_self_past_0
1512
+ past_key_cross_inputs = {subg.input[layer * 2 + input_cross_past_0].name: layer for layer in range(num_layers)}
1513
+ print(f" --past_key_cross_inputs={past_key_cross_inputs}")
1514
+
1515
+ input_past_key_cross_0_shape = shape_of(subg.input[input_cross_past_0])
1516
+ print(f"past_key_cross_0_shape is {input_past_key_cross_0_shape}")
1517
+ batch_size_dim = input_past_key_cross_0_shape[0]
1518
+ num_heads_dim = input_past_key_cross_0_shape[1]
1519
+ cross_seq_len_dim = input_past_key_cross_0_shape[2]
1520
+
1521
+ num_layer_output_qk = 0
1522
+ for node in subg.node:
1523
+ if (node.op_type == "DecoderMaskedMultiHeadAttention") and (node.input[1] in past_key_cross_inputs):
1524
+ print(f" -- add cross QK output from: node: {node.name} with output: {node.output}")
1525
+ num_layer_output_qk += 1
1526
+ layer = past_key_cross_inputs[node.input[1]]
1527
+ cross_attention_out_name = f"output_cross_qk_{layer}"
1528
+ appended_names = [""] * (3 - len(node.output))
1529
+ appended_names.append(cross_attention_out_name)
1530
+ node.output.extend(appended_names)
1531
+ node.attribute.extend([onnx.helper.make_attribute("output_qk", 1)])
1532
+
1533
+ cross_attention = onnx.helper.make_tensor_value_info(
1534
+ cross_attention_out_name, TensorProto.FLOAT, [batch_size_dim, num_heads_dim, 1, cross_seq_len_dim]
1535
+ )
1536
+ subg.output.extend([cross_attention])
1537
+ if num_layer_output_qk != num_layers:
1538
+ raise ValueError(f"Did not add cross QK for all layers{num_layers} vs {num_layer_output_qk}")
1539
+
1540
+
1541
+ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: ModelProto):
1542
+ input_self_past_0 = 1
1543
+ # w/wo attention mask, w/wo hidden_state
1544
+ graph_input_names = [gi.name for gi in subg.input]
1545
+ while input_self_past_0 < 3 and not graph_input_names[input_self_past_0].startswith("past"):
1546
+ input_self_past_0 += 1
1547
+ output_self_past_0 = 1
1548
+
1549
+ num_layers = int((len(subg.input) - input_self_past_0) / 4)
1550
+ input_cross_past_0 = 2 * num_layers + input_self_past_0
1551
+
1552
+ new_nodes = []
1553
+ old_nodes = []
1554
+ for node in subg.node:
1555
+ if node.op_type == "MultiHeadAttention":
1556
+ old_nodes.extend([node])
1557
+
1558
+ # If not all the MultiHeadAttention nodes are fused, this optimization is not applicable
1559
+ if len(old_nodes) < num_layers:
1560
+ return False
1561
+
1562
+ # Redirect the RelativePositionBias node's input from past_key_self_0.shape[2] to past_sequence_length.
1563
+ # There is only one RelativePositionBias node in T5 decoder subgraph.
1564
+ rel_pos_bias_node = None
1565
+ for node in subg.node:
1566
+ if node.op_type == "RelativePositionBias":
1567
+ rel_pos_bias_node = node
1568
+ break
1569
+
1570
+ decoder_masked_attention_supported_attr = [
1571
+ "past_present_share_buffer",
1572
+ "num_heads",
1573
+ "scale",
1574
+ "mask_filter_value",
1575
+ "domain",
1576
+ ]
1577
+
1578
+ target_squeezed_past_seq_name = "past_sequence_length_squeezed_int64"
1579
+ tensor_names_to_rename, nodes_to_remove = find_past_seq_len_usage(subg)
1580
+ if len(tensor_names_to_rename) > 0:
1581
+ for name_to_rename in tensor_names_to_rename:
1582
+ print(f"Found tensor name {name_to_rename} to be renamed to {target_squeezed_past_seq_name}")
1583
+ for nr in nodes_to_remove:
1584
+ print(f"Found node to removed: type:{nr.op_type}, name:{nr.name}")
1585
+
1586
+ squeeze_node = onnx.helper.make_node(
1587
+ "Squeeze",
1588
+ ["past_sequence_length"],
1589
+ ["past_sequence_length_squeezed"],
1590
+ name="node_past_sequence_length_squeeze",
1591
+ )
1592
+ cast_node = onnx.helper.make_node(
1593
+ "Cast",
1594
+ ["past_sequence_length_squeezed"],
1595
+ [target_squeezed_past_seq_name],
1596
+ name="node_past_sequence_length_squeeze_cast",
1597
+ to=TensorProto.INT64,
1598
+ )
1599
+ new_nodes.extend([squeeze_node, cast_node])
1600
+
1601
+ for node in subg.node:
1602
+ if len(node.output) > 0 and rel_pos_bias_node is not None and node.output[0] == rel_pos_bias_node.input[1]:
1603
+ cast_node = onnx.helper.make_node(
1604
+ "Cast",
1605
+ ["past_sequence_length"],
1606
+ ["past_sequence_length_int64"],
1607
+ name="past_sequence_length_cast",
1608
+ to=TensorProto.INT64,
1609
+ )
1610
+ node.input[1] = cast_node.output[0]
1611
+ new_nodes.extend([cast_node])
1612
+
1613
+ if node.op_type == "MultiHeadAttention":
1614
+ kwargs = kwargs_of(node)
1615
+ for k in kwargs.copy():
1616
+ if k not in decoder_masked_attention_supported_attr:
1617
+ del kwargs[k]
1618
+
1619
+ # note: This logic only apply to T5 model where there is no bias in Attention node.
1620
+ nis = [
1621
+ node.input[0], # query
1622
+ node.input[1], # key
1623
+ node.input[2], # value
1624
+ ]
1625
+
1626
+ nis.extend([node.input[4] if len(node.input) > 4 else ""]) # 2D mask
1627
+ nis.extend([node.input[5] if len(node.input) > 5 else ""]) # attention_bias
1628
+ nis.extend([node.input[6] if len(node.input) > 6 else ""]) # past_key
1629
+ nis.extend([node.input[7] if len(node.input) > 7 else ""]) # past_value
1630
+ nis.extend(["past_sequence_length"]) # past_sequence_length
1631
+ nis.extend(["beam_width"]) # beam_width
1632
+ nis.extend(["cache_indirection"]) # cache_indirection
1633
+ nis.extend([node.input[3] if len(node.input) > 3 else ""]) # bias
1634
+
1635
+ kwargs["past_present_share_buffer"] = 1
1636
+
1637
+ node = onnx.helper.make_node( # noqa: PLW2901
1638
+ "DecoderMaskedMultiHeadAttention", nis, node.output, name=node.name, **kwargs
1639
+ )
1640
+
1641
+ if node not in nodes_to_remove:
1642
+ for index, name in enumerate(node.input):
1643
+ if name in tensor_names_to_rename:
1644
+ node.input[index] = target_squeezed_past_seq_name
1645
+ new_nodes.extend([node])
1646
+
1647
+ subg.ClearField("node")
1648
+ subg.node.extend(new_nodes)
1649
+ orig_input_names = [inp.name for inp in subg.input]
1650
+
1651
+ new_inputs = []
1652
+ for i, vi in enumerate(subg.input):
1653
+ if i >= input_self_past_0 and i < input_cross_past_0:
1654
+ shape = shape_of(vi)
1655
+ vi = onnx.helper.make_tensor_value_info( # noqa: PLW2901
1656
+ vi.name,
1657
+ elem_type=vi.type.tensor_type.elem_type,
1658
+ shape=[shape[0], shape[1], "max_seq_len", shape[3]],
1659
+ )
1660
+ new_inputs.extend([vi])
1661
+ if "past_sequence_length" not in orig_input_names:
1662
+ new_inputs.extend(
1663
+ [onnx.helper.make_tensor_value_info("past_sequence_length", onnx.TensorProto.INT32, shape=[1])]
1664
+ )
1665
+ if "beam_width" not in orig_input_names:
1666
+ new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])])
1667
+ if "cache_indirection" not in orig_input_names:
1668
+ new_inputs.extend(
1669
+ [
1670
+ onnx.helper.make_tensor_value_info(
1671
+ "cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"]
1672
+ )
1673
+ ]
1674
+ )
1675
+ subg.ClearField("input")
1676
+ subg.input.extend(new_inputs)
1677
+
1678
+ new_outputs = []
1679
+ for i, vi in enumerate(subg.output):
1680
+ if i >= output_self_past_0:
1681
+ shape = shape_of(vi)
1682
+ vi = onnx.helper.make_tensor_value_info( # noqa: PLW2901
1683
+ vi.name,
1684
+ elem_type=vi.type.tensor_type.elem_type,
1685
+ shape=[shape[0], shape[1], "max_seq_len", shape[3]],
1686
+ )
1687
+ new_outputs.extend([vi])
1688
+ subg.ClearField("output")
1689
+ subg.output.extend(new_outputs)
1690
+
1691
+ return True
1692
+
1693
+
1694
+ def pack_qkv_for_decoder_masked_mha(model_proto: ModelProto):
1695
+ onnx_model = OnnxModel(model_proto)
1696
+ output_name_to_node = onnx_model.output_name_to_node()
1697
+
1698
+ nodes_to_add = []
1699
+ nodes_to_remove = []
1700
+ for node in onnx_model.nodes():
1701
+ if node.op_type == "DecoderMaskedMultiHeadAttention":
1702
+ if "past_key_cross" in node.input[1] and "past_value_cross" in node.input[2]:
1703
+ continue
1704
+ q_matmul = output_name_to_node[node.input[0]]
1705
+ k_matmul = output_name_to_node[node.input[1]]
1706
+ v_matmul = output_name_to_node[node.input[2]]
1707
+
1708
+ q_weight = onnx_model.get_initializer(q_matmul.input[1])
1709
+ k_weight = onnx_model.get_initializer(k_matmul.input[1])
1710
+ v_weight = onnx_model.get_initializer(v_matmul.input[1])
1711
+ if not (q_weight and k_weight and v_weight):
1712
+ return False
1713
+
1714
+ qw = NumpyHelper.to_array(q_weight)
1715
+ kw = NumpyHelper.to_array(k_weight)
1716
+ vw = NumpyHelper.to_array(v_weight)
1717
+
1718
+ qkv_weight = np.concatenate([qw, kw, vw], axis=1)
1719
+
1720
+ matmul_node_name = onnx_model.create_node_name("MatMul", name_prefix="MatMul_QKV")
1721
+ weight = onnx.helper.make_tensor(
1722
+ name=matmul_node_name + "_weight",
1723
+ data_type=TensorProto.FLOAT if q_weight.data_type == 1 else TensorProto.FLOAT16,
1724
+ dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
1725
+ vals=qkv_weight.flatten().tolist(),
1726
+ )
1727
+
1728
+ model_proto.graph.initializer.extend([weight])
1729
+
1730
+ matmul_node = onnx.helper.make_node(
1731
+ "MatMul",
1732
+ inputs=[q_matmul.input[0], matmul_node_name + "_weight"],
1733
+ outputs=[matmul_node_name + "_out"],
1734
+ name=matmul_node_name,
1735
+ )
1736
+
1737
+ node.input[0] = matmul_node.output[0]
1738
+ node.input[1] = ""
1739
+ node.input[2] = ""
1740
+
1741
+ nodes_to_add.extend([matmul_node])
1742
+ nodes_to_remove.extend([q_matmul, k_matmul, v_matmul])
1743
+
1744
+ onnx_model.add_nodes(nodes_to_add)
1745
+ onnx_model.remove_nodes(nodes_to_remove)
1746
+ onnx_model.update_graph()
1747
+
1748
+ onnx_model.topological_sort()
1749
+
1750
+ return True
1751
+
1752
+
1753
+ def update_input_shapes_for_gpt2_decoder_model(decoder_onnx_path: str, use_external_data_format: bool = True):
1754
+ """Update the input shapes for the inputs "input_ids" and "position_ids" and make the sequence length dim value 1 for each of them.
1755
+ The decoder model will be over-written.
1756
+
1757
+ Args:
1758
+ decoder_onnx_path (str): Path of GPT-2 decoder onnx model
1759
+ use_external_data_format(bool): output tensors to external data or not.
1760
+ """
1761
+
1762
+ decoder_model_proto = onnx.load_model(decoder_onnx_path, load_external_data=True)
1763
+ for i in range(len(decoder_model_proto.graph.input)):
1764
+ if (
1765
+ decoder_model_proto.graph.input[i].name == "input_ids"
1766
+ or decoder_model_proto.graph.input[i].name == "position_ids"
1767
+ ):
1768
+ shape_dim_proto = decoder_model_proto.graph.input[i].type.tensor_type.shape.dim[1]
1769
+
1770
+ # Clear any existing dim_param first
1771
+ if shape_dim_proto.HasField("dim_param"):
1772
+ shape_dim_proto.Clear()
1773
+
1774
+ # Update dim_value to be 1
1775
+ shape_dim_proto.dim_value = 1
1776
+
1777
+ OnnxModel.save(decoder_model_proto, decoder_onnx_path, save_as_external_data=use_external_data_format)
1778
+ return True
1779
+
1780
+
1781
+ def generate_gpt2_init_decoder(
1782
+ decoder_onnx_path: str, init_decoder_onnx_path: str, use_external_data_format: bool = True
1783
+ ) -> bool:
1784
+ """Generates the initial decoder GPT2 subgraph and saves it for downstream use.
1785
+ The initial decoder model will be saved to init_decoder_onnx_path.
1786
+
1787
+ Args:
1788
+ decoder_onnx_path (str): Path of GPT-2 decoder onnx model
1789
+ init_decoder_onnx_path (str): Path of GPT-2 init decoder onnx model
1790
+ use_external_data_format(bool): output tensors to external data or not.
1791
+ """
1792
+ init_decoder_model_proto = onnx.load_model(decoder_onnx_path, load_external_data=True)
1793
+
1794
+ logits_output_name = init_decoder_model_proto.graph.output[0].name
1795
+
1796
+ gpt2_init_decoder_model = OnnxModel(init_decoder_model_proto)
1797
+
1798
+ output_name_to_node = gpt2_init_decoder_model.output_name_to_node()
1799
+ assert logits_output_name in output_name_to_node
1800
+
1801
+ logits_matmul_node = output_name_to_node[logits_output_name]
1802
+
1803
+ # Sanity check - the logits need to be produced by a MatMul node
1804
+ if logits_matmul_node.op_type != "MatMul":
1805
+ return False
1806
+
1807
+ # Try to find the last residual Add
1808
+ # For fp16, there are Casts along the way
1809
+
1810
+ # Normalization Node is : LayerNormalization
1811
+ logits_matmul_to_residual_add_path = gpt2_init_decoder_model.match_parent_path(
1812
+ logits_matmul_node,
1813
+ [
1814
+ "Cast",
1815
+ "LayerNormalization",
1816
+ "Add",
1817
+ "Add",
1818
+ "Cast",
1819
+ "MatMul",
1820
+ "Cast",
1821
+ "FastGelu",
1822
+ "Cast",
1823
+ "MatMul",
1824
+ "Cast",
1825
+ "LayerNormalization",
1826
+ "Add",
1827
+ ],
1828
+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
1829
+ )
1830
+
1831
+ # Normalization Node is : SkipLayerNormalization
1832
+ if logits_matmul_to_residual_add_path is None:
1833
+ logits_matmul_to_residual_add_path = gpt2_init_decoder_model.match_parent_path(
1834
+ logits_matmul_node,
1835
+ [
1836
+ "Cast",
1837
+ "SkipLayerNormalization",
1838
+ "Cast",
1839
+ "MatMul",
1840
+ "Cast",
1841
+ "FastGelu",
1842
+ "Cast",
1843
+ "MatMul",
1844
+ "Cast",
1845
+ "SkipLayerNormalization",
1846
+ ],
1847
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
1848
+ )
1849
+
1850
+ # Try without the Casts before and after the MatMuls
1851
+ if logits_matmul_to_residual_add_path is None:
1852
+ # Normalization Node is : LayerNormalization
1853
+ logits_matmul_to_residual_add_path = gpt2_init_decoder_model.match_parent_path(
1854
+ logits_matmul_node,
1855
+ ["LayerNormalization", "Add", "Add", "MatMul", "FastGelu", "MatMul", "LayerNormalization", "Add"],
1856
+ [0, 0, 1, 0, 0, 0, 0, 0],
1857
+ )
1858
+
1859
+ # Normalization Node is : SkipLayerNormalization
1860
+ if logits_matmul_to_residual_add_path is None:
1861
+ logits_matmul_to_residual_add_path = gpt2_init_decoder_model.match_parent_path(
1862
+ logits_matmul_node,
1863
+ [
1864
+ "SkipLayerNormalization",
1865
+ "MatMul",
1866
+ "FastGelu",
1867
+ "MatMul",
1868
+ "SkipLayerNormalization",
1869
+ ],
1870
+ [0, 1, 0, 0, 0],
1871
+ )
1872
+
1873
+ # TODO(hasesh): Are there more permutations to try before returning ?
1874
+ if logits_matmul_to_residual_add_path is None:
1875
+ return False
1876
+
1877
+ residual_add_node = logits_matmul_to_residual_add_path[-1]
1878
+
1879
+ # If the last node in the pattern is SkipLayerNormalization, we need to adjust our pattern searches accordingly
1880
+ is_skiplayernorm_path = residual_add_node.op_type == "SkipLayerNormalization"
1881
+
1882
+ # Regular LayerNormalization path
1883
+ if not is_skiplayernorm_path:
1884
+ residual_add_to_attention_parent_index = 0
1885
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
1886
+ residual_add_node, ["Add", "Cast", "MatMul", "Attention"], [residual_add_to_attention_parent_index, 0, 0, 0]
1887
+ )
1888
+
1889
+ # Try other parent index of the residual Add node
1890
+ if residual_add_to_attention_path is None:
1891
+ residual_add_to_attention_parent_index = 1
1892
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
1893
+ residual_add_node,
1894
+ ["Add", "Cast", "MatMul", "Attention"],
1895
+ [residual_add_to_attention_parent_index, 0, 0, 0],
1896
+ )
1897
+
1898
+ # Try without the Casts before and after the MatMuls
1899
+ if residual_add_to_attention_path is None:
1900
+ residual_add_to_attention_parent_index = 0
1901
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
1902
+ residual_add_node, ["Add", "MatMul", "Attention"], [residual_add_to_attention_parent_index, 0, 0]
1903
+ )
1904
+
1905
+ # Try without the Casts before and after the MatMuls and other parent index of the residual Add node
1906
+ if residual_add_to_attention_path is None:
1907
+ residual_add_to_attention_parent_index = 1
1908
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
1909
+ residual_add_node, ["Add", "MatMul", "Attention"], [residual_add_to_attention_parent_index, 0, 0]
1910
+ )
1911
+
1912
+ # SkipLayerNormalization path
1913
+ else:
1914
+ residual_add_to_attention_parent_index = 0
1915
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
1916
+ residual_add_node, ["Cast", "MatMul", "Attention"], [residual_add_to_attention_parent_index, 0, 0]
1917
+ )
1918
+
1919
+ # Try other parent index of the residual Add node
1920
+ if residual_add_to_attention_path is None:
1921
+ residual_add_to_attention_parent_index = 1
1922
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
1923
+ residual_add_node, ["Cast", "MatMul", "Attention"], [residual_add_to_attention_parent_index, 0, 0]
1924
+ )
1925
+
1926
+ # Try without the Casts before and after the MatMuls
1927
+ if residual_add_to_attention_path is None:
1928
+ residual_add_to_attention_parent_index = 0
1929
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
1930
+ residual_add_node, ["MatMul", "Attention"], [residual_add_to_attention_parent_index, 0]
1931
+ )
1932
+
1933
+ # Try without the Casts before and after the MatMuls and other parent index of the residual Add node
1934
+ if residual_add_to_attention_path is None:
1935
+ residual_add_to_attention_parent_index = 1
1936
+ residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
1937
+ residual_add_node, ["MatMul", "Attention"], [residual_add_to_attention_parent_index, 0]
1938
+ )
1939
+
1940
+ # TODO(hasesh): Are there more permutations to try before returning ?
1941
+ if residual_add_to_attention_path is None:
1942
+ return False
1943
+
1944
+ residual_add_to_add_parent_index = 0 if residual_add_to_attention_parent_index == 1 else 1
1945
+
1946
+ # Regular LayerNormalization path
1947
+ if not is_skiplayernorm_path:
1948
+ add_before_residual_add = gpt2_init_decoder_model.match_parent(
1949
+ residual_add_node, "Add", residual_add_to_add_parent_index
1950
+ )
1951
+
1952
+ # SkipLayerNormalization path
1953
+ else:
1954
+ add_before_residual_add = gpt2_init_decoder_model.match_parent(
1955
+ residual_add_node, "SkipLayerNormalization", residual_add_to_add_parent_index
1956
+ )
1957
+
1958
+ if add_before_residual_add is None:
1959
+ return False
1960
+
1961
+ attention = residual_add_to_attention_path[-1]
1962
+ matmul_after_attention = residual_add_to_attention_path[-2]
1963
+
1964
+ slice_starts = onnx.helper.make_tensor(
1965
+ name="SliceLastTokenStarts",
1966
+ data_type=TensorProto.INT32,
1967
+ dims=[1],
1968
+ vals=[-1],
1969
+ )
1970
+
1971
+ slice_ends = onnx.helper.make_tensor(
1972
+ name="SliceLastTokenEnds",
1973
+ data_type=TensorProto.INT32,
1974
+ dims=[1],
1975
+ vals=[-2],
1976
+ )
1977
+
1978
+ slice_axes = onnx.helper.make_tensor(
1979
+ name="SliceLastTokenAxes",
1980
+ data_type=TensorProto.INT32,
1981
+ dims=[1],
1982
+ vals=[1],
1983
+ )
1984
+
1985
+ slice_steps = onnx.helper.make_tensor(
1986
+ name="SliceLastTokenSteps",
1987
+ data_type=TensorProto.INT32,
1988
+ dims=[1],
1989
+ vals=[-1],
1990
+ )
1991
+
1992
+ gpt2_init_decoder_model.add_initializer(slice_starts)
1993
+ gpt2_init_decoder_model.add_initializer(slice_ends)
1994
+ gpt2_init_decoder_model.add_initializer(slice_axes)
1995
+ gpt2_init_decoder_model.add_initializer(slice_steps)
1996
+
1997
+ # Add Slice node to the graph such that it consumes the output of Attention
1998
+ slice_0_output_name = "edge_modified_" + attention.output[0]
1999
+ slice_node_0 = onnx.helper.make_node(
2000
+ "Slice",
2001
+ inputs=[
2002
+ attention.output[0],
2003
+ "SliceLastTokenStarts",
2004
+ "SliceLastTokenEnds",
2005
+ "SliceLastTokenAxes",
2006
+ "SliceLastTokenSteps",
2007
+ ],
2008
+ outputs=[slice_0_output_name],
2009
+ name=gpt2_init_decoder_model.create_node_name("Slice", "GatherLastToken_0_"),
2010
+ )
2011
+
2012
+ # Add Slice node to the graph such that it consumes the output of Add before the residual Add
2013
+ # If the 'Add' output is produced by a SkipLayerNormalization node, then adjust its output
2014
+ # index appropriately
2015
+ add_before_residual_add_output = (
2016
+ add_before_residual_add.output[0] if not is_skiplayernorm_path else add_before_residual_add.output[3]
2017
+ )
2018
+
2019
+ slice_1_output_name = "edge_modified_" + add_before_residual_add.output[0]
2020
+ slice_node_1 = onnx.helper.make_node(
2021
+ "Slice",
2022
+ inputs=[
2023
+ add_before_residual_add_output,
2024
+ "SliceLastTokenStarts",
2025
+ "SliceLastTokenEnds",
2026
+ "SliceLastTokenAxes",
2027
+ "SliceLastTokenSteps",
2028
+ ],
2029
+ outputs=[slice_1_output_name],
2030
+ name=gpt2_init_decoder_model.create_node_name("Slice", "GatherLastToken_1_"),
2031
+ )
2032
+
2033
+ # Add the 2 Slice nodes
2034
+ gpt2_init_decoder_model.add_node(slice_node_0)
2035
+ gpt2_init_decoder_model.add_node(slice_node_1)
2036
+
2037
+ # Adjust the input(s) to the nodes consuming the outputs of the added Slice nodes
2038
+ gpt2_init_decoder_model.replace_node_input(matmul_after_attention, attention.output[0], slice_0_output_name)
2039
+ gpt2_init_decoder_model.replace_node_input(residual_add_node, add_before_residual_add_output, slice_1_output_name)
2040
+
2041
+ # Topologically sort the updated graph
2042
+ gpt2_init_decoder_model.topological_sort()
2043
+
2044
+ # Save the init decoder model
2045
+ OnnxModel.save(init_decoder_model_proto, init_decoder_onnx_path, save_as_external_data=use_external_data_format)
2046
+ return True
2047
+
2048
+
2049
+ def make_dim_proto_numeric_t5(model, config):
2050
+ """Make dim_proto numeric.
2051
+
2052
+ Args:
2053
+ model: T5 encoder and decoder model.
2054
+ config: T5 config.
2055
+ """
2056
+ sequence_length = str(1)
2057
+ num_heads = str(config.num_heads)
2058
+ hidden_size = str(config.d_model)
2059
+ head_size = str(config.d_kv)
2060
+
2061
+ for tensor in model.graph.output:
2062
+ for dim_proto in tensor.type.tensor_type.shape.dim:
2063
+ if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
2064
+ sequence_length,
2065
+ num_heads,
2066
+ hidden_size,
2067
+ head_size,
2068
+ ]:
2069
+ dim_value = int(dim_proto.dim_param)
2070
+ dim_proto.Clear()
2071
+ dim_proto.dim_value = dim_value
2072
+
2073
+ for tensor in model.graph.input:
2074
+ for dim_proto in tensor.type.tensor_type.shape.dim:
2075
+ if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
2076
+ sequence_length,
2077
+ num_heads,
2078
+ hidden_size,
2079
+ head_size,
2080
+ ]:
2081
+ dim_value = int(dim_proto.dim_param)
2082
+ dim_proto.Clear()
2083
+ dim_proto.dim_value = dim_value
2084
+
2085
+
2086
+ def convert_generation_model(args: argparse.Namespace, generation_type: GenerationType = GenerationType.BEAMSEARCH):
2087
+ """Convert model according to command line arguments.
2088
+
2089
+ Args:
2090
+ args (argparse.Namespace): arguments parsed from command line
2091
+ """
2092
+ is_gpt2: bool = args.model_type == "gpt2"
2093
+ is_beamsearch: bool = generation_type == GenerationType.BEAMSEARCH
2094
+ is_greedysearch: bool = generation_type == GenerationType.GREEDYSEARCH
2095
+ is_sampling: bool = generation_type == GenerationType.SAMPLING
2096
+ past_present_share_buffer: bool = args.past_present_share_buffer
2097
+
2098
+ logger.info(f"**** past_present_share_buffer={past_present_share_buffer}")
2099
+ if len(args.op_block_list) == 1 and args.op_block_list[0] == "auto":
2100
+ if is_gpt2 and args.precision == Precision.FLOAT16:
2101
+ args.op_block_list = ["Add", "LayerNormalization", "SkipLayerNormalization", "FastGelu"]
2102
+ logger.info(f"**** Setting op_block_list to {args.op_block_list}")
2103
+ logger.info("**** use --op_block_list if you want to override the block operator list.")
2104
+ else:
2105
+ args.op_block_list = []
2106
+
2107
+ if is_greedysearch or is_sampling:
2108
+ if not is_gpt2:
2109
+ raise NotImplementedError("Currently only gpt2 with greedy search/sampling is supported")
2110
+ if args.output_sequences_scores:
2111
+ raise NotImplementedError("output_sequences_scores currently is not supported in greedy search/sampling")
2112
+ if args.output_token_scores:
2113
+ raise NotImplementedError("output_token_scores currently is not supported in greedy search/sampling")
2114
+
2115
+ # For BeamSearch, sharing buffers for past and present states is only supported
2116
+ # when using `use_decoder_masked_attention`
2117
+ if past_present_share_buffer and is_beamsearch and not args.use_decoder_masked_attention:
2118
+ raise ValueError(
2119
+ "`use_decoder_masked_attention` MUST be turned on to use `past_present_share_buffer` in case of BeamSearch"
2120
+ )
2121
+
2122
+ # For any kind of sampling, using decoder masked multihead attention is only supported
2123
+ # when using `past_present_share_buffer`
2124
+ if args.use_decoder_masked_attention and not past_present_share_buffer:
2125
+ raise ValueError("`past_present_share_buffer` MUST be turned on to use `use_decoder_masked_attention`")
2126
+
2127
+ # For any kind of sampling, using decoder masked multihead attention is only supported
2128
+ # on GPUs
2129
+ if args.use_decoder_masked_attention and not args.use_gpu:
2130
+ raise ValueError("`use_decoder_masked_attention` option is only supported on GPUs")
2131
+
2132
+ if is_gpt2:
2133
+ if args.decoder_onnx and os.path.exists(args.decoder_onnx):
2134
+ logger.info(f"skip convert_to_onnx since path existed: {args.decoder_onnx}")
2135
+ else:
2136
+ if not args.decoder_onnx:
2137
+ onnx_filename = "{}_past_{}.onnx".format(
2138
+ args.model_name_or_path, "fp16" if args.precision == Precision.FLOAT16 else "fp32"
2139
+ )
2140
+ args.decoder_onnx = Path(Path(args.output).parent, onnx_filename).as_posix()
2141
+
2142
+ logger.info(f"Convert GPT model {args.model_name_or_path} to onnx {args.decoder_onnx} ...")
2143
+ gpt2_to_onnx(args)
2144
+ else: # t5 or mt5
2145
+ if args.decoder_onnx and args.encoder_decoder_init_onnx:
2146
+ logger.info(
2147
+ f"skip convert_to_onnx since paths specified: {args.decoder_onnx} and {args.encoder_decoder_init_onnx}"
2148
+ )
2149
+ else:
2150
+ logger.info(f"Convert model {args.model_name_or_path} to onnx ...")
2151
+ t5_to_onnx(args)
2152
+
2153
+ # We only want to pad the logits MatMul weight in the decoder for fp16 models.
2154
+ # The inherent assumption is that fp16 models run on GPU for which all
2155
+ # dims need to be a multiple of 8 to leverage tensor cores.
2156
+ # NOTE: We currently only support padding the MatMul logits weight for GPT2 GreedySearch/BeamSearch.
2157
+ # This can be expanded to other models/decoding strategies later
2158
+ logits_matmul_weight_padded = False
2159
+ if (
2160
+ not args.disable_pad_vocab_size
2161
+ and args.precision == Precision.FLOAT16
2162
+ and is_gpt2
2163
+ and (is_beamsearch or is_greedysearch or is_sampling)
2164
+ ):
2165
+ logger.info(
2166
+ f"Pad logits MatMul weights for optimal MatMul perf in fp16 on {args.decoder_onnx}. "
2167
+ "The file will be overwritten."
2168
+ )
2169
+ logits_matmul_weight_padded = pad_weights_of_logits_matmul(args.decoder_onnx, args.use_external_data_format)
2170
+ if not logits_matmul_weight_padded:
2171
+ logger.warning(
2172
+ "Tried and failed to pad logits MatMul weights. Performance may be sub-optimal for this MatMul"
2173
+ )
2174
+
2175
+ gpt2_init_decoder_generated = False
2176
+ gpt2_init_decoder_onnx_path = None
2177
+ if (
2178
+ not args.disable_separate_gpt2_decoder_for_init_run
2179
+ and is_gpt2
2180
+ and (is_beamsearch or is_greedysearch or is_sampling)
2181
+ ):
2182
+ logger.info(f"Creating an initial run GPT2 decoder from {args.decoder_onnx}. ")
2183
+
2184
+ gpt2_init_decoder_onnx_filename = "gpt2_init_past_{}.onnx".format(
2185
+ "fp16" if args.precision == Precision.FLOAT16 else "fp32"
2186
+ )
2187
+
2188
+ gpt2_init_decoder_onnx_path = Path(Path(args.output).parent, gpt2_init_decoder_onnx_filename).as_posix()
2189
+
2190
+ gpt2_init_decoder_generated = generate_gpt2_init_decoder(
2191
+ args.decoder_onnx, gpt2_init_decoder_onnx_path, args.use_external_data_format
2192
+ )
2193
+
2194
+ if not gpt2_init_decoder_generated:
2195
+ logger.warning(
2196
+ "Tried and failed to generate the init decoder GPT2 model. "
2197
+ "Performance may be sub-optimal for the initial decoding run"
2198
+ )
2199
+
2200
+ # Update the graph input shapes for the non-initial decoder model to account
2201
+ # for the fact that the sequence length will always be 1
2202
+ if gpt2_init_decoder_generated and not update_input_shapes_for_gpt2_decoder_model(
2203
+ args.decoder_onnx, args.use_external_data_format
2204
+ ):
2205
+ # Can't proceed further - better to raise an exception
2206
+ raise ValueError("Could not update the input shapes for the non-initial decoder subgraph.")
2207
+
2208
+ # If the user explicitly requests running shape inference or if we padded/mutated
2209
+ # weight(s)/input shape(s) in the decoder, we want to run shape inference to capture the new
2210
+ # shapes
2211
+ if logits_matmul_weight_padded or args.run_shape_inference or gpt2_init_decoder_generated:
2212
+ logger.info(f"Run symbolic shape inference on {args.decoder_onnx}. The file will be overwritten.")
2213
+ shape_inference(args.decoder_onnx, args.use_external_data_format)
2214
+ if gpt2_init_decoder_generated:
2215
+ logger.info(f"Run symbolic shape inference on {gpt2_init_decoder_onnx_path}. The file will be overwritten.")
2216
+ shape_inference(gpt2_init_decoder_onnx_path, args.use_external_data_format)
2217
+
2218
+ if is_gpt2:
2219
+ config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
2220
+ elif args.model_type == "t5":
2221
+ config = T5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
2222
+ else:
2223
+ config = MT5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
2224
+
2225
+ if args.verbose:
2226
+ logger.info(f"Config={config}")
2227
+
2228
+ eos_token_id = config.eos_token_id
2229
+ pad_token_id = config.eos_token_id if is_gpt2 else config.pad_token_id
2230
+ vocab_size = config.vocab_size
2231
+
2232
+ # if vocab_size is given in parameters use that.
2233
+ if args.vocab_size != -1:
2234
+ vocab_size = args.vocab_size
2235
+
2236
+ if args.eos_token_id != -1:
2237
+ eos_token_id = args.eos_token_id
2238
+ if args.pad_token_id != -1:
2239
+ pad_token_id = args.pad_token_id
2240
+
2241
+ decoder_model = onnx.load_model(args.decoder_onnx, load_external_data=True)
2242
+ decoder_model.graph.name = f"{args.model_type} decoder"
2243
+
2244
+ gpt2_init_decoder_model = None
2245
+ if args.model_type == "gpt2":
2246
+ verify_gpt2_subgraph(decoder_model.graph, args.precision)
2247
+
2248
+ # If we generated the init decoder model, verify that as well
2249
+ if gpt2_init_decoder_generated:
2250
+ gpt2_init_decoder_model = onnx.load_model(gpt2_init_decoder_onnx_path, load_external_data=True)
2251
+ gpt2_init_decoder_model.graph.name = f"{args.model_type} init decoder"
2252
+ verify_gpt2_subgraph(gpt2_init_decoder_model.graph, args.precision)
2253
+ else:
2254
+ verify_t5_decoder_subgraph(decoder_model.graph, args.precision)
2255
+
2256
+ inputs = None
2257
+ if is_beamsearch:
2258
+ inputs = [
2259
+ "input_ids",
2260
+ "max_length",
2261
+ "min_length",
2262
+ "num_beams",
2263
+ "num_return_sequences",
2264
+ "length_penalty",
2265
+ "repetition_penalty",
2266
+ ]
2267
+ elif is_greedysearch or is_sampling:
2268
+ inputs = [
2269
+ "input_ids",
2270
+ "max_length",
2271
+ "min_length",
2272
+ "repetition_penalty",
2273
+ ]
2274
+
2275
+ if args.vocab_mask:
2276
+ inputs.append("vocab_mask")
2277
+ else:
2278
+ inputs.append("")
2279
+
2280
+ if args.prefix_vocab_mask:
2281
+ inputs.append("prefix_vocab_mask")
2282
+ else:
2283
+ inputs.append("")
2284
+
2285
+ if args.custom_attention_mask:
2286
+ inputs.append("attention_mask")
2287
+ else:
2288
+ inputs.append("")
2289
+
2290
+ if is_sampling:
2291
+ if args.custom and args.presence_mask:
2292
+ inputs.append("presence_mask")
2293
+ else:
2294
+ inputs.append("")
2295
+
2296
+ if args.seed:
2297
+ inputs.append("seed")
2298
+
2299
+ outputs = ["sequences"]
2300
+ if args.output_sequences_scores:
2301
+ outputs.append("sequences_scores")
2302
+
2303
+ if args.output_token_scores:
2304
+ assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores"
2305
+ outputs.append("scores")
2306
+
2307
+ node = None
2308
+ if is_beamsearch:
2309
+ node = onnx.helper.make_node(
2310
+ "BeamSearch",
2311
+ inputs=inputs,
2312
+ outputs=outputs,
2313
+ name=f"BeamSearch_{args.model_type}",
2314
+ )
2315
+ elif is_greedysearch:
2316
+ node = onnx.helper.make_node(
2317
+ "GreedySearch",
2318
+ inputs=inputs,
2319
+ outputs=outputs,
2320
+ name=f"GreedySearch_{args.model_type}",
2321
+ )
2322
+ elif is_sampling:
2323
+ node = onnx.helper.make_node(
2324
+ "Sampling",
2325
+ inputs=inputs,
2326
+ outputs=outputs,
2327
+ name=f"Sampling_{args.model_type}",
2328
+ )
2329
+
2330
+ node.domain = "com.microsoft"
2331
+
2332
+ attr_to_extend = None
2333
+ if is_beamsearch:
2334
+ attr_to_extend = [
2335
+ onnx.helper.make_attribute("eos_token_id", eos_token_id),
2336
+ onnx.helper.make_attribute("pad_token_id", pad_token_id),
2337
+ onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
2338
+ onnx.helper.make_attribute("early_stopping", 1 if args.early_stopping else 0),
2339
+ onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
2340
+ ]
2341
+ elif is_greedysearch:
2342
+ attr_to_extend = [
2343
+ onnx.helper.make_attribute("eos_token_id", eos_token_id),
2344
+ onnx.helper.make_attribute("pad_token_id", pad_token_id),
2345
+ onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
2346
+ onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
2347
+ ]
2348
+ elif is_sampling:
2349
+ attr_to_extend = [
2350
+ onnx.helper.make_attribute("eos_token_id", eos_token_id),
2351
+ onnx.helper.make_attribute("pad_token_id", pad_token_id),
2352
+ onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
2353
+ onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
2354
+ onnx.helper.make_attribute("temperature", args.temperature),
2355
+ onnx.helper.make_attribute("top_p", args.top_p),
2356
+ onnx.helper.make_attribute("filter_value", args.filter_value),
2357
+ onnx.helper.make_attribute("min_tokens_to_keep", args.min_tokens_to_keep),
2358
+ onnx.helper.make_attribute("custom", args.custom),
2359
+ onnx.helper.make_attribute("presence_penalty", args.presence_penalty),
2360
+ ]
2361
+
2362
+ # Explicitly pass in the vocab size via an attribute
2363
+ if logits_matmul_weight_padded:
2364
+ attr_to_extend.extend([onnx.helper.make_attribute("vocab_size", vocab_size)])
2365
+
2366
+ node.attribute.extend(attr_to_extend)
2367
+
2368
+ initializers = []
2369
+
2370
+ if args.model_type in ["t5", "mt5"]:
2371
+ if args.run_shape_inference:
2372
+ logger.info(f"Symbolic shape inference on {args.encoder_decoder_init_onnx}. The file will be overwritten.")
2373
+ shape_inference(args.encoder_decoder_init_onnx, args.use_external_data_format)
2374
+ encoder_model = onnx.load_model(args.encoder_decoder_init_onnx, load_external_data=True)
2375
+ encoder_model.graph.name = f"{args.model_type} encoder and decoder init"
2376
+ verify_t5_encoder_decoder_init_subgraph(encoder_model.graph, args.precision)
2377
+
2378
+ make_dim_proto_numeric_t5(encoder_model, config)
2379
+ make_dim_proto_numeric_t5(decoder_model, config)
2380
+
2381
+ # Update decoder subgraph in preparation to use past present share buffer
2382
+ if past_present_share_buffer:
2383
+ if not args.use_decoder_masked_attention:
2384
+ raise ValueError("past_present_share_buffer is only supported with use_decoder_masked_attention")
2385
+
2386
+ logger.info(
2387
+ "*****update t5 decoder subgraph to share past/present buffer and use decoder_masked_multihead_attention*****"
2388
+ )
2389
+ if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph):
2390
+ logger.info("*****update t5 decoder subgraph successfully!!!*****")
2391
+ else:
2392
+ logger.info("*****DecoderMaskedMultiHeadAttention is not applied to T5 decoder*****")
2393
+
2394
+ if pack_qkv_for_decoder_masked_mha(decoder_model):
2395
+ logger.info("*****pack qkv for decoder masked mha successfully!!!*****")
2396
+ else:
2397
+ logger.info("*****pack qkv for decoder masked mha failed!!!*****")
2398
+
2399
+ if not args.disable_shared_initializers:
2400
+ # Unique shared initializers from the decoder and decoder_init could reduce memory usage in inference.
2401
+ initializers = get_shared_initializers(encoder_model, decoder_model)
2402
+ logger.info(
2403
+ f"{len(initializers)} shared initializers ({[i.name for i in initializers]}) in encoder and decoder subgraphs are moved to the main graph"
2404
+ )
2405
+
2406
+ # TODO(tianleiwu): investigate the following which causes error in inference
2407
+ # Move initializer from subgraph to main graph could reduce memory usage in inference.
2408
+ # moved_initializers = move_initializers(encoder_model.graph)
2409
+ # logger.info(
2410
+ # f"{len(moved_initializers)} initializers ({[i.name for i in moved_initializers]}) from the encoder are moved to the main graph"
2411
+ # )
2412
+ # initializers.extend(moved_initializers)
2413
+
2414
+ node.attribute.extend(
2415
+ [
2416
+ onnx.helper.make_attribute("encoder", encoder_model.graph),
2417
+ onnx.helper.make_attribute("decoder", decoder_model.graph),
2418
+ onnx.helper.make_attribute(
2419
+ "decoder_start_token_id",
2420
+ config.decoder_start_token_id if len(encoder_model.graph.input) == 3 else -1,
2421
+ ),
2422
+ ]
2423
+ )
2424
+ else:
2425
+ if gpt2_init_decoder_generated:
2426
+ # Move shared initializers (shared between init decoder and decoder models) to the main
2427
+ # graph and remove them from these models
2428
+ if not args.disable_shared_initializers:
2429
+ # Unique shared initializers from the decoder and decoder_init could reduce memory usage in inference.
2430
+ initializers = get_shared_initializers(gpt2_init_decoder_model, decoder_model)
2431
+ logger.info(
2432
+ f"{len(initializers)} shared initializers ({[i.name for i in initializers]}) in decoder and init decoder subgraphs are moved to the main graph"
2433
+ )
2434
+
2435
+ # Update init decoder subgraph in preparation to use past present share buffer
2436
+ if past_present_share_buffer:
2437
+ logger.info("*****update init decoder subgraph to make past and present share buffer******************")
2438
+ update_decoder_subgraph_past_present_share_buffer(gpt2_init_decoder_model.graph)
2439
+
2440
+ # Update init decoder subgraph in preparation to use DecoderMaskedSelfAttention
2441
+ # NOTE: Even if we will not use DecoderMaskedSelfAttention in the init decoder subgraph
2442
+ # it makes the runtime changes cleaner if we keep both the init decoder and decoder subgraphs
2443
+ # same in terms of the subgraph inputs.
2444
+ if args.use_decoder_masked_attention and not update_decoder_subgraph_use_decoder_masked_attention(
2445
+ gpt2_init_decoder_model.graph, is_beamsearch, False
2446
+ ):
2447
+ raise ValueError("Could not update the init decoder subgraph to use DecoderMaskedSelfAttention")
2448
+
2449
+ node.attribute.append(onnx.helper.make_attribute("init_decoder", gpt2_init_decoder_model.graph))
2450
+ else:
2451
+ # Move initializer from subgraph to main graph could reduce memory usage in inference.
2452
+ initializers = move_initializers(decoder_model.graph)
2453
+ logger.info(f"{len(initializers)} initializers from the decoder are moved to the main graph")
2454
+
2455
+ # Update decoder subgraph in preparation to use past present share buffer
2456
+ if past_present_share_buffer:
2457
+ logger.info("*****update decoder subgraph to make past and present share buffer******************")
2458
+ update_decoder_subgraph_past_present_share_buffer(decoder_model.graph)
2459
+
2460
+ # Update decoder subgraph in preparation to use DecoderMaskedSelfAttention
2461
+ if args.use_decoder_masked_attention and not update_decoder_subgraph_use_decoder_masked_attention(
2462
+ decoder_model.graph, is_beamsearch, True
2463
+ ):
2464
+ raise ValueError("Could not update the decoder subgraph to use DecoderMaskedSelfAttention")
2465
+
2466
+ node.attribute.append(onnx.helper.make_attribute("decoder", decoder_model.graph))
2467
+
2468
+ # graph inputs
2469
+ input_ids = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"])
2470
+ max_length = onnx.helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
2471
+ min_length = onnx.helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
2472
+ num_beams = onnx.helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1])
2473
+ num_return_sequences = onnx.helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1])
2474
+ length_penalty = onnx.helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
2475
+ repetition_penalty = onnx.helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1])
2476
+
2477
+ graph_inputs = None
2478
+ if is_beamsearch:
2479
+ graph_inputs = [
2480
+ input_ids,
2481
+ max_length,
2482
+ min_length,
2483
+ num_beams,
2484
+ num_return_sequences,
2485
+ length_penalty,
2486
+ repetition_penalty,
2487
+ ]
2488
+ elif is_greedysearch or is_sampling:
2489
+ graph_inputs = [
2490
+ input_ids,
2491
+ max_length,
2492
+ min_length,
2493
+ repetition_penalty,
2494
+ ]
2495
+
2496
+ if args.vocab_mask:
2497
+ vocab_mask = onnx.helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size])
2498
+ graph_inputs.append(vocab_mask)
2499
+
2500
+ if args.prefix_vocab_mask:
2501
+ prefix_vocab_mask = onnx.helper.make_tensor_value_info(
2502
+ "prefix_vocab_mask", TensorProto.INT32, ["batch_size", vocab_size]
2503
+ )
2504
+ graph_inputs.append(prefix_vocab_mask)
2505
+
2506
+ if args.custom_attention_mask:
2507
+ attention_mask = onnx.helper.make_tensor_value_info(
2508
+ "attention_mask", TensorProto.INT32, ["batch_size", "sequence_length"]
2509
+ )
2510
+ graph_inputs.append(attention_mask)
2511
+
2512
+ if args.custom and args.presence_mask:
2513
+ presence_mask = onnx.helper.make_tensor_value_info(
2514
+ "presence_mask", TensorProto.INT32, ["batch_size", vocab_size]
2515
+ )
2516
+ graph_inputs.append(presence_mask)
2517
+
2518
+ if is_sampling and args.seed:
2519
+ seed = onnx.helper.make_tensor_value_info("seed", TensorProto.INT32, [1])
2520
+ graph_inputs.append(seed)
2521
+
2522
+ # graph outputs
2523
+ sequences = None
2524
+ if is_beamsearch:
2525
+ sequences = onnx.helper.make_tensor_value_info(
2526
+ "sequences",
2527
+ TensorProto.INT32,
2528
+ ["batch_size", "num_return_sequences", "max_length"],
2529
+ )
2530
+ elif is_greedysearch or is_sampling:
2531
+ sequences = onnx.helper.make_tensor_value_info(
2532
+ "sequences",
2533
+ TensorProto.INT32,
2534
+ ["batch_size", "max_length"],
2535
+ )
2536
+
2537
+ graph_outputs = [sequences]
2538
+
2539
+ if args.output_sequences_scores:
2540
+ sequences_scores = onnx.helper.make_tensor_value_info(
2541
+ "sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"]
2542
+ )
2543
+ graph_outputs.append(sequences_scores)
2544
+
2545
+ if args.output_token_scores:
2546
+ scores = onnx.helper.make_tensor_value_info(
2547
+ "scores",
2548
+ TensorProto.FLOAT,
2549
+ ["max_length - sequence_length", "batch_size", "num_beams", vocab_size],
2550
+ )
2551
+ graph_outputs.append(scores)
2552
+
2553
+ new_graph = onnx.helper.make_graph(
2554
+ [node],
2555
+ f"{args.model_type} beam search" if not is_greedysearch else f"{args.model_type} greedy search",
2556
+ graph_inputs,
2557
+ graph_outputs,
2558
+ initializers,
2559
+ )
2560
+
2561
+ # Create the model
2562
+ new_model = onnx.helper.make_model(
2563
+ new_graph,
2564
+ producer_name="onnxruntime.transformers",
2565
+ opset_imports=decoder_model.opset_import,
2566
+ )
2567
+
2568
+ # TODO(tianleiwu): move shared initializers from T5 encoder and decoder subgraphs to parent graph to save memory.
2569
+ if args.use_external_data_format:
2570
+ from packaging import version
2571
+
2572
+ if version.parse(onnx.__version__) < version.parse("1.12.0"):
2573
+ logger.warning("Require onnx >= 1.12 to save large (>2GB) model!")
2574
+
2575
+ OnnxModel.save(
2576
+ new_model,
2577
+ args.output,
2578
+ save_as_external_data=True,
2579
+ all_tensors_to_one_file=True,
2580
+ )
2581
+ else:
2582
+ onnx.save(new_model, args.output)
2583
+ logger.info(f"model save to {args.output}")
2584
+
2585
+
2586
+ def test_torch_performance(
2587
+ args: argparse.Namespace,
2588
+ model: Union[GPT2LMHeadModel, T5ForConditionalGeneration],
2589
+ input_ids: torch.Tensor,
2590
+ attention_mask: torch.Tensor,
2591
+ eos_token_id: int,
2592
+ pad_token_id: int,
2593
+ bad_words_ids: List[List[int]],
2594
+ ) -> Dict[str, Any]:
2595
+ """Test PyTorch performance of text generation.
2596
+
2597
+ Args:
2598
+ args (argparse.Namespace): arguments parsed from command line
2599
+ model (Union[GPT2LMHeadModel, T5ForConditionalGeneration]): PyTorch model
2600
+ input_ids (torch.Tensor): input_ids
2601
+ attention_mask (torch.Tensor): Attention mask
2602
+ eos_token_id (int): EOS token ID
2603
+ pad_token_id (int): Padding token ID
2604
+ bad_words_ids (List[List[int]]): Words shall not be generated.
2605
+
2606
+ Raises:
2607
+ RuntimeError: PyTorch with CUDA is not available for --use_gpu
2608
+
2609
+ Returns:
2610
+ Dict[str, Any]: A dictionary with string with metric name, and value can be integer or string.
2611
+ """
2612
+ if args.use_gpu and not torch.cuda.is_available():
2613
+ raise RuntimeError("Please install PyTorch with Cuda for testing gpu performance.")
2614
+
2615
+ if args.precision == Precision.FLOAT16:
2616
+ model.half()
2617
+
2618
+ device = torch.device("cuda:0" if args.use_gpu else "cpu")
2619
+ model.to(device)
2620
+
2621
+ torch.set_grad_enabled(False)
2622
+ input_ids = input_ids.to(device)
2623
+ attention_mask = attention_mask.to(device)
2624
+
2625
+ torch_latency = []
2626
+ for _ in range(args.total_runs):
2627
+ start = time.time()
2628
+ _ = model.generate(
2629
+ input_ids=input_ids,
2630
+ attention_mask=attention_mask,
2631
+ max_length=args.max_length,
2632
+ min_length=args.min_length,
2633
+ num_beams=args.num_beams,
2634
+ early_stopping=args.early_stopping,
2635
+ no_repeat_ngram_size=args.no_repeat_ngram_size,
2636
+ eos_token_id=eos_token_id,
2637
+ pad_token_id=pad_token_id,
2638
+ num_return_sequences=args.num_return_sequences,
2639
+ length_penalty=args.length_penalty,
2640
+ repetition_penalty=args.repetition_penalty,
2641
+ bad_words_ids=bad_words_ids if bad_words_ids else None,
2642
+ return_dict_in_generate=True,
2643
+ output_scores=args.output_sequences_scores or args.output_token_scores,
2644
+ )
2645
+ torch_latency.append(time.time() - start)
2646
+ batch_size = input_ids.shape[0]
2647
+ from benchmark_helper import get_latency_result
2648
+
2649
+ return get_latency_result(torch_latency, batch_size)
2650
+
2651
+
2652
+ def create_attention_mask(input_ids, pad_token_id):
2653
+ attention_mask = np.ones(input_ids.shape, dtype=np.int32)
2654
+ for i in range(input_ids.shape[0]):
2655
+ abs_pos = 0
2656
+ for j in range(input_ids.shape[1]):
2657
+ if input_ids[i][j] == pad_token_id and abs_pos == 0:
2658
+ attention_mask[i][j] = 0
2659
+ else:
2660
+ abs_pos += 1
2661
+ return attention_mask
2662
+
2663
+
2664
+ def test_gpt_model(args: argparse.Namespace, sentences: Optional[List[str]] = None, is_greedy: bool = False):
2665
+ """Test GPT-2 model
2666
+
2667
+ Args:
2668
+ args (argparse.Namespace): arguments parsed from command line
2669
+ sentences (Optional[List[str]], optional): input text. Defaults to None.
2670
+
2671
+ Returns:
2672
+ Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string.
2673
+ """
2674
+ assert args.model_type == "gpt2"
2675
+
2676
+ tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
2677
+ tokenizer.padding_side = "left"
2678
+ tokenizer.pad_token = tokenizer.eos_token
2679
+
2680
+ model = GPT2LMHeadModel.from_pretrained(
2681
+ args.model_name_or_path,
2682
+ cache_dir=args.cache_dir,
2683
+ pad_token_id=tokenizer.eos_token_id,
2684
+ )
2685
+
2686
+ # Use different length sentences to test batching
2687
+ if sentences is None:
2688
+ sentences = [
2689
+ "The product is released",
2690
+ "I enjoy walking in the park",
2691
+ "Test best way to invest",
2692
+ ]
2693
+
2694
+ inputs = tokenizer(sentences, return_tensors="pt", padding=True)
2695
+ input_ids = inputs["input_ids"]
2696
+ attention_mask = inputs["attention_mask"]
2697
+
2698
+ bad_words = "walk in park"
2699
+ bad_words_ids = tokenizer.encode(bad_words, add_prefix_space=True)
2700
+ bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list
2701
+ if args.vocab_mask:
2702
+ logger.debug("bad_words_ids", bad_words_ids) # noqa: PLE1205
2703
+ else:
2704
+ bad_words_ids = []
2705
+
2706
+ config = model.config
2707
+ eos_token_id = config.eos_token_id
2708
+ pad_token_id = config.eos_token_id
2709
+ vocab_size = config.vocab_size
2710
+
2711
+ torch_decoded_sequences = []
2712
+ beam_outputs = None
2713
+ if not args.disable_parity:
2714
+ print("-" * 50)
2715
+ print("Test PyTorch model and beam search with huggingface transformers...")
2716
+ beam_outputs = model.generate(
2717
+ input_ids=input_ids,
2718
+ attention_mask=attention_mask,
2719
+ max_length=args.max_length,
2720
+ min_length=args.min_length,
2721
+ num_beams=args.num_beams,
2722
+ early_stopping=args.early_stopping,
2723
+ no_repeat_ngram_size=args.no_repeat_ngram_size,
2724
+ eos_token_id=eos_token_id,
2725
+ pad_token_id=pad_token_id,
2726
+ num_return_sequences=args.num_return_sequences,
2727
+ length_penalty=args.length_penalty,
2728
+ repetition_penalty=args.repetition_penalty,
2729
+ bad_words_ids=bad_words_ids if bad_words_ids else None,
2730
+ return_dict_in_generate=True,
2731
+ output_scores=args.output_sequences_scores or args.output_token_scores,
2732
+ )
2733
+ print("input_ids", input_ids)
2734
+ print("huggingface transformers outputs:")
2735
+ print("sequences", beam_outputs.sequences)
2736
+ if args.output_sequences_scores:
2737
+ print("sequences_scores", beam_outputs.sequences_scores)
2738
+ if args.output_token_scores:
2739
+ print("scores", beam_outputs.scores)
2740
+ for i, sequence in enumerate(beam_outputs.sequences):
2741
+ decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True)
2742
+ torch_decoded_sequences.append(decoded_sequence)
2743
+ print(f"{i}: {decoded_sequence}")
2744
+
2745
+ print("-" * 50)
2746
+ print("Testing beam search with onnxruntime...")
2747
+
2748
+ if is_greedy:
2749
+ inputs = {
2750
+ "input_ids": input_ids.cpu().numpy().astype(np.int32),
2751
+ "max_length": np.array([args.max_length], dtype=np.int32),
2752
+ "min_length": np.array([args.min_length], dtype=np.int32),
2753
+ "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32),
2754
+ }
2755
+ else:
2756
+ inputs = {
2757
+ "input_ids": input_ids.cpu().numpy().astype(np.int32),
2758
+ "max_length": np.array([args.max_length], dtype=np.int32),
2759
+ "min_length": np.array([args.min_length], dtype=np.int32),
2760
+ "num_beams": np.array([args.num_beams], dtype=np.int32),
2761
+ "num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32),
2762
+ "length_penalty": np.array([args.length_penalty], dtype=np.float32),
2763
+ "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32),
2764
+ }
2765
+
2766
+ if args.vocab_mask:
2767
+ vocab_mask = np.ones((vocab_size), dtype=np.int32)
2768
+ if args.vocab_mask:
2769
+ for bad_word_id in bad_words_ids:
2770
+ vocab_mask[bad_word_id] = 0
2771
+ inputs["vocab_mask"] = vocab_mask
2772
+
2773
+ if args.custom_attention_mask:
2774
+ inputs["attention_mask"] = create_attention_mask(input_ids, pad_token_id)
2775
+
2776
+ batch_size = input_ids.shape[0]
2777
+ if args.prefix_vocab_mask:
2778
+ logger.info("Use prefix vocab mask with all ones in ORT, but no corresponding setting for Torch model.")
2779
+ prefix_vocab_mask = np.ones((batch_size, vocab_size), dtype=np.int32)
2780
+ inputs["prefix_vocab_mask"] = prefix_vocab_mask
2781
+
2782
+ if args.save_test_data:
2783
+ test_data_dir = Path(args.output).parent.as_posix()
2784
+ logger.debug("test_data_dir", test_data_dir) # noqa: PLE1205
2785
+ from bert_test_data import output_test_data
2786
+
2787
+ logger.info(f"Saving test_data to {test_data_dir}/test_data_set_* ...")
2788
+
2789
+ all_inputs = [inputs]
2790
+ for i, inputs in enumerate(all_inputs):
2791
+ dir = os.path.join(test_data_dir, "test_data_set_" + str(i))
2792
+ output_test_data(dir, inputs)
2793
+
2794
+ logger.debug("ORT inputs", inputs) # noqa: PLE1205
2795
+
2796
+ if args.disable_perf_test:
2797
+ return
2798
+
2799
+ logger.debug("Creating ort session......")
2800
+ ort_session = create_ort_session(args.output, args.use_gpu, args.use_sln_strict_mode)
2801
+
2802
+ logger.debug("Run ort session......")
2803
+ result = ort_session.run(None, inputs)
2804
+
2805
+ # Test performance
2806
+ latency = []
2807
+ for _ in range(args.total_runs):
2808
+ start = time.time()
2809
+ _ = ort_session.run(None, inputs)
2810
+ latency.append(time.time() - start)
2811
+
2812
+ from benchmark_helper import get_latency_result
2813
+
2814
+ batch_size = input_ids.shape[0]
2815
+ output = get_latency_result(latency, batch_size)
2816
+
2817
+ print("ORT outputs:")
2818
+ sequences = result[0]
2819
+ print("sequences", sequences)
2820
+ if args.output_sequences_scores:
2821
+ print("sequences_scores", result[1])
2822
+ if args.output_token_scores:
2823
+ print("scores", result[2])
2824
+
2825
+ if is_greedy:
2826
+ (batch_size, max_length) = sequences.shape
2827
+ ort_decoded_sequences = []
2828
+ for i in range(batch_size):
2829
+ decoded_sequence = tokenizer.decode(sequences[i], skip_special_tokens=True)
2830
+ ort_decoded_sequences.append(decoded_sequence)
2831
+ print(f"batch {i} sequence: {decoded_sequence}")
2832
+ else:
2833
+ (batch_size, num_sequences, max_length) = sequences.shape
2834
+ ort_decoded_sequences = []
2835
+ for i in range(batch_size):
2836
+ for j in range(num_sequences):
2837
+ decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True)
2838
+ ort_decoded_sequences.append(decoded_sequence)
2839
+ print(f"batch {i} sequence {j}: {decoded_sequence}")
2840
+
2841
+ if beam_outputs:
2842
+ torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1)
2843
+ ort_sequences = torch.LongTensor(sequences)
2844
+ print("-" * 50)
2845
+ print("Torch Sequences:")
2846
+ print(torch_sequences)
2847
+ print(torch_decoded_sequences)
2848
+ print("-" * 50)
2849
+ print("ORT Sequences:")
2850
+ print(ort_sequences)
2851
+ print(ort_decoded_sequences)
2852
+ print("-" * 50)
2853
+ # Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not.
2854
+ is_same = torch_decoded_sequences == ort_decoded_sequences
2855
+ print("Torch and ORT result is ", "same" if is_same else "different")
2856
+ output["parity"] = is_same
2857
+
2858
+ if args.torch_performance:
2859
+ torch_latency_output = test_torch_performance(
2860
+ args,
2861
+ model,
2862
+ input_ids,
2863
+ attention_mask,
2864
+ eos_token_id,
2865
+ pad_token_id,
2866
+ bad_words_ids,
2867
+ )
2868
+ print("Torch Latency", torch_latency_output)
2869
+
2870
+ print("ORT", output)
2871
+
2872
+ return output
2873
+
2874
+
2875
+ def test_t5_model(args: argparse.Namespace, sentences: Optional[List[str]] = None):
2876
+ """Test T5 or MT5 model
2877
+
2878
+ Args:
2879
+ args (argparse.Namespace): arguments parsed from command line
2880
+ sentences (Optional[List[str]], optional): input text. Defaults to None.
2881
+
2882
+ Returns:
2883
+ Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string.
2884
+ """
2885
+ assert args.model_type in ["t5", "mt5"]
2886
+
2887
+ if args.prefix_vocab_mask:
2888
+ logger.debug("Skipping parity test as prefix vocab mask is not implemented by Hugging Face")
2889
+ return None
2890
+
2891
+ tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
2892
+ tokenizer.padding_side = "left"
2893
+
2894
+ if args.model_type == "t5":
2895
+ model = T5ForConditionalGeneration.from_pretrained(
2896
+ args.model_name_or_path,
2897
+ cache_dir=args.cache_dir,
2898
+ )
2899
+ else:
2900
+ model = MT5ForConditionalGeneration.from_pretrained(
2901
+ args.model_name_or_path,
2902
+ cache_dir=args.cache_dir,
2903
+ )
2904
+
2905
+ # Use different length sentences to test batching
2906
+ if sentences is None:
2907
+ sentences = [
2908
+ "translate English to French: The product is released",
2909
+ "summarize: research continues to show that pets bring real health benefits to their owners. Having a dog around can lead to lower levels of stress for both adults and kids.",
2910
+ # "summarize: I enjoy walking in the park. It makes my mind feel calm and refreshed. "
2911
+ # + "I enjoy looking at the trees, flowers, and wildlife around me, and listening to sound from natural.",
2912
+ ]
2913
+
2914
+ inputs = tokenizer(sentences, return_tensors="pt", padding=True)
2915
+ input_ids = inputs["input_ids"]
2916
+ attention_mask = inputs["attention_mask"]
2917
+
2918
+ bad_words = "walk in park"
2919
+ bad_words_ids = tokenizer.encode(bad_words)[:-1] # exclude the last token (EOS)
2920
+ bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list
2921
+ if args.vocab_mask:
2922
+ logger.debug("bad_words_ids", bad_words_ids) # noqa: PLE1205
2923
+ else:
2924
+ bad_words_ids = []
2925
+
2926
+ config = model.config
2927
+ eos_token_id = config.eos_token_id
2928
+ pad_token_id = config.pad_token_id
2929
+ vocab_size = config.vocab_size
2930
+ logger.debug(f"eos_token_id:{eos_token_id}, pad_token_id:{pad_token_id}, vocab_size:{vocab_size}")
2931
+
2932
+ torch_decoded_sequences = []
2933
+ if not args.disable_parity:
2934
+ print("-" * 50)
2935
+ print("Test PyTorch model and beam search with huggingface transformers...")
2936
+ beam_outputs = model.generate(
2937
+ input_ids=input_ids,
2938
+ attention_mask=attention_mask,
2939
+ max_length=args.max_length,
2940
+ min_length=args.min_length,
2941
+ num_beams=args.num_beams,
2942
+ early_stopping=args.early_stopping,
2943
+ no_repeat_ngram_size=args.no_repeat_ngram_size,
2944
+ eos_token_id=eos_token_id,
2945
+ pad_token_id=pad_token_id,
2946
+ num_return_sequences=args.num_return_sequences,
2947
+ length_penalty=args.length_penalty,
2948
+ repetition_penalty=args.repetition_penalty,
2949
+ bad_words_ids=bad_words_ids if bad_words_ids else None,
2950
+ return_dict_in_generate=True,
2951
+ output_scores=args.output_sequences_scores or args.output_token_scores,
2952
+ )
2953
+
2954
+ print("input_ids", input_ids)
2955
+ print("huggingface transformers outputs:")
2956
+ print("sequences", beam_outputs.sequences)
2957
+ if args.output_sequences_scores:
2958
+ print("sequences_scores", beam_outputs.sequences_scores)
2959
+ if args.output_token_scores:
2960
+ print("scores", beam_outputs.scores)
2961
+ for i, sequence in enumerate(beam_outputs.sequences):
2962
+ decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True)
2963
+ torch_decoded_sequences.append(decoded_sequence)
2964
+ print(f"{i}: {decoded_sequence}")
2965
+
2966
+ print("-" * 50)
2967
+ print("Testing beam search with onnxruntime...")
2968
+
2969
+ vocab_mask = np.ones((vocab_size), dtype=np.int32)
2970
+ if args.vocab_mask:
2971
+ for bad_word_id in bad_words_ids:
2972
+ vocab_mask[bad_word_id] = 0
2973
+
2974
+ inputs = {
2975
+ "input_ids": input_ids.cpu().numpy().astype(np.int32),
2976
+ "max_length": np.array([args.max_length], dtype=np.int32),
2977
+ "min_length": np.array([args.min_length], dtype=np.int32),
2978
+ "num_beams": np.array([args.num_beams], dtype=np.int32),
2979
+ "num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32),
2980
+ "length_penalty": np.array([args.length_penalty], dtype=np.float32),
2981
+ "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32),
2982
+ }
2983
+
2984
+ if args.vocab_mask:
2985
+ inputs["vocab_mask"] = vocab_mask
2986
+
2987
+ if args.custom_attention_mask:
2988
+ inputs["attention_mask"] = create_attention_mask(input_ids, pad_token_id)
2989
+
2990
+ if args.save_test_data:
2991
+ test_data_dir = Path(args.output).parent.as_posix()
2992
+ logger.debug("test_data_dir", test_data_dir) # noqa: PLE1205
2993
+ from bert_test_data import output_test_data
2994
+
2995
+ all_inputs = [inputs]
2996
+ for i, inputs in enumerate(all_inputs):
2997
+ dir = os.path.join(test_data_dir, "test_data_set_" + str(i))
2998
+ output_test_data(dir, inputs)
2999
+
3000
+ logger.debug("ORT inputs", inputs) # noqa: PLE1205
3001
+
3002
+ ort_session = create_ort_session(args.output, args.use_gpu, args.use_sln_strict_mode)
3003
+
3004
+ # Test performance
3005
+ latency = []
3006
+ for _ in range(args.total_runs):
3007
+ start = time.time()
3008
+ result = ort_session.run(None, inputs)
3009
+ latency.append(time.time() - start)
3010
+ batch_size = input_ids.shape[0]
3011
+ from benchmark_helper import get_latency_result
3012
+
3013
+ output = get_latency_result(latency, batch_size)
3014
+
3015
+ print("ORT outputs:")
3016
+ sequences = result[0]
3017
+ print("sequences", sequences)
3018
+ if args.output_sequences_scores:
3019
+ print("sequences_scores", result[1])
3020
+ if args.output_token_scores:
3021
+ print("scores", result[2])
3022
+
3023
+ (batch_size, num_sequences, max_length) = sequences.shape
3024
+ ort_decoded_sequences = []
3025
+ for i in range(batch_size):
3026
+ for j in range(num_sequences):
3027
+ decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True)
3028
+ ort_decoded_sequences.append(decoded_sequence)
3029
+ print(f"batch {i} sequence {j}: {decoded_sequence}")
3030
+
3031
+ if not args.disable_parity:
3032
+ torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1)
3033
+ ort_sequences = torch.LongTensor(sequences)
3034
+ print("-" * 50)
3035
+ print("Torch Sequences:")
3036
+ print(torch_sequences)
3037
+ print(torch_decoded_sequences)
3038
+ print("-" * 50)
3039
+ print("ORT Sequences:")
3040
+ print(ort_sequences)
3041
+ print(ort_decoded_sequences)
3042
+ print("-" * 50)
3043
+ # Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not.
3044
+ is_same = torch_decoded_sequences == ort_decoded_sequences
3045
+ print("Torch and ORT result is ", "same" if is_same else "different")
3046
+ output["parity"] = is_same
3047
+
3048
+ if args.torch_performance:
3049
+ torch_latency_output = test_torch_performance(
3050
+ args,
3051
+ model,
3052
+ input_ids,
3053
+ attention_mask,
3054
+ eos_token_id,
3055
+ pad_token_id,
3056
+ bad_words_ids,
3057
+ )
3058
+ print("Torch Latency", torch_latency_output)
3059
+
3060
+ print("ORT", output)
3061
+ return output
3062
+
3063
+
3064
+ def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None):
3065
+ """Main entry function
3066
+
3067
+ Args:
3068
+ argv (Optional[List[str]], optional): _description_. Defaults to None.
3069
+ sentences (Optional[List[str]], optional): input text. Defaults to None.
3070
+
3071
+ Raises:
3072
+ ValueError: Path does not exist: --encoder_decoder_init_onnx
3073
+ ValueError: Path does not exist: --decoder_onnx
3074
+ ValueError: --decoder_onnx and --encoder_decoder_init_onnx are not used together for T5
3075
+
3076
+ Returns:
3077
+ Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string.
3078
+ """
3079
+
3080
+ args = parse_arguments(argv)
3081
+ setup_logger(args.verbose)
3082
+
3083
+ if args.model_type in ["t5", "mt5"]:
3084
+ if args.encoder_decoder_init_onnx and not os.path.exists(args.encoder_decoder_init_onnx):
3085
+ raise ValueError(f"Path does not exist: --encoder_decoder_init_onnx {args.encoder_decoder_init_onnx}")
3086
+ if args.decoder_onnx and not os.path.exists(args.decoder_onnx):
3087
+ raise ValueError(f"Path does not exist: --decoder_onnx {args.decoder_onnx}")
3088
+ if (args.encoder_decoder_init_onnx and not args.decoder_onnx) or (
3089
+ args.decoder_onnx and not args.encoder_decoder_init_onnx
3090
+ ):
3091
+ raise ValueError("--decoder_onnx shall use together with --encoder_decoder_init_onnx")
3092
+
3093
+ is_greedy = args.num_beams == 1 and args.num_return_sequences == 1
3094
+
3095
+ if args.model_type == "gpt2" and is_greedy:
3096
+ if args.top_p > 0.0 and args.top_p < 1.0:
3097
+ convert_generation_model(args, GenerationType.SAMPLING)
3098
+ logger.info(
3099
+ "The test for gpt2_sampling onnx model is limited to non-custom model with small top_p(e.g <=0.01) value. The result should be the same as gpt2 greedy search."
3100
+ )
3101
+ if args.top_p > 0.01 or args.custom or args.seed:
3102
+ return
3103
+ else:
3104
+ convert_generation_model(args, GenerationType.GREEDYSEARCH)
3105
+ else:
3106
+ convert_generation_model(args)
3107
+
3108
+ logger.info("start testing model...")
3109
+ if args.model_type in ["t5", "mt5"]:
3110
+ result = test_t5_model(args, sentences=sentences)
3111
+ else:
3112
+ result = test_gpt_model(args, sentences=sentences, is_greedy=is_greedy)
3113
+
3114
+ if result:
3115
+ if args.use_external_data_format:
3116
+ logger.info(f"Output files: {args.output}, {args.output}.data")
3117
+ else:
3118
+ logger.info(f"Output file: {args.output}")
3119
+
3120
+ return result
3121
+
3122
+
3123
+ if __name__ == "__main__":
3124
+ main()