onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,414 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import numpy as np
7
+ import torch
8
+ from transformers import AutoTokenizer
9
+
10
+ import onnxruntime as ort
11
+
12
+ pt_to_np = {
13
+ "torch.int32": np.int32,
14
+ "torch.int64": np.int64,
15
+ "torch.float32": np.float32,
16
+ "torch.float16": np.float16,
17
+ }
18
+
19
+
20
+ def cuda_memcpy(dst, src):
21
+ from cuda import cudart
22
+
23
+ cudart.cudaMemcpy(
24
+ dst.data_ptr(),
25
+ src.data_ptr(),
26
+ src.element_size() * src.nelement(),
27
+ cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice,
28
+ )
29
+
30
+
31
+ class ORTGenerator:
32
+ def __init__(self, decoder_path):
33
+ self.onnx_decoder_path = decoder_path
34
+ self.num_heads = 32
35
+ self.head_size = 80
36
+ self.num_layers = 32
37
+ self.max_sequence_length = 2048
38
+ self.device_id = 0
39
+ self.use_cuda_graph = False
40
+ self.use_traced_inputs = False
41
+ self.static_inputs_map = {}
42
+
43
+ def append_static_inputs(self, batch_size):
44
+ # Only use this function with GQA and with use_cuda_graph=True
45
+ if batch_size in self.static_inputs_map:
46
+ return
47
+
48
+ cpu_device = torch.device("cpu")
49
+ cuda_device = torch.device("cuda", self.device_id)
50
+
51
+ static_io = {}
52
+ static_io["input_ids"] = torch.zeros((batch_size, 1), dtype=torch.int32, device=cuda_device)
53
+ static_io["step"] = torch.tensor([0], dtype=torch.int64, device=cuda_device)
54
+ static_io["seqlens_k"] = torch.tensor(batch_size * [0], dtype=torch.int32, device=cuda_device)
55
+ static_io["total_sequence_length"] = torch.tensor([0], dtype=torch.int32, device=cpu_device)
56
+
57
+ cache_shape = (batch_size, self.num_heads, self.max_sequence_length, self.head_size)
58
+ for i in range(self.num_layers):
59
+ cache = torch.zeros(cache_shape, device=cuda_device, dtype=torch.float16)
60
+ static_io.update({f"past_key_{i}": cache.contiguous(), f"past_value_{i}": cache.clone().contiguous()})
61
+
62
+ static_io["logits"] = torch.zeros((batch_size, 1, 51200), dtype=torch.float16, device=cuda_device)
63
+
64
+ self.static_inputs_map[batch_size] = static_io
65
+
66
+ def get_initial_inputs_and_outputs(self, encodings_dict):
67
+ self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32
68
+
69
+ input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32)
70
+ attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32)
71
+
72
+ batch_size, sequence_length = input_ids.shape
73
+
74
+ self.use_traced_inputs = (
75
+ self.use_cuda_graph
76
+ and (batch_size in self.static_inputs_map)
77
+ and self.use_buffer_share
78
+ and not self.packed_kv
79
+ )
80
+
81
+ step = (
82
+ torch.tensor([0], device=self.device, dtype=torch.int64)
83
+ if not self.use_traced_inputs
84
+ else self.static_inputs_map[batch_size]["step"]
85
+ )
86
+
87
+ seqlens_k = (
88
+ torch.tensor(batch_size * [0], device=self.device, dtype=torch.int32)
89
+ if not self.use_traced_inputs
90
+ else self.static_inputs_map[batch_size]["seqlens_k"]
91
+ )
92
+ cuda_memcpy(seqlens_k, attention_mask.sum(1).sub(1).to(torch.int32))
93
+
94
+ total_seq_length = (
95
+ torch.tensor([0], device=torch.device("cpu"), dtype=torch.int32)
96
+ if not self.use_traced_inputs
97
+ else self.static_inputs_map[batch_size]["total_sequence_length"]
98
+ )
99
+ total_seq_length[0] = sequence_length
100
+
101
+ inputs = {
102
+ "input_ids": input_ids.contiguous(),
103
+ "attention_mask": attention_mask.contiguous(),
104
+ }
105
+
106
+ if self.use_step:
107
+ inputs["step"] = step.contiguous()
108
+
109
+ if self.use_cuda_graph:
110
+ inputs["seqlens_k"] = seqlens_k.contiguous()
111
+ inputs["total_sequence_length"] = total_seq_length.contiguous()
112
+ del inputs["attention_mask"]
113
+
114
+ past_seq_length = self.max_sequence_length if self.use_buffer_share else 0
115
+ past_shape = (
116
+ (2, batch_size, self.num_heads, past_seq_length, self.head_size)
117
+ if self.packed_kv
118
+ else (batch_size, self.num_heads, past_seq_length, self.head_size)
119
+ )
120
+
121
+ if not self.use_traced_inputs:
122
+ for i in range(self.num_layers):
123
+ past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype)
124
+ (
125
+ inputs.update({f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()})
126
+ if not self.packed_kv
127
+ else inputs.update({f"past_{i}": past.contiguous()})
128
+ )
129
+ else:
130
+ for i in range(self.num_layers):
131
+ inputs.update(
132
+ {
133
+ f"past_key_{i}": self.static_inputs_map[batch_size][f"past_key_{i}"].contiguous(),
134
+ f"past_value_{i}": self.static_inputs_map[batch_size][f"past_value_{i}"].contiguous(),
135
+ }
136
+ )
137
+
138
+ logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype)
139
+ outputs = {"logits": logits.contiguous()}
140
+
141
+ if not self.use_buffer_share:
142
+ present_shape = (
143
+ (2, batch_size, self.num_heads, sequence_length, self.head_size)
144
+ if self.packed_kv
145
+ else (batch_size, self.num_heads, sequence_length, self.head_size)
146
+ )
147
+ for i in range(self.num_layers):
148
+ present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
149
+ (
150
+ outputs.update(
151
+ {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()}
152
+ )
153
+ if not self.packed_kv
154
+ else outputs.update({f"present_{i}": present.contiguous()})
155
+ )
156
+
157
+ return inputs, outputs
158
+
159
+ def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: dict):
160
+ io_binding = model.io_binding()
161
+ device = None
162
+
163
+ for k, v in inputs.items():
164
+ io_binding.bind_input(
165
+ name=k,
166
+ device_type=v.device.type,
167
+ device_id=0 if v.device.type == "cpu" else v.device.index,
168
+ element_type=pt_to_np[repr(v.dtype)],
169
+ shape=tuple(v.shape),
170
+ buffer_ptr=v.data_ptr(),
171
+ )
172
+ device = v.device
173
+
174
+ for output in model.get_outputs():
175
+ name = output.name
176
+ if self.use_buffer_share and "present" in name:
177
+ v = inputs[name.replace("present", "past")]
178
+ io_binding.bind_output(
179
+ name=name,
180
+ device_type=v.device.type,
181
+ device_id=v.device.index,
182
+ element_type=(np.float16 if self.use_fp16 else np.float32),
183
+ shape=tuple(v.shape),
184
+ buffer_ptr=v.data_ptr(),
185
+ )
186
+ else:
187
+ v = outputs[name]
188
+ io_binding.bind_output(
189
+ name=name,
190
+ device_type=device.type,
191
+ device_id=0 if device.type == "cpu" else device.index,
192
+ element_type=(np.float16 if self.use_fp16 else np.float32),
193
+ shape=tuple(v.shape),
194
+ buffer_ptr=v.data_ptr(),
195
+ )
196
+
197
+ return io_binding
198
+
199
+ def create_session(
200
+ self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False, use_cuda_graph=False
201
+ ):
202
+ self.device_id = device_id
203
+ sess_options = ort.SessionOptions()
204
+ sess_options.log_verbosity_level = 4
205
+ sess_options.log_severity_level = 4
206
+ self.use_cuda_graph = use_cuda_graph
207
+ ep = (
208
+ ("CUDAExecutionProvider", {"device_id": self.device_id, "enable_cuda_graph": self.use_cuda_graph})
209
+ if self.device_id >= 0
210
+ else "CPUExecutionProvider"
211
+ )
212
+ self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep])
213
+ self.ro = ort.RunOptions()
214
+
215
+ self.device = torch.device("cuda", self.device_id) if torch.cuda.is_available() else torch.device("cpu")
216
+ self.use_fp16 = use_fp16
217
+ self.use_buffer_share = use_buffer_share
218
+ self.packed_kv = packed_kv
219
+ self.use_step = use_step
220
+
221
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
222
+ self.tokenizer.pad_token = "[PAD]"
223
+
224
+ def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation, benchmark=False):
225
+ inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict)
226
+
227
+ all_token_ids = inputs["input_ids"].clone()
228
+ batch_size, sequence_length = all_token_ids.shape
229
+
230
+ current_length = sequence_length
231
+ has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool)
232
+
233
+ if benchmark:
234
+ import time
235
+
236
+ latency = []
237
+
238
+ prompt_run = True
239
+ while current_length < max_length:
240
+ io_binding = self.apply_io_binding(self.sess, inputs, outputs)
241
+
242
+ if benchmark:
243
+ start = time.time()
244
+
245
+ io_binding.synchronize_inputs()
246
+ if prompt_run:
247
+ if self.use_cuda_graph:
248
+ # Disable CUDA graph for the prompt run
249
+ self.ro.add_run_config_entry("gpu_graph_id", "-1")
250
+ self.sess.run_with_iobinding(io_binding, self.ro)
251
+ if self.use_cuda_graph:
252
+ # Enable CUDA graph for the decoding run
253
+ self.ro.add_run_config_entry(
254
+ "gpu_graph_id", str(cuda_graph_annotation) if self.use_traced_inputs else "-1"
255
+ )
256
+ prompt_run = False
257
+ else:
258
+ self.sess.run_with_iobinding(io_binding, self.ro)
259
+ io_binding.synchronize_outputs()
260
+
261
+ if benchmark:
262
+ end = time.time()
263
+ latency.append(end - start)
264
+
265
+ # Sample with argmax (greedy search)
266
+ next_token_logits = outputs["logits"][:, -1, :]
267
+ next_tokens = torch.argmax(next_token_logits, dim=-1)
268
+
269
+ # Check if we previously reached EOS token id or if generated token id is EOS token id
270
+ has_eos = has_eos | next_tokens == self.tokenizer.eos_token_id
271
+
272
+ # Determine which new tokens to add to list of all token ids
273
+ # Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
274
+ tokens_to_add = next_tokens.masked_fill(has_eos, self.tokenizer.eos_token_id).reshape([batch_size, 1])
275
+ all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
276
+
277
+ # Return early if all batch entries have reached EOS token id
278
+ if torch.all(has_eos):
279
+ break
280
+
281
+ # Update inputs for next inference run
282
+ current_length += 1
283
+
284
+ inputs["input_ids"] = tokens_to_add.to(torch.int32)
285
+ if self.use_traced_inputs:
286
+ cuda_memcpy(self.static_inputs_map[batch_size]["input_ids"], inputs["input_ids"])
287
+ inputs["input_ids"] = self.static_inputs_map[batch_size]["input_ids"]
288
+
289
+ if self.use_step:
290
+ inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64)
291
+ if self.use_traced_inputs:
292
+ cuda_memcpy(self.static_inputs_map[batch_size]["step"], inputs["step"])
293
+ inputs["step"] = self.static_inputs_map[batch_size]["step"]
294
+
295
+ if self.use_cuda_graph:
296
+ previous_seqlens_k = inputs["seqlens_k"]
297
+ inputs["seqlens_k"] = (previous_seqlens_k + (~has_eos).reshape(batch_size, 1)).to(torch.int32)
298
+ inputs["total_sequence_length"][0] = current_length
299
+ if self.use_traced_inputs:
300
+ cuda_memcpy(self.static_inputs_map[batch_size]["seqlens_k"], inputs["seqlens_k"])
301
+ inputs["seqlens_k"] = self.static_inputs_map[batch_size]["seqlens_k"]
302
+ self.static_inputs_map[batch_size]["total_sequence_length"][0] = inputs["total_sequence_length"][0]
303
+ inputs["total_sequence_length"] = self.static_inputs_map[batch_size]["total_sequence_length"]
304
+ else:
305
+ inputs["attention_mask"] = torch.cat(
306
+ [inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1
307
+ ).to(torch.int32)
308
+
309
+ # Set logits to zeros for next inference run and re-use memory buffer
310
+ if outputs["logits"].shape[1] != 1:
311
+ outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
312
+ if self.use_traced_inputs:
313
+ outputs["logits"] = self.static_inputs_map[batch_size]["logits"]
314
+ outputs["logits"].zero_()
315
+
316
+ if not self.use_buffer_share:
317
+ for i in range(self.num_layers):
318
+ if not self.packed_kv:
319
+ inputs[f"past_key_{i}"] = outputs[f"present_key_{i}"]
320
+ inputs[f"past_value_{i}"] = outputs[f"present_value_{i}"]
321
+ else:
322
+ inputs[f"past_{i}"] = outputs[f"present_{i}"]
323
+
324
+ new_sequence_length = inputs["attention_mask"].shape[1]
325
+ present_shape = (
326
+ (2, batch_size, self.num_heads, new_sequence_length, self.head_size)
327
+ if self.packed_kv
328
+ else (batch_size, self.num_heads, new_sequence_length, self.head_size)
329
+ )
330
+ for i in range(self.num_layers):
331
+ present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
332
+ (
333
+ outputs.update(
334
+ {
335
+ f"present_key_{i}": present.contiguous(),
336
+ f"present_value_{i}": present.clone().contiguous(),
337
+ }
338
+ )
339
+ if not self.packed_kv
340
+ else outputs.update({f"present_{i}": present.contiguous()})
341
+ )
342
+
343
+ if benchmark:
344
+ print(
345
+ f"Batch size: {batch_size}, Sequence length: {sequence_length}, Token num: {max_length - sequence_length}"
346
+ )
347
+ print(f"Prompt letency: {1000 * latency[0]}ms, Token latency: {1000 * np.mean(latency[1:])}ms")
348
+ return
349
+
350
+ texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True)
351
+ return texts
352
+
353
+ def generate(self, prompt, max_length, cuda_graph_annotation):
354
+ encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True)
355
+
356
+ return self.generate_impl(encodings_dict, max_length, cuda_graph_annotation)
357
+
358
+ def generate_benchmark(self, prompt_shape, token_num, cuda_graph_annotation):
359
+ batch_size, sequence_length = prompt_shape
360
+ max_length = sequence_length + token_num
361
+
362
+ encodings_dict = {}
363
+ encodings_dict["input_ids"] = torch.randint(0, 50264, (batch_size, sequence_length), dtype=torch.int32).tolist()
364
+ encodings_dict["attention_mask"] = torch.ones((batch_size, sequence_length), dtype=torch.int32).tolist()
365
+
366
+ # Warm up run
367
+ self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=False)
368
+
369
+ # Benchmark run
370
+ self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=True)
371
+
372
+
373
+ def run_phi2(
374
+ onnx_model_path,
375
+ use_buffer_share,
376
+ device_id,
377
+ packed_kv=False,
378
+ use_fp16=True,
379
+ use_step=False,
380
+ use_cuda_graph=False,
381
+ run_benchmark=False,
382
+ ):
383
+ generator = ORTGenerator(onnx_model_path)
384
+ generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step, use_cuda_graph)
385
+
386
+ def simple_run(prompt):
387
+ example_batch_size = len(prompt)
388
+ if use_cuda_graph:
389
+ generator.append_static_inputs(batch_size=example_batch_size)
390
+ texts = generator.generate(prompt, max_length=210, cuda_graph_annotation=example_batch_size)
391
+
392
+ for i in range(len(texts)):
393
+ print("Prompt: ", prompt[i])
394
+ print("Texts: ", texts[i])
395
+
396
+ prompt = [
397
+ '''```python
398
+ def print_prime(n):
399
+ """
400
+ Print all primes between 1 and n
401
+ """'''
402
+ ]
403
+
404
+ if not run_benchmark:
405
+ simple_run(prompt)
406
+
407
+ # Run simple benchmark. Time the decoder only.
408
+ if run_benchmark:
409
+ token_num = 32
410
+ for batch_size in [1, 2, 4, 8]:
411
+ generator.append_static_inputs(batch_size)
412
+ for sequence_length in [16, 512]:
413
+ prompt_shape = (batch_size, sequence_length)
414
+ generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size)
@@ -0,0 +1,12 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import os.path
6
+ import sys
7
+
8
+ sys.path.append(os.path.dirname(__file__))
9
+
10
+ transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
11
+ if transformers_dir not in sys.path:
12
+ sys.path.append(transformers_dir)