onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,299 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+ import tempfile
10
+ from pathlib import Path
11
+ from typing import List, Optional, Union
12
+
13
+ import numpy
14
+ import onnx
15
+ import torch
16
+ from onnx_model import OnnxModel
17
+ from past_helper import PastKeyValuesHelper
18
+ from t5_decoder import T5DecoderInit
19
+ from t5_encoder import T5Encoder, T5EncoderInputs
20
+ from torch_onnx_export_helper import torch_onnx_export
21
+ from transformers import MT5Config, T5Config
22
+
23
+ from onnxruntime import InferenceSession
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class T5EncoderDecoderInit(torch.nn.Module):
29
+ """A combination of T5Encoder and T5DecoderInit."""
30
+
31
+ def __init__(
32
+ self,
33
+ encoder: torch.nn.Module,
34
+ decoder: torch.nn.Module,
35
+ lm_head: torch.nn.Module,
36
+ config: Union[T5Config, MT5Config],
37
+ decoder_start_token_id: Optional[int] = None,
38
+ ):
39
+ super().__init__()
40
+ self.config = config
41
+ self.t5_encoder = T5Encoder(encoder, config)
42
+ self.t5_decoder_init = T5DecoderInit(decoder, lm_head, config, decoder_start_token_id)
43
+
44
+ def forward(
45
+ self,
46
+ encoder_input_ids: torch.Tensor,
47
+ encoder_attention_mask: torch.Tensor,
48
+ decoder_input_ids: torch.Tensor = None,
49
+ ):
50
+ encoder_hidden_states: torch.FloatTensor = self.t5_encoder(encoder_input_ids, encoder_attention_mask)
51
+ lm_logits, past_self, past_cross = self.t5_decoder_init(
52
+ decoder_input_ids, encoder_attention_mask, encoder_hidden_states
53
+ )
54
+ return lm_logits, encoder_hidden_states, past_self, past_cross
55
+
56
+
57
+ class T5EncoderDecoderInitInputs:
58
+ def __init__(self, encoder_input_ids, encoder_attention_mask, decoder_input_ids=None):
59
+ self.encoder_input_ids: torch.LongTensor = encoder_input_ids
60
+ self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
61
+ self.decoder_input_ids: torch.LongTensor = decoder_input_ids
62
+
63
+ @staticmethod
64
+ def create_dummy(
65
+ config: Union[T5Config, MT5Config],
66
+ batch_size: int,
67
+ encode_sequence_length: int,
68
+ use_decoder_input_ids: int,
69
+ device: torch.device,
70
+ use_int32_inputs: bool = False,
71
+ ): # -> T5EncoderDecoderInitInputs:
72
+ encoder_inputs: T5EncoderInputs = T5EncoderInputs.create_dummy(
73
+ batch_size,
74
+ encode_sequence_length,
75
+ config.vocab_size,
76
+ device,
77
+ use_int32_inputs=use_int32_inputs,
78
+ )
79
+ decoder_input_ids = None
80
+ if use_decoder_input_ids:
81
+ dtype = torch.int32 if use_int32_inputs else torch.int64
82
+ decoder_input_ids = torch.ones((batch_size, 1), dtype=dtype, device=device) * config.decoder_start_token_id
83
+
84
+ return T5EncoderDecoderInitInputs(encoder_inputs.input_ids, encoder_inputs.attention_mask, decoder_input_ids)
85
+
86
+ def to_list(self) -> List:
87
+ input_list = [self.encoder_input_ids, self.encoder_attention_mask]
88
+ if self.decoder_input_ids is not None:
89
+ input_list.append(self.decoder_input_ids)
90
+ return input_list
91
+
92
+
93
+ class T5EncoderDecoderInitHelper:
94
+ @staticmethod
95
+ def export_onnx(
96
+ model: T5EncoderDecoderInit,
97
+ device: torch.device,
98
+ onnx_model_path: str,
99
+ use_decoder_input_ids: bool = True,
100
+ verbose: bool = True,
101
+ use_external_data_format: bool = False,
102
+ use_int32_inputs: bool = False,
103
+ ):
104
+ """Export decoder to ONNX
105
+
106
+ Args:
107
+ model (T5EncoderDecoderInit): the model to export
108
+ device (torch.device): device of decoder object
109
+ onnx_model_path (str): onnx path
110
+ verbose (bool, optional): print verbose information. Defaults to True.
111
+ use_external_data_format (bool, optional): use external data format or not. Defaults to False.
112
+ """
113
+ assert isinstance(model, T5EncoderDecoderInit)
114
+
115
+ inputs = T5EncoderDecoderInitInputs.create_dummy(
116
+ model.config,
117
+ batch_size=2,
118
+ encode_sequence_length=3,
119
+ use_decoder_input_ids=use_decoder_input_ids,
120
+ device=device,
121
+ use_int32_inputs=use_int32_inputs,
122
+ )
123
+ input_list = inputs.to_list()
124
+
125
+ present_names = PastKeyValuesHelper.get_past_names(model.config.num_decoder_layers, present=True)
126
+
127
+ output_names = ["logits", "encoder_hidden_states", *present_names]
128
+
129
+ # Shape of input tensors (sequence_length==1):
130
+ # input_ids: (batch_size, sequence_length)
131
+ # encoder_attention_mask: (batch_size, encode_sequence_length)
132
+ # encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
133
+ # past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
134
+ # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
135
+
136
+ # Shape of output tensors:
137
+ # logits: (batch_size, sequence_length, vocab_size)
138
+ # past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
139
+ # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
140
+
141
+ input_names = ["encoder_input_ids", "encoder_attention_mask"]
142
+
143
+ # ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2' in shape inference.
144
+ # We use a workaround here: first use dim_param "1" for sequence_length, and later change to dim_value.
145
+ sequence_length = "1"
146
+ num_heads = str(model.config.num_heads)
147
+ hidden_size = str(model.config.d_model)
148
+ head_size = str(model.config.d_kv)
149
+
150
+ dynamic_axes = {
151
+ "encoder_input_ids": {0: "batch_size", 1: "encode_sequence_length"},
152
+ "encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
153
+ "encoder_hidden_states": {
154
+ 0: "batch_size",
155
+ 1: "encode_sequence_length",
156
+ 2: hidden_size,
157
+ },
158
+ "logits": {
159
+ 0: "batch_size",
160
+ 1: sequence_length,
161
+ },
162
+ }
163
+
164
+ if use_decoder_input_ids:
165
+ input_names.append("decoder_input_ids")
166
+ dynamic_axes["decoder_input_ids"] = {
167
+ 0: "batch_size",
168
+ 1: sequence_length,
169
+ }
170
+
171
+ for name in present_names:
172
+ if "cross" in name:
173
+ dynamic_axes[name] = {
174
+ 0: "batch_size",
175
+ 1: num_heads,
176
+ 2: "encode_sequence_length",
177
+ 3: head_size,
178
+ }
179
+
180
+ else: # self attention past state
181
+ dynamic_axes[name] = {
182
+ 0: "batch_size",
183
+ 1: num_heads,
184
+ 2: sequence_length,
185
+ 3: head_size,
186
+ }
187
+
188
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
189
+ temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder_decoder_init.onnx")
190
+ Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
191
+ torch_onnx_export(
192
+ model,
193
+ args=tuple(input_list),
194
+ f=temp_onnx_model_path,
195
+ export_params=True,
196
+ input_names=input_names,
197
+ output_names=output_names,
198
+ dynamic_axes=dynamic_axes,
199
+ opset_version=12,
200
+ do_constant_folding=True,
201
+ use_external_data_format=use_external_data_format,
202
+ verbose=verbose,
203
+ )
204
+
205
+ # Workaround as mentioned earlier: change numeric dim_param to dim_value
206
+ model = onnx.load(temp_onnx_model_path)
207
+ for tensor in model.graph.output:
208
+ for dim_proto in tensor.type.tensor_type.shape.dim:
209
+ if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
210
+ sequence_length,
211
+ num_heads,
212
+ hidden_size,
213
+ head_size,
214
+ ]:
215
+ dim_value = int(dim_proto.dim_param)
216
+ dim_proto.Clear()
217
+ dim_proto.dim_value = dim_value
218
+
219
+ OnnxModel.save(
220
+ model,
221
+ onnx_model_path,
222
+ save_as_external_data=use_external_data_format,
223
+ all_tensors_to_one_file=True,
224
+ )
225
+
226
+ @staticmethod
227
+ def onnxruntime_inference(ort_session, inputs: T5EncoderDecoderInitInputs):
228
+ """Run inference of ONNX model."""
229
+ logger.debug("start onnxruntime_inference")
230
+
231
+ ort_inputs = {
232
+ "encoder_input_ids": numpy.ascontiguousarray(inputs.encoder_input_ids.cpu().numpy()),
233
+ "encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
234
+ }
235
+ if inputs.decoder_input_ids is not None:
236
+ ort_inputs["decoder_input_ids"] = numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy())
237
+
238
+ ort_outputs = ort_session.run(None, ort_inputs)
239
+ return ort_outputs
240
+
241
+ @staticmethod
242
+ def verify_onnx(
243
+ model: T5EncoderDecoderInit,
244
+ ort_session: InferenceSession,
245
+ device: torch.device,
246
+ use_int32_inputs: bool,
247
+ max_cases: int = 4,
248
+ ):
249
+ """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
250
+ ort_inputs = ort_session.get_inputs()
251
+ use_decoder_input_ids = len(ort_inputs) == 3
252
+
253
+ test_cases = [(4, 11), (1, 2), (3, 1), (8, 5)]
254
+ test_cases_max_diff = []
255
+ for batch_size, encode_sequence_length in test_cases[:max_cases]:
256
+ inputs = T5EncoderDecoderInitInputs.create_dummy(
257
+ model.config,
258
+ batch_size,
259
+ encode_sequence_length,
260
+ use_decoder_input_ids=use_decoder_input_ids,
261
+ device=device,
262
+ use_int32_inputs=use_int32_inputs,
263
+ )
264
+
265
+ ort_outputs = T5EncoderDecoderInitHelper.onnxruntime_inference(ort_session, inputs)
266
+
267
+ # Run inference of PyTorch model
268
+ input_list = inputs.to_list()
269
+ torch_outputs = model(*input_list)
270
+
271
+ num_decoder_layers = model.config.num_decoder_layers
272
+
273
+ assert torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape
274
+ max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
275
+ logger.debug(f"logits max_diff={max_diff}")
276
+ max_diff_all = max_diff
277
+
278
+ assert torch_outputs[1].cpu().numpy().shape == ort_outputs[1].shape
279
+ max_diff = numpy.amax(numpy.abs(torch_outputs[1].cpu().numpy() - ort_outputs[1]))
280
+ logger.debug(f"encoder_hidden_states max_diff={max_diff}")
281
+ max_diff_all = max(max_diff_all, max_diff)
282
+
283
+ for i in range(2 * num_decoder_layers):
284
+ max_diff = numpy.amax(numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[2 + i]))
285
+ logger.debug(f"self attention past state {i} max_diff={max_diff}")
286
+
287
+ for i in range(2 * num_decoder_layers):
288
+ max_diff = numpy.amax(
289
+ numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * num_decoder_layers + i])
290
+ )
291
+ logger.debug(f"cross attention past state {i} max_diff={max_diff}")
292
+ max_diff_all = max(max_diff_all, max_diff)
293
+
294
+ test_cases_max_diff.append(max_diff_all)
295
+ logger.info(
296
+ f"batch_size={batch_size} encode_sequence_length={encode_sequence_length}, max_diff={max_diff_all}"
297
+ )
298
+
299
+ return max(test_cases_max_diff)
@@ -0,0 +1,272 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Dict, List, Union
11
+
12
+ import torch
13
+ from float16 import float_to_float16_max_diff
14
+ from onnx_model import OnnxModel
15
+ from optimizer import optimize_model
16
+ from t5_decoder import T5Decoder, T5DecoderHelper, T5DecoderInit
17
+ from t5_encoder import T5Encoder, T5EncoderHelper
18
+ from t5_encoder_decoder_init import T5EncoderDecoderInit, T5EncoderDecoderInitHelper
19
+ from transformers import MT5ForConditionalGeneration, T5ForConditionalGeneration
20
+
21
+ from onnxruntime import InferenceSession
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]
26
+ PRETRAINED_MT5_MODELS = ["google/mt5-small", "google/mt5-base", "google/mt5-large", "google/mt5-xl", "google/mt5-xxl"]
27
+
28
+
29
+ class T5Helper:
30
+ @staticmethod
31
+ def get_onnx_path(
32
+ output_dir: str,
33
+ model_name_or_path: str,
34
+ suffix: str = "",
35
+ new_folder: bool = False,
36
+ ) -> str:
37
+ """Build onnx path
38
+
39
+ Args:
40
+ output_dir (str): output directory
41
+ model_name_or_path (str): pretrained model name, or path to the model checkpoint
42
+ suffix (str, optional): suffix like "_encoder" or "_decoder_fp16" will be appended to file name. Defaults to None.
43
+ new_folder (bool, optional): create a new directory for the model. Defaults to False.
44
+
45
+ Returns:
46
+ str: path of onnx model
47
+ """
48
+ model_name = model_name_or_path
49
+ if os.path.isdir(model_name_or_path):
50
+ model_name = Path(model_name_or_path).parts[-1]
51
+ else:
52
+ model_name.split("/")[-1]
53
+
54
+ model_name += suffix
55
+
56
+ directory = os.path.join(output_dir, model_name) if new_folder else output_dir
57
+ return os.path.join(directory, model_name + ".onnx")
58
+
59
+ @staticmethod
60
+ def load_model(
61
+ model_name_or_path: str,
62
+ cache_dir: str,
63
+ device: torch.device,
64
+ merge_encoder_and_decoder_init: bool = True,
65
+ model_type: str = "t5",
66
+ state_dict_path: str = "",
67
+ ) -> Dict[str, torch.nn.Module]:
68
+ """Load model given a pretrained name or path, then build models for ONNX conversion.
69
+
70
+ Args:
71
+ model_name_or_path (str): pretrained model name or path
72
+ cache_dir (str): cache directory
73
+ device (torch.device): device to run the model
74
+ merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True.
75
+ is_mt5 (bool, optional): whether the model is MT5 instead of T5
76
+ Returns:
77
+ Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
78
+ """
79
+ if model_type == "t5":
80
+ model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
81
+ elif model_type == "mt5":
82
+ model = MT5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
83
+ else:
84
+ raise ValueError("only support mode_type=t5 or mt5")
85
+
86
+ if state_dict_path:
87
+ model.load_state_dict(torch.load(state_dict_path))
88
+
89
+ decoder = T5Decoder(model.decoder, model.lm_head, model.config)
90
+ decoder.eval().to(device)
91
+
92
+ if merge_encoder_and_decoder_init:
93
+ encoder_decoder_init = T5EncoderDecoderInit(
94
+ model.encoder,
95
+ model.decoder,
96
+ model.lm_head,
97
+ model.config,
98
+ decoder_start_token_id=None,
99
+ )
100
+ return {"encoder_decoder_init": encoder_decoder_init, "decoder": decoder}
101
+ else:
102
+ encoder = T5Encoder(model.encoder, model.config)
103
+ encoder.eval().to(device)
104
+ decoder_init = T5DecoderInit(model.decoder, model.lm_head, model.config)
105
+ decoder_init.eval().to(device)
106
+ return {
107
+ "encoder": encoder,
108
+ "decoder": decoder,
109
+ "decoder_init": decoder_init,
110
+ }
111
+
112
+ @staticmethod
113
+ def export_onnx(
114
+ model: Union[T5Encoder, T5Decoder, T5DecoderInit, T5EncoderDecoderInit],
115
+ device: torch.device,
116
+ onnx_model_path: str,
117
+ verbose: bool = True,
118
+ use_external_data_format: bool = False,
119
+ use_decoder_input_ids: bool = True,
120
+ use_int32_inputs: bool = False,
121
+ ):
122
+ if isinstance(model, T5Encoder):
123
+ T5EncoderHelper.export_onnx(
124
+ model,
125
+ device,
126
+ onnx_model_path,
127
+ verbose,
128
+ use_external_data_format,
129
+ use_int32_inputs,
130
+ )
131
+ elif isinstance(model, T5EncoderDecoderInit):
132
+ T5EncoderDecoderInitHelper.export_onnx(
133
+ model,
134
+ device,
135
+ onnx_model_path,
136
+ use_decoder_input_ids,
137
+ verbose,
138
+ use_external_data_format,
139
+ use_int32_inputs,
140
+ )
141
+ else:
142
+ T5DecoderHelper.export_onnx(
143
+ model,
144
+ device,
145
+ onnx_model_path,
146
+ verbose,
147
+ use_external_data_format,
148
+ use_int32_inputs,
149
+ )
150
+
151
+ @staticmethod
152
+ def auto_mixed_precision(
153
+ onnx_model: OnnxModel,
154
+ op_block_list: List[str] = [ # noqa: B006
155
+ "SimplifiedLayerNormalization",
156
+ "SkipSimplifiedLayerNormalization",
157
+ "Relu",
158
+ "Add",
159
+ ],
160
+ ):
161
+ """Convert model to mixed precision.
162
+ It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
163
+ Args:
164
+ onnx_model (OnnxModel): optimized ONNX model
165
+ op_block_list (List[str], optional): . Defaults to ["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Relu", "Add"]
166
+ Returns:
167
+ parameters(dict): a dictionary of parameters used in float16 conversion
168
+ """
169
+ op_full_set = {node.op_type for node in onnx_model.nodes()}
170
+ fp32_op_set = set(op_block_list)
171
+ fp16_op_set = op_full_set.difference(fp32_op_set)
172
+ logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")
173
+
174
+ # logits is the first output
175
+ logits_output_name = onnx_model.graph().output[0].name
176
+
177
+ # We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
178
+ is_weight_fp16_precision = False
179
+ output_name_to_node = onnx_model.output_name_to_node()
180
+ assert logits_output_name in output_name_to_node
181
+ node = output_name_to_node[logits_output_name]
182
+ last_matmul_node = None
183
+ if node.op_type == "MatMul":
184
+ last_matmul_node = node
185
+ logger.info(f"Found last MatMul node for logits: {node.name}")
186
+ initializer = None
187
+ for input in node.input:
188
+ initializer = onnx_model.get_initializer(input)
189
+ if initializer is not None:
190
+ break
191
+
192
+ # when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
193
+ # we can deduce that the weights are stored in float16 precision.
194
+ max_diff = float_to_float16_max_diff(initializer)
195
+ logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}")
196
+ is_weight_fp16_precision = max_diff < 1e-6
197
+ else:
198
+ logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
199
+
200
+ keep_io_types = []
201
+ node_block_list = []
202
+ if (not is_weight_fp16_precision) and (last_matmul_node is not None):
203
+ # When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
204
+ keep_io_types = [logits_output_name]
205
+ node_block_list = [last_matmul_node.name]
206
+
207
+ parameters = {
208
+ "keep_io_types": keep_io_types,
209
+ "op_block_list": op_block_list,
210
+ "node_block_list": node_block_list,
211
+ "force_fp16_initializers": is_weight_fp16_precision,
212
+ }
213
+
214
+ logger.info(f"auto_mixed_precision parameters: {parameters}")
215
+ onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters)
216
+
217
+ return parameters
218
+
219
+ @staticmethod
220
+ def optimize_onnx(
221
+ onnx_model_path: str,
222
+ optimized_model_path: str,
223
+ is_float16: bool,
224
+ num_attention_heads: int,
225
+ hidden_size: int,
226
+ use_external_data_format: bool = False,
227
+ auto_mixed_precision: bool = True,
228
+ use_gpu: bool = False,
229
+ ):
230
+ """Optimize ONNX model with an option to convert it to use mixed precision."""
231
+
232
+ from fusion_options import FusionOptions
233
+
234
+ optimization_options = None
235
+ if is_float16:
236
+ optimization_options = FusionOptions("t5")
237
+ optimization_options.enable_skip_layer_norm = False
238
+
239
+ m = optimize_model(
240
+ onnx_model_path,
241
+ model_type="t5",
242
+ num_heads=num_attention_heads,
243
+ hidden_size=hidden_size,
244
+ opt_level=2 if not use_external_data_format else 0,
245
+ optimization_options=optimization_options,
246
+ use_gpu=False,
247
+ only_onnxruntime=not use_gpu,
248
+ )
249
+
250
+ if is_float16:
251
+ if auto_mixed_precision:
252
+ T5Helper.auto_mixed_precision(m)
253
+ else:
254
+ m.convert_model_float32_to_float16(cast_input_output=False)
255
+
256
+ m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)
257
+
258
+ @staticmethod
259
+ def verify_onnx(
260
+ model: Union[T5Encoder, T5Decoder, T5DecoderInit, T5EncoderDecoderInit],
261
+ ort_session: InferenceSession,
262
+ device: torch.device,
263
+ use_int32_inputs: bool,
264
+ ):
265
+ """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
266
+ if isinstance(model, T5Encoder):
267
+ return T5EncoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
268
+
269
+ if isinstance(model, T5EncoderDecoderInit):
270
+ return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
271
+
272
+ return T5DecoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
@@ -0,0 +1,12 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(os.path.dirname(__file__))
9
+
10
+ transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
11
+ if transformers_dir not in sys.path:
12
+ sys.path.append(transformers_dir)