onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,438 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+ import tempfile
10
+ from pathlib import Path
11
+ from typing import List, Optional, Union
12
+
13
+ import numpy
14
+ import onnx
15
+ import torch
16
+ from io_binding_helper import TypeHelper
17
+ from onnx_model import OnnxModel
18
+ from past_helper import PastKeyValuesHelper
19
+ from t5_encoder import T5EncoderInputs
20
+ from torch_onnx_export_helper import torch_onnx_export
21
+ from transformers import MT5Config, T5Config
22
+
23
+ from onnxruntime import InferenceSession
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class T5DecoderInit(torch.nn.Module):
29
+ """A T5 decoder with LM head to create initial past key values.
30
+ This model is only called once during starting decoding.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ decoder: torch.nn.Module,
36
+ lm_head: torch.nn.Module,
37
+ config: Union[T5Config, MT5Config],
38
+ decoder_start_token_id: Optional[int] = None,
39
+ ):
40
+ super().__init__()
41
+ self.decoder = decoder
42
+ self.lm_head = lm_head
43
+ self.config = config
44
+ self.decoder_start_token_id = (
45
+ decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
46
+ )
47
+ self.tie_word_embeddings = (
48
+ self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
49
+ )
50
+
51
+ def forward(
52
+ self,
53
+ decoder_input_ids: torch.Tensor,
54
+ encoder_attention_mask: torch.Tensor,
55
+ encoder_hidden_states: torch.FloatTensor,
56
+ ):
57
+ if decoder_input_ids is None:
58
+ batch_size = encoder_attention_mask.shape[0]
59
+ decoder_input_ids = (
60
+ torch.ones(
61
+ (batch_size, 1),
62
+ dtype=torch.long,
63
+ device=encoder_attention_mask.device,
64
+ )
65
+ * self.decoder_start_token_id
66
+ )
67
+
68
+ decoder_outputs = self.decoder(
69
+ input_ids=decoder_input_ids,
70
+ encoder_hidden_states=encoder_hidden_states,
71
+ encoder_attention_mask=encoder_attention_mask,
72
+ use_cache=True,
73
+ return_dict=True,
74
+ )
75
+
76
+ sequence_output = decoder_outputs.last_hidden_state
77
+ present_key_values = decoder_outputs.past_key_values
78
+
79
+ if self.tie_word_embeddings:
80
+ sequence_output = sequence_output * (self.config.d_model**-0.5)
81
+
82
+ lm_logits = self.lm_head(sequence_output)
83
+ past_self, past_cross = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
84
+ return lm_logits, past_self, past_cross
85
+
86
+
87
+ class T5Decoder(torch.nn.Module):
88
+ """A T5 decoder with LM head and past key values"""
89
+
90
+ def __init__(self, decoder, lm_head, config):
91
+ super().__init__()
92
+ self.decoder = decoder
93
+ self.lm_head = lm_head
94
+ self.config = config
95
+ self.tie_word_embeddings = (
96
+ self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
97
+ )
98
+
99
+ def forward(self, decoder_input_ids, encoder_attention_mask, *past):
100
+ num_decoder_layers = self.config.num_decoder_layers
101
+ past_key_values = PastKeyValuesHelper.group_by_layer(past, num_decoder_layers)
102
+
103
+ # This is a hack since only the third dimension of encoder_hidden_states is used here
104
+ dummy_encoder_hidden_states = encoder_attention_mask.unsqueeze(2)
105
+ decoder_outputs = self.decoder(
106
+ input_ids=decoder_input_ids,
107
+ past_key_values=past_key_values,
108
+ encoder_hidden_states=dummy_encoder_hidden_states,
109
+ encoder_attention_mask=encoder_attention_mask,
110
+ use_cache=True,
111
+ return_dict=True,
112
+ )
113
+
114
+ sequence_output = decoder_outputs.last_hidden_state
115
+ present_key_values = decoder_outputs.past_key_values
116
+
117
+ if self.tie_word_embeddings:
118
+ sequence_output = sequence_output * (self.config.d_model**-0.5)
119
+
120
+ lm_logits = self.lm_head(sequence_output)
121
+ present_self, _ = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
122
+
123
+ # Do not return present_cross since they are identical to corresponding past_cross input
124
+ return lm_logits, present_self
125
+
126
+
127
+ class T5DecoderInputs:
128
+ def __init__(
129
+ self,
130
+ decoder_input_ids,
131
+ encoder_attention_mask,
132
+ past_key_values=None,
133
+ ):
134
+ self.decoder_input_ids: torch.LongTensor = decoder_input_ids
135
+ self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
136
+ self.past_key_values: Union[List[torch.FloatTensor], List[torch.HalfTensor], None] = past_key_values
137
+
138
+ @staticmethod
139
+ def create_dummy(
140
+ config: Union[T5Config, MT5Config],
141
+ batch_size: int,
142
+ encode_sequence_length: int,
143
+ past_decode_sequence_length: int,
144
+ device: torch.device,
145
+ float16: bool = False,
146
+ use_int32_inputs: bool = False,
147
+ ): # -> T5DecoderInputs:
148
+ """Create dummy inputs for T5Decoder.
149
+
150
+ Args:
151
+ decoder: decoder
152
+ batch_size (int): batch size
153
+ encode_sequence_length (int): sequence length of input_ids for encoder
154
+ past_decode_sequence_length (int): past sequence length of input_ids for decoder
155
+ device (torch.device): device of output tensors
156
+ float16 (bool): whether the model uses float32 or float16 in input
157
+ use_int32_inputs(bool): whether use int32 instead of int64 for some inputs
158
+
159
+ Returns:
160
+ T5DecoderInputs: dummy inputs for decoder
161
+ """
162
+ num_attention_heads: int = config.num_heads
163
+ num_layers: int = config.num_decoder_layers
164
+ vocab_size: int = config.vocab_size
165
+
166
+ # Do not use head_size = hidden_size / num_attention_heads here.
167
+ # For example, mt5-small, d_model=512 and num_heads=6
168
+ head_size: int = config.d_kv
169
+
170
+ sequence_length: int = 1 # fixed for decoding
171
+ decoder_input_ids = torch.randint(
172
+ low=0,
173
+ high=vocab_size - 1,
174
+ size=(batch_size, sequence_length),
175
+ dtype=(torch.int32 if use_int32_inputs else torch.int64),
176
+ device=device,
177
+ )
178
+
179
+ encoder_inputs = T5EncoderInputs.create_dummy(
180
+ batch_size,
181
+ encode_sequence_length,
182
+ vocab_size,
183
+ device,
184
+ use_int32_inputs=use_int32_inputs,
185
+ )
186
+
187
+ float_type = torch.float16 if float16 else torch.float32
188
+
189
+ if past_decode_sequence_length > 0:
190
+ self_attention_past_shape = [
191
+ batch_size,
192
+ num_attention_heads,
193
+ past_decode_sequence_length,
194
+ head_size,
195
+ ]
196
+ cross_attention_past_shape = [
197
+ batch_size,
198
+ num_attention_heads,
199
+ encode_sequence_length,
200
+ head_size,
201
+ ]
202
+
203
+ past = []
204
+ for _ in range(2 * num_layers):
205
+ past.append(torch.rand(self_attention_past_shape, dtype=float_type, device=device))
206
+
207
+ for _ in range(2 * num_layers):
208
+ past.append(torch.rand(cross_attention_past_shape, dtype=float_type, device=device))
209
+ else:
210
+ past = None
211
+
212
+ return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, past)
213
+
214
+ def to_list(self) -> List:
215
+ input_list = [
216
+ self.decoder_input_ids,
217
+ self.encoder_attention_mask,
218
+ ]
219
+ if self.past_key_values:
220
+ input_list.extend(self.past_key_values)
221
+ return input_list
222
+
223
+ def to_fp32(self):
224
+ past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
225
+ return T5DecoderInputs(
226
+ self.decoder_input_ids.clone(),
227
+ self.encoder_attention_mask.clone(),
228
+ past,
229
+ )
230
+
231
+
232
+ class T5DecoderHelper:
233
+ @staticmethod
234
+ def export_onnx(
235
+ decoder: Union[T5Decoder, T5DecoderInit],
236
+ device: torch.device,
237
+ onnx_model_path: str,
238
+ verbose: bool = True,
239
+ use_external_data_format: bool = False,
240
+ use_int32_inputs: bool = False,
241
+ ):
242
+ """Export decoder to ONNX
243
+
244
+ Args:
245
+ decoder (Union[T5Decoder, T5DecoderNoPastState]): decoder object
246
+ device (torch.device): device of decoder object
247
+ onnx_model_path (str): onnx path
248
+ verbose (bool, optional): print verbose information. Defaults to True.
249
+ use_external_data_format (bool, optional): use external data format or not. Defaults to False.
250
+ use_int32_inputs (bool, optional): use int32 inputs
251
+ """
252
+ assert isinstance(decoder, (T5Decoder, T5DecoderInit))
253
+
254
+ inputs = T5DecoderInputs.create_dummy(
255
+ decoder.config,
256
+ batch_size=2,
257
+ encode_sequence_length=3,
258
+ past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0,
259
+ device=device,
260
+ use_int32_inputs=use_int32_inputs,
261
+ )
262
+ input_list = inputs.to_list()
263
+
264
+ num_decoder_layers = decoder.config.num_decoder_layers
265
+
266
+ past_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=False)
267
+ present_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=True)
268
+ present_self_names = present_names[: 2 * num_decoder_layers]
269
+
270
+ input_past_names = past_names if isinstance(decoder, T5Decoder) else []
271
+ output_present_names = present_self_names if isinstance(decoder, T5Decoder) else present_names
272
+ output_names = ["logits", *output_present_names]
273
+
274
+ # Shape of input tensors (sequence_length==1):
275
+ # input_ids: (batch_size, sequence_length)
276
+ # encoder_attention_mask: (batch_size, encode_sequence_length)
277
+ # past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
278
+ # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
279
+
280
+ # Shape of output tensors:
281
+ # logits: (batch_size, sequence_length, vocab_size)
282
+ # past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
283
+ # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
284
+
285
+ input_names = ["input_ids"]
286
+ input_names.append("encoder_attention_mask")
287
+ input_names.extend(input_past_names)
288
+
289
+ dynamic_axes = {
290
+ "input_ids": {
291
+ 0: "batch_size",
292
+ # 1: 'sequence_length'
293
+ },
294
+ "encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
295
+ "encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length"},
296
+ "logits": {
297
+ 0: "batch_size",
298
+ # 1: 'sequence_length'
299
+ },
300
+ }
301
+
302
+ for name in input_past_names:
303
+ dynamic_axes[name] = {
304
+ 0: "batch_size",
305
+ 2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length",
306
+ }
307
+
308
+ for name in output_present_names:
309
+ if "cross" in name:
310
+ dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"}
311
+ else: # self attention past state
312
+ if isinstance(decoder, T5Decoder):
313
+ dynamic_axes[name] = {
314
+ 0: "batch_size",
315
+ 2: "past_decode_sequence_length + 1",
316
+ }
317
+ else:
318
+ dynamic_axes[name] = {
319
+ 0: "batch_size",
320
+ # 2: 'sequence_length'
321
+ }
322
+
323
+ Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
324
+
325
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
326
+ temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
327
+ Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
328
+ torch_onnx_export(
329
+ decoder,
330
+ args=tuple(input_list),
331
+ f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
332
+ export_params=True,
333
+ input_names=input_names,
334
+ output_names=output_names,
335
+ dynamic_axes=dynamic_axes,
336
+ opset_version=12,
337
+ do_constant_folding=True,
338
+ use_external_data_format=use_external_data_format,
339
+ verbose=verbose,
340
+ )
341
+
342
+ if use_external_data_format:
343
+ model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
344
+ OnnxModel.save(
345
+ model,
346
+ onnx_model_path,
347
+ save_as_external_data=True,
348
+ all_tensors_to_one_file=True,
349
+ )
350
+
351
+ @staticmethod
352
+ def onnxruntime_inference(ort_session, inputs: T5DecoderInputs):
353
+ """Run inference of ONNX model."""
354
+ logger.debug("start onnxruntime_inference")
355
+
356
+ ort_inputs = {
357
+ "input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()),
358
+ "encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
359
+ }
360
+
361
+ if inputs.past_key_values:
362
+ assert len(inputs.past_key_values) % 4 == 0
363
+ num_layers = int(len(inputs.past_key_values) / 4)
364
+ past_names = PastKeyValuesHelper.get_past_names(num_layers)
365
+ for i, past_tensor in enumerate(inputs.past_key_values):
366
+ ort_inputs[past_names[i]] = numpy.ascontiguousarray(past_tensor.cpu().numpy())
367
+
368
+ ort_outputs = ort_session.run(None, ort_inputs)
369
+ return ort_outputs
370
+
371
+ @staticmethod
372
+ def verify_onnx(
373
+ model: Union[T5Decoder, T5DecoderInit],
374
+ ort_session: InferenceSession,
375
+ device: torch.device,
376
+ use_int32_inputs: bool,
377
+ max_cases: int = 4,
378
+ ):
379
+ """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
380
+ float16: bool = TypeHelper.get_input_type(ort_session, "past_key_self_0") == "tensor(float16)"
381
+
382
+ test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)]
383
+ test_cases_max_diff = []
384
+ for (
385
+ batch_size,
386
+ encode_sequence_length,
387
+ past_decode_sequence_length,
388
+ ) in test_cases[:max_cases]:
389
+ if isinstance(model, T5DecoderInit):
390
+ past_decode_sequence_length = 0 # noqa: PLW2901
391
+
392
+ inputs = T5DecoderInputs.create_dummy(
393
+ model.config,
394
+ batch_size,
395
+ encode_sequence_length,
396
+ past_decode_sequence_length,
397
+ device=device,
398
+ float16=float16,
399
+ use_int32_inputs=use_int32_inputs,
400
+ )
401
+
402
+ # We use fp32 PyTroch model as baseline even when ONNX model is fp16
403
+ input_list = inputs.to_fp32().to_list()
404
+
405
+ # Run inference of PyTorch model
406
+ with torch.no_grad():
407
+ torch_outputs = model(*input_list)
408
+
409
+ ort_outputs = T5DecoderHelper.onnxruntime_inference(ort_session, inputs)
410
+ num_decoder_layers = model.config.num_decoder_layers
411
+
412
+ max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
413
+ max_diff_all = max_diff
414
+ logger.debug(f"logits max_diff={max_diff}")
415
+
416
+ for i in range(2 * num_decoder_layers):
417
+ max_diff = numpy.amax(numpy.abs(torch_outputs[1][i].cpu().numpy() - ort_outputs[1 + i]))
418
+ logger.debug(f"self attention past state {i} max_diff={max_diff}")
419
+ max_diff_all = max(max_diff_all, max_diff)
420
+
421
+ if isinstance(model, T5DecoderInit):
422
+ for i in range(2 * num_decoder_layers):
423
+ max_diff = numpy.amax(
424
+ numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * num_decoder_layers + i])
425
+ )
426
+ logger.debug(f"cross attention past state {i} max_diff={max_diff}")
427
+ max_diff_all = max(max_diff_all, max_diff)
428
+
429
+ test_cases_max_diff.append(max_diff_all)
430
+ logger.info(
431
+ "batch_size=%s, encode_sequence_length=%s, past_decode_sequence_length=%s, max_diff=%s",
432
+ batch_size,
433
+ encode_sequence_length,
434
+ past_decode_sequence_length,
435
+ max_diff_all,
436
+ )
437
+
438
+ return max_diff_all
@@ -0,0 +1,171 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+ import random
10
+ import tempfile
11
+ from pathlib import Path
12
+ from typing import List, Union
13
+
14
+ import numpy
15
+ import onnx
16
+ import torch
17
+ from onnx_model import OnnxModel
18
+ from torch_onnx_export_helper import torch_onnx_export
19
+ from transformers import MT5Config, T5Config
20
+
21
+ from onnxruntime import InferenceSession
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class T5Encoder(torch.nn.Module):
27
+ """T5 encoder outputs only the last hidden state"""
28
+
29
+ def __init__(self, encoder, config: Union[T5Config, MT5Config]):
30
+ super().__init__()
31
+ self.encoder = encoder
32
+ self.config = config
33
+
34
+ def forward(self, input_ids, attention_mask):
35
+ return self.encoder(input_ids, attention_mask)[0]
36
+
37
+
38
+ class T5EncoderInputs:
39
+ def __init__(self, input_ids, attention_mask):
40
+ self.input_ids: torch.LongTensor = input_ids
41
+ self.attention_mask: torch.LongTensor = attention_mask
42
+
43
+ @staticmethod
44
+ def create_dummy(
45
+ batch_size: int, sequence_length: int, vocab_size: int, device: torch.device, use_int32_inputs: bool = False
46
+ ): # -> T5EncoderInputs
47
+ """Create dummy inputs for T5 encoder.
48
+
49
+ Args:
50
+ batch_size (int): batch size
51
+ sequence_length (int): sequence length
52
+ vocab_size (int): vocabulary size
53
+ device (torch.device): device of output tensors
54
+
55
+ Returns:
56
+ T5EncoderInputs: dummy inputs for encoder
57
+ """
58
+ dtype = torch.int32 if use_int32_inputs else torch.int64
59
+
60
+ input_ids = torch.randint(
61
+ low=0,
62
+ high=vocab_size - 1,
63
+ size=(batch_size, sequence_length),
64
+ dtype=dtype,
65
+ device=device,
66
+ )
67
+
68
+ attention_mask = torch.ones([batch_size, sequence_length], dtype=dtype, device=device)
69
+ if sequence_length >= 2:
70
+ for i in range(batch_size):
71
+ padding_position = random.randint(0, sequence_length - 1)
72
+ attention_mask[i, :padding_position] = 0
73
+ return T5EncoderInputs(input_ids, attention_mask)
74
+
75
+ def to_list(self) -> List:
76
+ input_list = [v for v in [self.input_ids, self.attention_mask] if v is not None]
77
+ return input_list
78
+
79
+
80
+ class T5EncoderHelper:
81
+ @staticmethod
82
+ def export_onnx(
83
+ encoder: T5Encoder,
84
+ device: torch.device,
85
+ onnx_model_path: str,
86
+ verbose: bool = True,
87
+ use_external_data_format: bool = False,
88
+ use_int32_inputs: bool = False,
89
+ ):
90
+ """Export encoder to ONNX
91
+
92
+ Args:
93
+ encoder (T5Encoder): encoder object
94
+ device (torch.device): device of encoder object
95
+ onnx_model_path (str): onnx path
96
+ verbose (bool, optional): print verbose information. Defaults to True.
97
+ use_external_data_format (bool, optional): use external data format or not. Defaults to False.
98
+ """
99
+ config = encoder.config
100
+ encoder_inputs = T5EncoderInputs.create_dummy(
101
+ batch_size=2,
102
+ sequence_length=4,
103
+ vocab_size=config.vocab_size,
104
+ device=device,
105
+ use_int32_inputs=use_int32_inputs,
106
+ )
107
+
108
+ Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
109
+
110
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
111
+ temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
112
+ Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
113
+ torch_onnx_export(
114
+ encoder,
115
+ args=tuple(encoder_inputs.to_list()),
116
+ f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
117
+ export_params=True,
118
+ input_names=["input_ids", "attention_mask"],
119
+ output_names=["hidden_states"],
120
+ dynamic_axes={
121
+ "input_ids": {0: "batch_size", 1: "sequence_length"},
122
+ "attention_mask": {0: "batch_size", 1: "sequence_length"},
123
+ "hidden_states": {0: "batch_size", 1: "sequence_length"},
124
+ },
125
+ opset_version=12,
126
+ do_constant_folding=True,
127
+ use_external_data_format=use_external_data_format,
128
+ verbose=verbose,
129
+ )
130
+
131
+ if use_external_data_format:
132
+ model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
133
+ OnnxModel.save(
134
+ model,
135
+ onnx_model_path,
136
+ save_as_external_data=True,
137
+ all_tensors_to_one_file=True,
138
+ )
139
+
140
+ @staticmethod
141
+ def onnxruntime_inference(ort_session, inputs: T5EncoderInputs):
142
+ """Run inference of ONNX model."""
143
+ ort_inputs = {
144
+ "input_ids": numpy.ascontiguousarray(inputs.input_ids.cpu().numpy()),
145
+ "attention_mask": numpy.ascontiguousarray(inputs.attention_mask.cpu().numpy()),
146
+ }
147
+
148
+ return ort_session.run(None, ort_inputs)
149
+
150
+ @staticmethod
151
+ def verify_onnx(
152
+ model: T5Encoder, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool = False
153
+ ):
154
+ """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
155
+ inputs = T5EncoderInputs.create_dummy(
156
+ batch_size=4,
157
+ sequence_length=11,
158
+ vocab_size=model.config.vocab_size,
159
+ device=device,
160
+ use_int32_inputs=use_int32_inputs,
161
+ )
162
+ input_list = inputs.to_list()
163
+ torch_outputs = model(*input_list)
164
+
165
+ ort_outputs = T5EncoderHelper.onnxruntime_inference(ort_session, inputs)
166
+
167
+ max_diff = numpy.amax(numpy.abs(torch_outputs.cpu().numpy() - ort_outputs[0]))
168
+
169
+ logger.info(f"max_diff={max_diff}")
170
+
171
+ return max_diff