onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,12 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import tensorrt as trt
6
+
7
+ TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
8
+
9
+
10
+ def init_trt_plugins():
11
+ # Register TensorRT plugins
12
+ trt.init_libnvinfer_plugins(TRT_LOGGER, "")
@@ -0,0 +1,12 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import os.path
6
+ import sys
7
+
8
+ sys.path.append(os.path.dirname(__file__))
9
+
10
+ transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
11
+ if transformers_dir not in sys.path:
12
+ sys.path.append(transformers_dir)
@@ -0,0 +1,278 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import argparse
8
+ import copy
9
+ import logging
10
+ import os
11
+
12
+ import torch
13
+ from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger
14
+ from t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS, T5Helper
15
+
16
+ logger = logging.getLogger("")
17
+
18
+
19
+ def parse_arguments():
20
+ parser = argparse.ArgumentParser()
21
+
22
+ pretrained_models = PRETRAINED_T5_MODELS + PRETRAINED_MT5_MODELS
23
+ parser.add_argument(
24
+ "-m",
25
+ "--model_name_or_path",
26
+ required=False,
27
+ default=PRETRAINED_T5_MODELS[0],
28
+ type=str,
29
+ help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models),
30
+ )
31
+
32
+ parser.add_argument(
33
+ "--model_type",
34
+ required=False,
35
+ type=str,
36
+ default="t5",
37
+ choices=["t5", "mt5"],
38
+ help="Model type: either t5 (default) or mt5",
39
+ )
40
+
41
+ parser.add_argument(
42
+ "--cache_dir",
43
+ required=False,
44
+ type=str,
45
+ default=os.path.join(".", "cache_models"),
46
+ help="Directory to cache pre-trained models",
47
+ )
48
+
49
+ parser.add_argument(
50
+ "--output",
51
+ required=False,
52
+ type=str,
53
+ default=os.path.join(".", "onnx_models"),
54
+ help="Output directory",
55
+ )
56
+
57
+ parser.add_argument(
58
+ "-o",
59
+ "--optimize_onnx",
60
+ required=False,
61
+ action="store_true",
62
+ help="Use optimizer.py to optimize onnx model",
63
+ )
64
+ parser.set_defaults(optimize_onnx=False)
65
+
66
+ parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
67
+ parser.set_defaults(use_gpu=False)
68
+
69
+ parser.add_argument(
70
+ "-p",
71
+ "--precision",
72
+ required=False,
73
+ type=Precision,
74
+ default=Precision.FLOAT32,
75
+ choices=[Precision.FLOAT32, Precision.FLOAT16],
76
+ help="Precision of model to run. fp32 for full precision, fp16 for half precision",
77
+ )
78
+
79
+ parser.add_argument("--verbose", required=False, action="store_true")
80
+ parser.set_defaults(verbose=False)
81
+
82
+ parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
83
+ parser.set_defaults(use_external_data_format=False)
84
+
85
+ parser.add_argument(
86
+ "-s",
87
+ "--use_decoder_start_token",
88
+ required=False,
89
+ action="store_true",
90
+ help="Use config.decoder_start_token_id. Otherwise, add an extra graph input for decoder_input_ids.",
91
+ )
92
+ parser.set_defaults(use_decoder_start_token=False)
93
+
94
+ parser.add_argument(
95
+ "-w",
96
+ "--overwrite",
97
+ required=False,
98
+ action="store_true",
99
+ help="overwrite existing ONNX model",
100
+ )
101
+ parser.set_defaults(overwrite=False)
102
+
103
+ parser.add_argument(
104
+ "--disable_auto_mixed_precision",
105
+ required=False,
106
+ action="store_true",
107
+ help="use pure fp16 instead of mixed precision",
108
+ )
109
+ parser.set_defaults(disable_auto_mixed_precision=False)
110
+
111
+ parser.add_argument(
112
+ "--separate_encoder_and_decoder_init",
113
+ required=False,
114
+ action="store_true",
115
+ help="Do not merge encode and decoder init. Output 3 instead of 2 onnx models.",
116
+ )
117
+ parser.set_defaults(separate_encoder_and_decoder_init=False)
118
+
119
+ parser.add_argument(
120
+ "--use_int64_inputs",
121
+ required=False,
122
+ action="store_true",
123
+ help="Use int64 instead of int32 for input_ids, position_ids and attention_mask.",
124
+ )
125
+ parser.set_defaults(use_int64_inputs=False)
126
+
127
+ parser.add_argument(
128
+ "--state_dict_path",
129
+ type=str,
130
+ default="",
131
+ help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
132
+ )
133
+
134
+ args = parser.parse_args()
135
+
136
+ return args
137
+
138
+
139
+ def export_onnx_models(
140
+ model_name_or_path,
141
+ cache_dir,
142
+ output_dir,
143
+ use_gpu,
144
+ use_external_data_format,
145
+ optimize_onnx,
146
+ precision,
147
+ verbose,
148
+ use_decoder_start_token: bool = False,
149
+ merge_encoder_and_decoder_init: bool = True,
150
+ overwrite: bool = False,
151
+ disable_auto_mixed_precision: bool = False,
152
+ use_int32_inputs: bool = True,
153
+ model_type: str = "t5",
154
+ state_dict_path: str = "",
155
+ ):
156
+ device = torch.device("cuda:0" if use_gpu else "cpu")
157
+
158
+ models = T5Helper.load_model(
159
+ model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, model_type, state_dict_path
160
+ )
161
+ config = models["decoder"].config
162
+
163
+ if (not use_external_data_format) and (config.num_layers > 24):
164
+ logger.info("Try use_external_data_format when model size > 2GB")
165
+
166
+ output_paths = []
167
+ for name, model in models.items():
168
+ model.to(device)
169
+ filename_suffix = "_" + name
170
+
171
+ onnx_path = T5Helper.get_onnx_path(
172
+ output_dir,
173
+ model_name_or_path,
174
+ suffix=filename_suffix,
175
+ new_folder=False,
176
+ )
177
+
178
+ if overwrite or not os.path.exists(onnx_path):
179
+ logger.info(f"Exporting ONNX model to {onnx_path}")
180
+ # We have to clone model before exporting onnx, otherwise verify_onnx will report large difference.
181
+ cloned_model = copy.deepcopy(model).to(device)
182
+ T5Helper.export_onnx(
183
+ cloned_model,
184
+ device,
185
+ onnx_path,
186
+ verbose,
187
+ use_external_data_format,
188
+ use_decoder_input_ids=not use_decoder_start_token,
189
+ use_int32_inputs=use_int32_inputs,
190
+ )
191
+ else:
192
+ logger.info(f"Skip exporting: existed ONNX model {onnx_path}")
193
+
194
+ # Optimize ONNX graph. Note that we have not implemented graph optimization for T5 yet.
195
+ if optimize_onnx or precision != Precision.FLOAT32:
196
+ output_path = T5Helper.get_onnx_path(
197
+ output_dir,
198
+ model_name_or_path,
199
+ suffix=filename_suffix + "_" + str(precision),
200
+ new_folder=False,
201
+ )
202
+
203
+ if overwrite or not os.path.exists(output_path):
204
+ logger.info(f"Optimizing model to {output_path}")
205
+ T5Helper.optimize_onnx(
206
+ onnx_path,
207
+ output_path,
208
+ precision == Precision.FLOAT16,
209
+ config.num_heads,
210
+ config.hidden_size,
211
+ use_external_data_format,
212
+ auto_mixed_precision=not disable_auto_mixed_precision,
213
+ use_gpu=use_gpu,
214
+ )
215
+ else:
216
+ logger.info(f"Skip optimizing: existed ONNX model {onnx_path}")
217
+ else:
218
+ output_path = onnx_path
219
+
220
+ ort_session = create_onnxruntime_session(
221
+ output_path,
222
+ use_gpu=use_gpu,
223
+ provider=["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"],
224
+ )
225
+
226
+ with torch.no_grad():
227
+ max_diff = T5Helper.verify_onnx(model, ort_session, device, use_int32_inputs)
228
+ logger.info(f"PyTorch and OnnxRuntime results max difference = {max_diff}")
229
+ if max_diff > 1e-4:
230
+ logger.warning("PyTorch and OnnxRuntime results are NOT close")
231
+
232
+ output_paths.append(output_path)
233
+
234
+ return output_paths
235
+
236
+
237
+ def main():
238
+ args = parse_arguments()
239
+
240
+ setup_logger(args.verbose)
241
+
242
+ logger.info(f"Arguments:{args}")
243
+
244
+ cache_dir = args.cache_dir
245
+ output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
246
+ prepare_environment(cache_dir, output_dir, args.use_gpu)
247
+
248
+ if args.precision != Precision.FLOAT32:
249
+ assert args.optimize_onnx, "fp16/int8 requires --optimize_onnx"
250
+
251
+ if args.precision == Precision.FLOAT16:
252
+ assert args.use_gpu, "fp16 requires --use_gpu"
253
+
254
+ if args.optimize_onnx:
255
+ logger.warning("Graph optimization for T5 is not implemented yet.")
256
+
257
+ output_paths = export_onnx_models(
258
+ args.model_name_or_path,
259
+ cache_dir,
260
+ output_dir,
261
+ args.use_gpu,
262
+ args.use_external_data_format,
263
+ args.optimize_onnx,
264
+ args.precision,
265
+ args.verbose,
266
+ args.use_decoder_start_token,
267
+ not args.separate_encoder_and_decoder_init,
268
+ args.overwrite,
269
+ args.disable_auto_mixed_precision,
270
+ not args.use_int64_inputs,
271
+ args.model_type,
272
+ )
273
+
274
+ logger.info(f"Done! Outputs: {output_paths}")
275
+
276
+
277
+ if __name__ == "__main__":
278
+ main()
@@ -0,0 +1,150 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ from typing import List, Tuple
9
+
10
+ import torch
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class PastKeyValuesHelper:
16
+ """Helper functions to process past key values for encoder-decoder model"""
17
+
18
+ @staticmethod
19
+ def get_past_names(num_layers, present: bool = False):
20
+ past_self_names = []
21
+ past_cross_names = []
22
+ for i in range(num_layers):
23
+ past_self_names.extend(
24
+ [f"present_key_self_{i}", f"present_value_self_{i}"]
25
+ if present
26
+ else [f"past_key_self_{i}", f"past_value_self_{i}"]
27
+ )
28
+ past_cross_names.extend(
29
+ [f"present_key_cross_{i}", f"present_value_cross_{i}"]
30
+ if present
31
+ else [f"past_key_cross_{i}", f"past_value_cross_{i}"]
32
+ )
33
+ return past_self_names + past_cross_names
34
+
35
+ @staticmethod
36
+ def group_by_self_or_cross(present_key_values):
37
+ """Split present state from grouped by layer to grouped by self/cross attention.
38
+ Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
39
+ After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...), (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
40
+
41
+ """
42
+ present_self = []
43
+ present_cross = []
44
+ for _i, present_layer_i in enumerate(present_key_values):
45
+ assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
46
+ (
47
+ present_key_self,
48
+ present_value_self,
49
+ present_key_cross,
50
+ present_value_cross,
51
+ ) = present_layer_i
52
+ present_self.extend([present_key_self, present_value_self])
53
+ present_cross.extend([present_key_cross, present_value_cross])
54
+ return present_self, present_cross
55
+
56
+ @staticmethod
57
+ def group_by_layer(past, num_layers):
58
+ """Reorder past state from grouped by self/cross attention to grouped by layer.
59
+ Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ..., past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
60
+ After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
61
+ """
62
+ assert len(past) == 4 * num_layers
63
+ return tuple(
64
+ [
65
+ past[2 * i],
66
+ past[2 * i + 1],
67
+ past[2 * num_layers + 2 * i],
68
+ past[2 * num_layers + 2 * i + 1],
69
+ ]
70
+ for i in range(num_layers)
71
+ )
72
+
73
+ @staticmethod
74
+ def back_group_by_layer(past_key_values: Tuple[Tuple[torch.Tensor]]):
75
+ """Categorize present_key_values from self and cross attention to layer by layer.
76
+
77
+ Reorder past state from grouped by self/cross attention to grouped by layer.
78
+ Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...,
79
+ past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
80
+ After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
81
+ (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
82
+
83
+ Args:
84
+ present_key_values: From past_key_values of a model (group by self and cross attention)
85
+
86
+ Returns:
87
+ past_tuples: present key and values grouped by layer.
88
+ """
89
+ past_tuples = ()
90
+ half_idx = len(past_key_values) // 2
91
+ for i in range(len(past_key_values) // 4):
92
+ idx = 2 * i
93
+ past_tuples += (
94
+ (
95
+ past_key_values[idx],
96
+ past_key_values[idx + 1],
97
+ past_key_values[half_idx + idx],
98
+ past_key_values[half_idx + idx + 1],
99
+ ),
100
+ )
101
+ return past_tuples
102
+
103
+ @staticmethod
104
+ def group_by_self_and_cross(present_key_values: Tuple[torch.Tensor], concat: bool = False):
105
+ """Categorize present_key_values into self and cross attention.
106
+
107
+ Split present state from grouped by layer to grouped by self/cross attention.
108
+ Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
109
+ (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
110
+ After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...),
111
+ (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
112
+
113
+ Args:
114
+ present_key_values: From past_key_values of a model (group by layer)
115
+ concat: If concat self attention with cross attention key/value to return
116
+
117
+ Returns:
118
+ present_self (Tuple[torch.Tensor]): present key and values from self attention
119
+ present_cross (Tuple[torch.Tensor]): present key and values from cross attention
120
+ """
121
+ present_self: List[torch.Tensor] = []
122
+ present_cross: List[torch.Tensor] = []
123
+ for _, present_layer_i in enumerate(present_key_values):
124
+ assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
125
+ present_key_self, present_value_self, present_key_cross, present_value_cross = present_layer_i
126
+ present_self.extend([present_key_self, present_value_self])
127
+ present_cross.extend([present_key_cross, present_value_cross])
128
+ if concat:
129
+ return present_self + present_cross
130
+ else:
131
+ return present_self, present_cross
132
+
133
+ @staticmethod
134
+ def get_input_names(past_key_values: Tuple[Tuple[torch.Tensor]], encoder=True):
135
+ """Process input names of model wrapper.
136
+
137
+ Args:
138
+ past_key_values: Consider `self` and `cross` past_key_values
139
+
140
+ Returns:
141
+ names (List[string]): input names
142
+ """
143
+ names = []
144
+ num_layers = len(past_key_values) // 4 if encoder else len(past_key_values)
145
+ prefix = "past_" if not encoder else "present_"
146
+ for i in range(num_layers):
147
+ names.extend([prefix + s for s in [f"key_self_{i}", f"value_self_{i}"]])
148
+ for i in range(num_layers):
149
+ names.extend([prefix + s for s in [f"key_cross_{i}", f"value_cross_{i}"]])
150
+ return names